@@ -22,6 +22,7 @@ use oauth2_types::{
22
22
use psl:: Psl ;
23
23
use rand:: distributions:: { Alphanumeric , DistString } ;
24
24
use serde:: Serialize ;
25
+ use sha2:: Digest as _;
25
26
use thiserror:: Error ;
26
27
use tracing:: info;
27
28
use url:: Url ;
@@ -50,6 +51,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
50
51
impl_from_error_for_route ! ( mas_policy:: LoadError ) ;
51
52
impl_from_error_for_route ! ( mas_policy:: EvaluationError ) ;
52
53
impl_from_error_for_route ! ( mas_keystore:: aead:: Error ) ;
54
+ impl_from_error_for_route ! ( serde_json:: Error ) ;
53
55
54
56
impl IntoResponse for RouteError {
55
57
fn into_response ( self ) -> axum:: response:: Response {
@@ -204,7 +206,10 @@ pub(crate) async fn post(
204
206
// Propagate any JSON extraction error
205
207
let Json ( body) = body?;
206
208
207
- info ! ( ?body, "Client registration" ) ;
209
+ // We need to serialize the body to compute the hash, and to log it
210
+ let body_json = serde_json:: to_string ( & body) ?;
211
+
212
+ info ! ( body = body_json, "Client registration" ) ;
208
213
209
214
let user_agent = user_agent. map ( |ua| ua. to_string ( ) ) ;
210
215
@@ -276,34 +281,59 @@ pub(crate) async fn post(
276
281
_ => ( None , None ) ,
277
282
} ;
278
283
279
- let client = repo
280
- . oauth2_client ( )
281
- . add (
282
- & mut rng,
283
- & clock,
284
- metadata. redirect_uris ( ) . to_vec ( ) ,
285
- encrypted_client_secret,
286
- metadata. application_type . clone ( ) ,
287
- //&metadata.response_types(),
288
- metadata. grant_types ( ) . to_vec ( ) ,
289
- metadata
290
- . client_name
291
- . clone ( )
292
- . map ( Localized :: to_non_localized) ,
293
- metadata. logo_uri . clone ( ) . map ( Localized :: to_non_localized) ,
294
- metadata. client_uri . clone ( ) . map ( Localized :: to_non_localized) ,
295
- metadata. policy_uri . clone ( ) . map ( Localized :: to_non_localized) ,
296
- metadata. tos_uri . clone ( ) . map ( Localized :: to_non_localized) ,
297
- metadata. jwks_uri . clone ( ) ,
298
- metadata. jwks . clone ( ) ,
299
- // XXX: those might not be right, should be function calls
300
- metadata. id_token_signed_response_alg . clone ( ) ,
301
- metadata. userinfo_signed_response_alg . clone ( ) ,
302
- metadata. token_endpoint_auth_method . clone ( ) ,
303
- metadata. token_endpoint_auth_signing_alg . clone ( ) ,
304
- metadata. initiate_login_uri . clone ( ) ,
305
- )
306
- . await ?;
284
+ // If the client doesn't have a secret, we may be able to deduplicate it. To
285
+ // do so, we hash the client metadata, and look for it in the database
286
+ let ( digest_hash, existing_client) = if client_secret. is_none ( ) {
287
+ // XXX: One interesting caveat is that we hash *before* saving to the database.
288
+ // It means it takes into account fields that we don't care about *yet*.
289
+ //
290
+ // This means that if later we start supporting a particular field, we
291
+ // will still serve the 'old' client_id, without updating the client in the
292
+ // database
293
+ let hash = sha2:: Sha256 :: digest ( body_json) ;
294
+ let hash = hex:: encode ( hash) ;
295
+ let client = repo. oauth2_client ( ) . find_by_metadata_digest ( & hash) . await ?;
296
+ ( Some ( hash) , client)
297
+ } else {
298
+ ( None , None )
299
+ } ;
300
+
301
+ let client = if let Some ( client) = existing_client {
302
+ tracing:: info!( %client. id, "Reusing existing client" ) ;
303
+ client
304
+ } else {
305
+ let client = repo
306
+ . oauth2_client ( )
307
+ . add (
308
+ & mut rng,
309
+ & clock,
310
+ metadata. redirect_uris ( ) . to_vec ( ) ,
311
+ digest_hash,
312
+ encrypted_client_secret,
313
+ metadata. application_type . clone ( ) ,
314
+ //&metadata.response_types(),
315
+ metadata. grant_types ( ) . to_vec ( ) ,
316
+ metadata
317
+ . client_name
318
+ . clone ( )
319
+ . map ( Localized :: to_non_localized) ,
320
+ metadata. logo_uri . clone ( ) . map ( Localized :: to_non_localized) ,
321
+ metadata. client_uri . clone ( ) . map ( Localized :: to_non_localized) ,
322
+ metadata. policy_uri . clone ( ) . map ( Localized :: to_non_localized) ,
323
+ metadata. tos_uri . clone ( ) . map ( Localized :: to_non_localized) ,
324
+ metadata. jwks_uri . clone ( ) ,
325
+ metadata. jwks . clone ( ) ,
326
+ // XXX: those might not be right, should be function calls
327
+ metadata. id_token_signed_response_alg . clone ( ) ,
328
+ metadata. userinfo_signed_response_alg . clone ( ) ,
329
+ metadata. token_endpoint_auth_method . clone ( ) ,
330
+ metadata. token_endpoint_auth_signing_alg . clone ( ) ,
331
+ metadata. initiate_login_uri . clone ( ) ,
332
+ )
333
+ . await ?;
334
+ tracing:: info!( %client. id, "Registered new client" ) ;
335
+ client
336
+ } ;
307
337
308
338
let response = ClientRegistrationResponse {
309
339
client_id : client. client_id . clone ( ) ,
@@ -490,4 +520,51 @@ mod tests {
490
520
let response: ClientRegistrationResponse = response. json ( ) ;
491
521
assert ! ( response. client_secret. is_some( ) ) ;
492
522
}
523
+ #[ sqlx:: test( migrator = "mas_storage_pg::MIGRATOR" ) ]
524
+ async fn test_registration_dedupe ( pool : PgPool ) {
525
+ setup ( ) ;
526
+ let state = TestState :: from_pool ( pool) . await . unwrap ( ) ;
527
+
528
+ // Post a client registration twice, we should get the same client ID
529
+ let request =
530
+ Request :: post ( mas_router:: OAuth2RegistrationEndpoint :: PATH ) . json ( serde_json:: json!( {
531
+ "client_uri" : "https://example.com/" ,
532
+ "redirect_uris" : [ "https://example.com/" ] ,
533
+ "response_types" : [ "code" ] ,
534
+ "grant_types" : [ "authorization_code" ] ,
535
+ "token_endpoint_auth_method" : "none" ,
536
+ } ) ) ;
537
+
538
+ let response = state. request ( request. clone ( ) ) . await ;
539
+ response. assert_status ( StatusCode :: CREATED ) ;
540
+ let response: ClientRegistrationResponse = response. json ( ) ;
541
+ let client_id = response. client_id ;
542
+
543
+ let response = state. request ( request) . await ;
544
+ response. assert_status ( StatusCode :: CREATED ) ;
545
+ let response: ClientRegistrationResponse = response. json ( ) ;
546
+ assert_eq ! ( response. client_id, client_id) ;
547
+
548
+ // Doing that with a client that has a client_secret should not deduplicate
549
+ let request =
550
+ Request :: post ( mas_router:: OAuth2RegistrationEndpoint :: PATH ) . json ( serde_json:: json!( {
551
+ "client_uri" : "https://example.com/" ,
552
+ "redirect_uris" : [ "https://example.com/" ] ,
553
+ "response_types" : [ "code" ] ,
554
+ "grant_types" : [ "authorization_code" ] ,
555
+ "token_endpoint_auth_method" : "client_secret_basic" ,
556
+ } ) ) ;
557
+
558
+ let response = state. request ( request. clone ( ) ) . await ;
559
+ response. assert_status ( StatusCode :: CREATED ) ;
560
+ let response: ClientRegistrationResponse = response. json ( ) ;
561
+ // Sanity check that the client_id is different
562
+ assert_ne ! ( response. client_id, client_id) ;
563
+ let client_id = response. client_id ;
564
+
565
+ let response = state. request ( request) . await ;
566
+ response. assert_status ( StatusCode :: CREATED ) ;
567
+ let response: ClientRegistrationResponse = response. json ( ) ;
568
+ assert_ne ! ( response. client_id, client_id) ;
569
+ }
493
570
}
0 commit comments