@@ -42,12 +42,31 @@ pub(super) trait ValidateToken {
4242 #[ derive( Clone , Debug , Serialize , Deserialize ) ]
4343 pub struct Claims {
4444 /// The intended audience of this token.
45- pub aud : String ,
45+ /// Can be either a single string or an array of strings per JWT spec. (https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3)
46+ #[ serde( deserialize_with = "deserialize_audience" ) ]
47+ pub aud : Vec < String > ,
4648
4749 /// The user who owns this token
4850 pub sub : String ,
4951 }
5052
53+ fn deserialize_audience < ' de , D > ( deserializer : D ) -> Result < Vec < String > , D :: Error >
54+ where
55+ D : serde:: Deserializer < ' de > ,
56+ {
57+ #[ derive( Deserialize ) ]
58+ #[ serde( untagged) ]
59+ enum Audience {
60+ Single ( String ) ,
61+ Multiple ( Vec < String > ) ,
62+ }
63+
64+ Ok ( match Audience :: deserialize ( deserializer) ? {
65+ Audience :: Single ( s) => vec ! [ s] ,
66+ Audience :: Multiple ( v) => v,
67+ } )
68+ }
69+
5170 let jwt = token. token ( ) ;
5271 let header = decode_header ( jwt) . ok ( ) ?;
5372 let key_id = header. kid . as_ref ( ) ?;
@@ -327,4 +346,104 @@ mod test {
327346 . ok_or ( "Expected warning for validation failure" . to_string ( ) )
328347 } ) ;
329348 }
349+
350+ #[ tokio:: test]
351+ async fn it_validates_jwt_with_array_audience ( ) {
352+ use serde_json:: json;
353+
354+ let key_id = "some-example-id" . to_string ( ) ;
355+ let ( encode_key, decode_key) = create_key ( "DEADBEEF" ) ;
356+ let jwk = Jwk {
357+ alg : KeyAlgorithm :: HS512 ,
358+ decoding_key : decode_key,
359+ } ;
360+
361+ let audience = "test-audience" . to_string ( ) ;
362+ let in_the_future = chrono:: Utc :: now ( ) . timestamp ( ) + 1000 ;
363+
364+ let header = {
365+ let mut h = Header :: new ( Algorithm :: HS512 ) ;
366+ h. kid = Some ( key_id. clone ( ) ) ;
367+ h
368+ } ;
369+
370+ let claims = json ! ( {
371+ "aud" : [ "test-audience" , "another-audience" ] ,
372+ "exp" : in_the_future,
373+ "sub" : "test user"
374+ } ) ;
375+
376+ let token = encode ( & header, & claims, & encode_key) . expect ( "encode JWT" ) ;
377+ let jwt = Authorization :: bearer ( & token) . expect ( "create bearer token" ) ;
378+
379+ let server =
380+ Url :: from_str ( "https://auth.example.com" ) . expect ( "should parse a valid example server" ) ;
381+
382+ let test_validator = TestTokenValidator {
383+ audiences : vec ! [ audience] ,
384+ key_pair : ( key_id, jwk) ,
385+ servers : vec ! [ server] ,
386+ } ;
387+
388+ assert_eq ! (
389+ test_validator
390+ . validate( jwt)
391+ . await
392+ . expect( "valid token" )
393+ . 0
394+ . token( ) ,
395+ token
396+ ) ;
397+ }
398+
399+ #[ traced_test]
400+ #[ tokio:: test]
401+ async fn it_rejects_array_audience_with_no_matches ( ) {
402+ use serde_json:: json;
403+
404+ let key_id = "some-example-id" . to_string ( ) ;
405+ let ( encode_key, decode_key) = create_key ( "DEADBEEF" ) ;
406+ let jwk = Jwk {
407+ alg : KeyAlgorithm :: HS512 ,
408+ decoding_key : decode_key,
409+ } ;
410+
411+ let expected_audience = "expected-audience" . to_string ( ) ;
412+ let in_the_future = chrono:: Utc :: now ( ) . timestamp ( ) + 1000 ;
413+
414+ let header = {
415+ let mut h = Header :: new ( Algorithm :: HS512 ) ;
416+ h. kid = Some ( key_id. clone ( ) ) ;
417+ h
418+ } ;
419+
420+ let claims = json ! ( {
421+ "aud" : [ "wrong-audience-1" , "wrong-audience-2" ] ,
422+ "exp" : in_the_future,
423+ "sub" : "test user"
424+ } ) ;
425+
426+ let token = encode ( & header, & claims, & encode_key) . expect ( "encode JWT" ) ;
427+ let jwt = Authorization :: bearer ( & token) . expect ( "create bearer token" ) ;
428+
429+ let server =
430+ Url :: from_str ( "https://auth.example.com" ) . expect ( "should parse a valid example server" ) ;
431+
432+ let test_validator = TestTokenValidator {
433+ audiences : vec ! [ expected_audience] ,
434+ key_pair : ( key_id, jwk) ,
435+ servers : vec ! [ server] ,
436+ } ;
437+
438+ assert_eq ! ( test_validator. validate( jwt) . await , None ) ;
439+
440+ logs_assert ( |lines : & [ & str ] | {
441+ lines
442+ . iter ( )
443+ . filter ( |line| line. contains ( "WARN" ) )
444+ . any ( |line| line. contains ( "InvalidAudience" ) )
445+ . then_some ( ( ) )
446+ . ok_or ( "Expected warning for validation failure" . to_string ( ) )
447+ } ) ;
448+ }
330449}
0 commit comments