@@ -22,6 +22,7 @@ use oauth2_types::{
22
22
use psl:: Psl ;
23
23
use rand:: distributions:: { Alphanumeric , DistString } ;
24
24
use serde:: Serialize ;
25
+ use sha2:: Digest as _;
25
26
use thiserror:: Error ;
26
27
use tracing:: info;
27
28
use url:: Url ;
@@ -50,6 +51,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
50
51
impl_from_error_for_route ! ( mas_policy:: LoadError ) ;
51
52
impl_from_error_for_route ! ( mas_policy:: EvaluationError ) ;
52
53
impl_from_error_for_route ! ( mas_keystore:: aead:: Error ) ;
54
+ impl_from_error_for_route ! ( serde_json:: Error ) ;
53
55
54
56
impl IntoResponse for RouteError {
55
57
fn into_response ( self ) -> axum:: response:: Response {
@@ -204,7 +206,13 @@ pub(crate) async fn post(
204
206
// Propagate any JSON extraction error
205
207
let Json ( body) = body?;
206
208
207
- info ! ( ?body, "Client registration" ) ;
209
+ // Sort the properties to ensure a stable serialisation order for hashing
210
+ let body = body. sorted ( ) ;
211
+
212
+ // We need to serialize the body to compute the hash, and to log it
213
+ let body_json = serde_json:: to_string ( & body) ?;
214
+
215
+ info ! ( body = body_json, "Client registration" ) ;
208
216
209
217
let user_agent = user_agent. map ( |ua| ua. to_string ( ) ) ;
210
218
@@ -276,34 +284,59 @@ pub(crate) async fn post(
276
284
_ => ( None , None ) ,
277
285
} ;
278
286
279
- let client = repo
280
- . oauth2_client ( )
281
- . add (
282
- & mut rng,
283
- & clock,
284
- metadata. redirect_uris ( ) . to_vec ( ) ,
285
- encrypted_client_secret,
286
- metadata. application_type . clone ( ) ,
287
- //&metadata.response_types(),
288
- metadata. grant_types ( ) . to_vec ( ) ,
289
- metadata
290
- . client_name
291
- . clone ( )
292
- . map ( Localized :: to_non_localized) ,
293
- metadata. logo_uri . clone ( ) . map ( Localized :: to_non_localized) ,
294
- metadata. client_uri . clone ( ) . map ( Localized :: to_non_localized) ,
295
- metadata. policy_uri . clone ( ) . map ( Localized :: to_non_localized) ,
296
- metadata. tos_uri . clone ( ) . map ( Localized :: to_non_localized) ,
297
- metadata. jwks_uri . clone ( ) ,
298
- metadata. jwks . clone ( ) ,
299
- // XXX: those might not be right, should be function calls
300
- metadata. id_token_signed_response_alg . clone ( ) ,
301
- metadata. userinfo_signed_response_alg . clone ( ) ,
302
- metadata. token_endpoint_auth_method . clone ( ) ,
303
- metadata. token_endpoint_auth_signing_alg . clone ( ) ,
304
- metadata. initiate_login_uri . clone ( ) ,
305
- )
306
- . await ?;
287
+ // If the client doesn't have a secret, we may be able to deduplicate it. To
288
+ // do so, we hash the client metadata, and look for it in the database
289
+ let ( digest_hash, existing_client) = if client_secret. is_none ( ) {
290
+ // XXX: One interesting caveat is that we hash *before* saving to the database.
291
+ // It means it takes into account fields that we don't care about *yet*.
292
+ //
293
+ // This means that if later we start supporting a particular field, we
294
+ // will still serve the 'old' client_id, without updating the client in the
295
+ // database
296
+ let hash = sha2:: Sha256 :: digest ( body_json) ;
297
+ let hash = hex:: encode ( hash) ;
298
+ let client = repo. oauth2_client ( ) . find_by_metadata_digest ( & hash) . await ?;
299
+ ( Some ( hash) , client)
300
+ } else {
301
+ ( None , None )
302
+ } ;
303
+
304
+ let client = if let Some ( client) = existing_client {
305
+ tracing:: info!( %client. id, "Reusing existing client" ) ;
306
+ client
307
+ } else {
308
+ let client = repo
309
+ . oauth2_client ( )
310
+ . add (
311
+ & mut rng,
312
+ & clock,
313
+ metadata. redirect_uris ( ) . to_vec ( ) ,
314
+ digest_hash,
315
+ encrypted_client_secret,
316
+ metadata. application_type . clone ( ) ,
317
+ //&metadata.response_types(),
318
+ metadata. grant_types ( ) . to_vec ( ) ,
319
+ metadata
320
+ . client_name
321
+ . clone ( )
322
+ . map ( Localized :: to_non_localized) ,
323
+ metadata. logo_uri . clone ( ) . map ( Localized :: to_non_localized) ,
324
+ metadata. client_uri . clone ( ) . map ( Localized :: to_non_localized) ,
325
+ metadata. policy_uri . clone ( ) . map ( Localized :: to_non_localized) ,
326
+ metadata. tos_uri . clone ( ) . map ( Localized :: to_non_localized) ,
327
+ metadata. jwks_uri . clone ( ) ,
328
+ metadata. jwks . clone ( ) ,
329
+ // XXX: those might not be right, should be function calls
330
+ metadata. id_token_signed_response_alg . clone ( ) ,
331
+ metadata. userinfo_signed_response_alg . clone ( ) ,
332
+ metadata. token_endpoint_auth_method . clone ( ) ,
333
+ metadata. token_endpoint_auth_signing_alg . clone ( ) ,
334
+ metadata. initiate_login_uri . clone ( ) ,
335
+ )
336
+ . await ?;
337
+ tracing:: info!( %client. id, "Registered new client" ) ;
338
+ client
339
+ } ;
307
340
308
341
let response = ClientRegistrationResponse {
309
342
client_id : client. client_id . clone ( ) ,
@@ -490,4 +523,74 @@ mod tests {
490
523
let response: ClientRegistrationResponse = response. json ( ) ;
491
524
assert ! ( response. client_secret. is_some( ) ) ;
492
525
}
526
+ #[ sqlx:: test( migrator = "mas_storage_pg::MIGRATOR" ) ]
527
+ async fn test_registration_dedupe ( pool : PgPool ) {
528
+ setup ( ) ;
529
+ let state = TestState :: from_pool ( pool) . await . unwrap ( ) ;
530
+
531
+ // Post a client registration twice, we should get the same client ID
532
+ let request =
533
+ Request :: post ( mas_router:: OAuth2RegistrationEndpoint :: PATH ) . json ( serde_json:: json!( {
534
+ "client_uri" : "https://example.com/" ,
535
+ "client_name" : "Example" ,
536
+ "client_name#en" : "Example" ,
537
+ "client_name#fr" : "Exemple" ,
538
+ "client_name#de" : "Beispiel" ,
539
+ "redirect_uris" : [ "https://example.com/" , "https://example.com/callback" ] ,
540
+ "response_types" : [ "code" ] ,
541
+ "grant_types" : [ "authorization_code" , "urn:ietf:params:oauth:grant-type:device_code" ] ,
542
+ "token_endpoint_auth_method" : "none" ,
543
+ } ) ) ;
544
+
545
+ let response = state. request ( request. clone ( ) ) . await ;
546
+ response. assert_status ( StatusCode :: CREATED ) ;
547
+ let response: ClientRegistrationResponse = response. json ( ) ;
548
+ let client_id = response. client_id ;
549
+
550
+ let response = state. request ( request) . await ;
551
+ response. assert_status ( StatusCode :: CREATED ) ;
552
+ let response: ClientRegistrationResponse = response. json ( ) ;
553
+ assert_eq ! ( response. client_id, client_id) ;
554
+
555
+ // Check that the order of some properties doesn't matter
556
+ let request =
557
+ Request :: post ( mas_router:: OAuth2RegistrationEndpoint :: PATH ) . json ( serde_json:: json!( {
558
+ "client_uri" : "https://example.com/" ,
559
+ "client_name" : "Example" ,
560
+ "client_name#de" : "Beispiel" ,
561
+ "client_name#fr" : "Exemple" ,
562
+ "client_name#en" : "Example" ,
563
+ "redirect_uris" : [ "https://example.com/callback" , "https://example.com/" ] ,
564
+ "response_types" : [ "code" ] ,
565
+ "grant_types" : [ "urn:ietf:params:oauth:grant-type:device_code" , "authorization_code" ] ,
566
+ "token_endpoint_auth_method" : "none" ,
567
+ } ) ) ;
568
+
569
+ let response = state. request ( request) . await ;
570
+ response. assert_status ( StatusCode :: CREATED ) ;
571
+ let response: ClientRegistrationResponse = response. json ( ) ;
572
+ assert_eq ! ( response. client_id, client_id) ;
573
+
574
+ // Doing that with a client that has a client_secret should not deduplicate
575
+ let request =
576
+ Request :: post ( mas_router:: OAuth2RegistrationEndpoint :: PATH ) . json ( serde_json:: json!( {
577
+ "client_uri" : "https://example.com/" ,
578
+ "redirect_uris" : [ "https://example.com/" ] ,
579
+ "response_types" : [ "code" ] ,
580
+ "grant_types" : [ "authorization_code" ] ,
581
+ "token_endpoint_auth_method" : "client_secret_basic" ,
582
+ } ) ) ;
583
+
584
+ let response = state. request ( request. clone ( ) ) . await ;
585
+ response. assert_status ( StatusCode :: CREATED ) ;
586
+ let response: ClientRegistrationResponse = response. json ( ) ;
587
+ // Sanity check that the client_id is different
588
+ assert_ne ! ( response. client_id, client_id) ;
589
+ let client_id = response. client_id ;
590
+
591
+ let response = state. request ( request) . await ;
592
+ response. assert_status ( StatusCode :: CREATED ) ;
593
+ let response: ClientRegistrationResponse = response. json ( ) ;
594
+ assert_ne ! ( response. client_id, client_id) ;
595
+ }
493
596
}
0 commit comments