Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ version = "0.8.0"
[workspace.dependencies.headers]
version = "0.4.0"

# Hex encoding and decoding
[workspace.dependencies.hex]
version = "0.4.3"

# HTTP request/response
[workspace.dependencies.http]
version = "1.3.1"
Expand Down Expand Up @@ -184,6 +188,11 @@ version = "0.27.5"
features = ["http1", "http2"]
default-features = false

# HashMap which preserves insertion order
[workspace.dependencies.indexmap]
version = "2.8.0"
features = ["serde"]

# Snapshot testing
[workspace.dependencies.insta]
version = "1.42.2"
Expand Down Expand Up @@ -278,6 +287,11 @@ version = "0.5.1"
version = "0.8.22"
features = ["url", "chrono", "preserve_order"]

# SHA-2 cryptographic hash algorithm
[workspace.dependencies.sha2]
version = "0.10.8"
features = ["oid"]

# Query builder
[workspace.dependencies.sea-query]
version = "0.32.3"
Expand Down
5 changes: 5 additions & 0 deletions crates/data-model/src/oauth2/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ pub struct Client {
/// Client identifier
pub client_id: String,

/// Hash of the client metadata
pub metadata_digest: Option<String>,

pub encrypted_client_secret: Option<String>,

pub application_type: Option<ApplicationType>,
Expand Down Expand Up @@ -177,6 +180,7 @@ impl Client {
Self {
id: Ulid::from_datetime_with_source(now.into(), rng),
client_id: "client1".to_owned(),
metadata_digest: None,
encrypted_client_secret: None,
application_type: Some(ApplicationType::Web),
redirect_uris: vec![
Expand All @@ -202,6 +206,7 @@ impl Client {
Self {
id: Ulid::from_datetime_with_source(now.into(), rng),
client_id: "client2".to_owned(),
metadata_digest: None,
encrypted_client_secret: None,
application_type: Some(ApplicationType::Native),
redirect_uris: vec![Url::parse("https://client2.example.com/redirect").unwrap()],
Expand Down
4 changes: 3 additions & 1 deletion crates/handlers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ base64ct.workspace = true
camino.workspace = true
chrono.workspace = true
elliptic-curve.workspace = true
hex.workspace = true
governor.workspace = true
indexmap = "2.8.0"
indexmap.workspace = true
pkcs8.workspace = true
psl = "2.1.96"
sha2.workspace = true
time = "0.3.41"
url.workspace = true
mime = "0.3.17"
Expand Down
1 change: 1 addition & 0 deletions crates/handlers/src/graphql/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async fn create_test_client(state: &TestState) -> Client {
vec![],
None,
None,
None,
vec![],
None,
None,
Expand Down
135 changes: 106 additions & 29 deletions crates/handlers/src/oauth2/registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use oauth2_types::{
use psl::Psl;
use rand::distributions::{Alphanumeric, DistString};
use serde::Serialize;
use sha2::Digest as _;
use thiserror::Error;
use tracing::info;
use url::Url;
Expand Down Expand Up @@ -50,6 +51,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(mas_keystore::aead::Error);
impl_from_error_for_route!(serde_json::Error);

impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
Expand Down Expand Up @@ -204,7 +206,10 @@ pub(crate) async fn post(
// Propagate any JSON extraction error
let Json(body) = body?;

info!(?body, "Client registration");
// We need to serialize the body to compute the hash, and to log it
let body_json = serde_json::to_string(&body)?;

info!(body = body_json, "Client registration");

let user_agent = user_agent.map(|ua| ua.to_string());

Expand Down Expand Up @@ -276,34 +281,59 @@ pub(crate) async fn post(
_ => (None, None),
};

let client = repo
.oauth2_client()
.add(
&mut rng,
&clock,
metadata.redirect_uris().to_vec(),
encrypted_client_secret,
metadata.application_type.clone(),
//&metadata.response_types(),
metadata.grant_types().to_vec(),
metadata
.client_name
.clone()
.map(Localized::to_non_localized),
metadata.logo_uri.clone().map(Localized::to_non_localized),
metadata.client_uri.clone().map(Localized::to_non_localized),
metadata.policy_uri.clone().map(Localized::to_non_localized),
metadata.tos_uri.clone().map(Localized::to_non_localized),
metadata.jwks_uri.clone(),
metadata.jwks.clone(),
// XXX: those might not be right, should be function calls
metadata.id_token_signed_response_alg.clone(),
metadata.userinfo_signed_response_alg.clone(),
metadata.token_endpoint_auth_method.clone(),
metadata.token_endpoint_auth_signing_alg.clone(),
metadata.initiate_login_uri.clone(),
)
.await?;
// If the client doesn't have a secret, we may be able to deduplicate it. To
// do so, we hash the client metadata, and look for it in the database
let (digest_hash, existing_client) = if client_secret.is_none() {
// XXX: One interesting caveat is that we hash *before* saving to the database.
// It means it takes into account fields that we don't care about *yet*.
//
// This means that if later we start supporting a particular field, we
// will still serve the 'old' client_id, without updating the client in the
// database
let hash = sha2::Sha256::digest(body_json);
let hash = hex::encode(hash);
let client = repo.oauth2_client().find_by_metadata_digest(&hash).await?;
(Some(hash), client)
} else {
(None, None)
};

let client = if let Some(client) = existing_client {
tracing::info!(%client.id, "Reusing existing client");
client
} else {
let client = repo
.oauth2_client()
.add(
&mut rng,
&clock,
metadata.redirect_uris().to_vec(),
digest_hash,
encrypted_client_secret,
metadata.application_type.clone(),
//&metadata.response_types(),
metadata.grant_types().to_vec(),
metadata
.client_name
.clone()
.map(Localized::to_non_localized),
metadata.logo_uri.clone().map(Localized::to_non_localized),
metadata.client_uri.clone().map(Localized::to_non_localized),
metadata.policy_uri.clone().map(Localized::to_non_localized),
metadata.tos_uri.clone().map(Localized::to_non_localized),
metadata.jwks_uri.clone(),
metadata.jwks.clone(),
// XXX: those might not be right, should be function calls
metadata.id_token_signed_response_alg.clone(),
metadata.userinfo_signed_response_alg.clone(),
metadata.token_endpoint_auth_method.clone(),
metadata.token_endpoint_auth_signing_alg.clone(),
metadata.initiate_login_uri.clone(),
)
.await?;
tracing::info!(%client.id, "Registered new client");
client
};

let response = ClientRegistrationResponse {
client_id: client.client_id.clone(),
Expand Down Expand Up @@ -490,4 +520,51 @@ mod tests {
let response: ClientRegistrationResponse = response.json();
assert!(response.client_secret.is_some());
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_registration_dedupe(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();

// Post a client registration twice, we should get the same client ID
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/"],
"response_types": ["code"],
"grant_types": ["authorization_code"],
"token_endpoint_auth_method": "none",
}));

let response = state.request(request.clone()).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = response.json();
let client_id = response.client_id;

let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = response.json();
assert_eq!(response.client_id, client_id);

// Doing that with a client that has a client_secret should not deduplicate
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/"],
"response_types": ["code"],
"grant_types": ["authorization_code"],
"token_endpoint_auth_method": "client_secret_basic",
}));

let response = state.request(request.clone()).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = response.json();
// Sanity check that the client_id is different
assert_ne!(response.client_id, client_id);
let client_id = response.client_id;

let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = response.json();
assert_ne!(response.client_id, client_id);
}
}
2 changes: 1 addition & 1 deletion crates/jose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ sec1 = "0.7.3"
serde.workspace = true
serde_json.workspace = true
serde_with = "3.12.0"
sha2 = { version = "0.10.8", features = ["oid"] }
sha2.workspace = true
signature = "2.2.0"
thiserror.workspace = true
url.workspace = true
Expand Down
4 changes: 3 additions & 1 deletion crates/oauth2-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ language-tags = { version = "0.3.2", features = ["serde"] }
url.workspace = true
serde_with = { version = "3.12.0", features = ["chrono"] }
chrono.workspace = true
sha2 = "0.10.8"
sha2.workspace = true
thiserror.workspace = true
indexmap.workspace = true

mas-iana.workspace = true
mas-jose.workspace = true

[dev-dependencies]
assert_matches = "1.5.0"
insta.workspace = true
Loading
Loading