@@ -42,12 +42,53 @@ pub(super) trait ValidateToken {
4242 #[ derive( Clone , Debug , Serialize , Deserialize ) ]
4343 pub struct Claims {
4444 /// The intended audience of this token.
45- pub aud : String ,
45+ /// Can be either a single string or an array of strings per JWT spec. (https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3)
46+ #[ serde( deserialize_with = "deserialize_audience" ) ]
47+ pub aud : Vec < String > ,
4648
4749 /// The user who owns this token
4850 pub sub : String ,
4951 }
5052
53+ /// Custom deserializer to handle both single string and array of strings for audience claim
54+ fn deserialize_audience < ' de , D > ( deserializer : D ) -> Result < Vec < String > , D :: Error >
55+ where
56+ D : serde:: Deserializer < ' de > ,
57+ {
58+ use serde:: de:: { SeqAccess , Visitor } ;
59+ use std:: fmt;
60+
61+ struct AudienceVisitor ;
62+
63+ impl < ' de > Visitor < ' de > for AudienceVisitor {
64+ type Value = Vec < String > ;
65+
66+ fn expecting ( & self , formatter : & mut fmt:: Formatter ) -> fmt:: Result {
67+ formatter. write_str ( "a string or array of strings" )
68+ }
69+
70+ fn visit_str < E > ( self , value : & str ) -> Result < Self :: Value , E >
71+ where
72+ E : serde:: de:: Error ,
73+ {
74+ Ok ( vec ! [ value. to_string( ) ] )
75+ }
76+
77+ fn visit_seq < A > ( self , mut seq : A ) -> Result < Self :: Value , A :: Error >
78+ where
79+ A : SeqAccess < ' de > ,
80+ {
81+ let mut audiences = Vec :: new ( ) ;
82+ while let Some ( aud) = seq. next_element ( ) ? {
83+ audiences. push ( aud) ;
84+ }
85+ Ok ( audiences)
86+ }
87+ }
88+
89+ deserializer. deserialize_any ( AudienceVisitor )
90+ }
91+
5192 let jwt = token. token ( ) ;
5293 let header = decode_header ( jwt) . ok ( ) ?;
5394 let key_id = header. kid . as_ref ( ) ?;
@@ -327,4 +368,104 @@ mod test {
327368 . ok_or ( "Expected warning for validation failure" . to_string ( ) )
328369 } ) ;
329370 }
371+
372+ #[ tokio:: test]
373+ async fn it_validates_jwt_with_array_audience ( ) {
374+ use serde_json:: json;
375+
376+ let key_id = "some-example-id" . to_string ( ) ;
377+ let ( encode_key, decode_key) = create_key ( "DEADBEEF" ) ;
378+ let jwk = Jwk {
379+ alg : KeyAlgorithm :: HS512 ,
380+ decoding_key : decode_key,
381+ } ;
382+
383+ let audience = "test-audience" . to_string ( ) ;
384+ let in_the_future = chrono:: Utc :: now ( ) . timestamp ( ) + 1000 ;
385+
386+ let header = {
387+ let mut h = Header :: new ( Algorithm :: HS512 ) ;
388+ h. kid = Some ( key_id. clone ( ) ) ;
389+ h
390+ } ;
391+
392+ let claims = json ! ( {
393+ "aud" : [ "test-audience" , "another-audience" ] ,
394+ "exp" : in_the_future,
395+ "sub" : "test user"
396+ } ) ;
397+
398+ let token = encode ( & header, & claims, & encode_key) . expect ( "encode JWT" ) ;
399+ let jwt = Authorization :: bearer ( & token) . expect ( "create bearer token" ) ;
400+
401+ let server =
402+ Url :: from_str ( "https://auth.example.com" ) . expect ( "should parse a valid example server" ) ;
403+
404+ let test_validator = TestTokenValidator {
405+ audiences : vec ! [ audience] ,
406+ key_pair : ( key_id, jwk) ,
407+ servers : vec ! [ server] ,
408+ } ;
409+
410+ assert_eq ! (
411+ test_validator
412+ . validate( jwt)
413+ . await
414+ . expect( "valid token" )
415+ . 0
416+ . token( ) ,
417+ token
418+ ) ;
419+ }
420+
421+ #[ traced_test]
422+ #[ tokio:: test]
423+ async fn it_rejects_array_audience_with_no_matches ( ) {
424+ use serde_json:: json;
425+
426+ let key_id = "some-example-id" . to_string ( ) ;
427+ let ( encode_key, decode_key) = create_key ( "DEADBEEF" ) ;
428+ let jwk = Jwk {
429+ alg : KeyAlgorithm :: HS512 ,
430+ decoding_key : decode_key,
431+ } ;
432+
433+ let expected_audience = "expected-audience" . to_string ( ) ;
434+ let in_the_future = chrono:: Utc :: now ( ) . timestamp ( ) + 1000 ;
435+
436+ let header = {
437+ let mut h = Header :: new ( Algorithm :: HS512 ) ;
438+ h. kid = Some ( key_id. clone ( ) ) ;
439+ h
440+ } ;
441+
442+ let claims = json ! ( {
443+ "aud" : [ "wrong-audience-1" , "wrong-audience-2" ] ,
444+ "exp" : in_the_future,
445+ "sub" : "test user"
446+ } ) ;
447+
448+ let token = encode ( & header, & claims, & encode_key) . expect ( "encode JWT" ) ;
449+ let jwt = Authorization :: bearer ( & token) . expect ( "create bearer token" ) ;
450+
451+ let server =
452+ Url :: from_str ( "https://auth.example.com" ) . expect ( "should parse a valid example server" ) ;
453+
454+ let test_validator = TestTokenValidator {
455+ audiences : vec ! [ expected_audience] ,
456+ key_pair : ( key_id, jwk) ,
457+ servers : vec ! [ server] ,
458+ } ;
459+
460+ assert_eq ! ( test_validator. validate( jwt) . await , None ) ;
461+
462+ logs_assert ( |lines : & [ & str ] | {
463+ lines
464+ . iter ( )
465+ . filter ( |line| line. contains ( "WARN" ) )
466+ . any ( |line| line. contains ( "InvalidAudience" ) )
467+ . then_some ( ( ) )
468+ . ok_or ( "Expected warning for validation failure" . to_string ( ) )
469+ } ) ;
470+ }
330471}
0 commit comments