Skip to content

Commit df7b0c7

Browse files
authored
Merge pull request #244 from apollographql/GT-333
Valid token fails validation with multiple audiences
2 parents 7dc713b + 76b345c commit df7b0c7

File tree

1 file changed

+120
-1
lines changed

1 file changed

+120
-1
lines changed

crates/apollo-mcp-server/src/auth/valid_token.rs

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,31 @@ pub(super) trait ValidateToken {
4242
#[derive(Clone, Debug, Serialize, Deserialize)]
4343
pub struct Claims {
4444
/// The intended audience of this token.
45-
pub aud: String,
45+
/// Can be either a single string or an array of strings per JWT spec. (https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3)
46+
#[serde(deserialize_with = "deserialize_audience")]
47+
pub aud: Vec<String>,
4648

4749
/// The user who owns this token
4850
pub sub: String,
4951
}
5052

53+
fn deserialize_audience<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
54+
where
55+
D: serde::Deserializer<'de>,
56+
{
57+
#[derive(Deserialize)]
58+
#[serde(untagged)]
59+
enum Audience {
60+
Single(String),
61+
Multiple(Vec<String>),
62+
}
63+
64+
Ok(match Audience::deserialize(deserializer)? {
65+
Audience::Single(s) => vec![s],
66+
Audience::Multiple(v) => v,
67+
})
68+
}
69+
5170
let jwt = token.token();
5271
let header = decode_header(jwt).ok()?;
5372
let key_id = header.kid.as_ref()?;
@@ -327,4 +346,104 @@ mod test {
327346
.ok_or("Expected warning for validation failure".to_string())
328347
});
329348
}
349+
350+
#[tokio::test]
351+
async fn it_validates_jwt_with_array_audience() {
352+
use serde_json::json;
353+
354+
let key_id = "some-example-id".to_string();
355+
let (encode_key, decode_key) = create_key("DEADBEEF");
356+
let jwk = Jwk {
357+
alg: KeyAlgorithm::HS512,
358+
decoding_key: decode_key,
359+
};
360+
361+
let audience = "test-audience".to_string();
362+
let in_the_future = chrono::Utc::now().timestamp() + 1000;
363+
364+
let header = {
365+
let mut h = Header::new(Algorithm::HS512);
366+
h.kid = Some(key_id.clone());
367+
h
368+
};
369+
370+
let claims = json!({
371+
"aud": ["test-audience", "another-audience"],
372+
"exp": in_the_future,
373+
"sub": "test user"
374+
});
375+
376+
let token = encode(&header, &claims, &encode_key).expect("encode JWT");
377+
let jwt = Authorization::bearer(&token).expect("create bearer token");
378+
379+
let server =
380+
Url::from_str("https://auth.example.com").expect("should parse a valid example server");
381+
382+
let test_validator = TestTokenValidator {
383+
audiences: vec![audience],
384+
key_pair: (key_id, jwk),
385+
servers: vec![server],
386+
};
387+
388+
assert_eq!(
389+
test_validator
390+
.validate(jwt)
391+
.await
392+
.expect("valid token")
393+
.0
394+
.token(),
395+
token
396+
);
397+
}
398+
399+
#[traced_test]
400+
#[tokio::test]
401+
async fn it_rejects_array_audience_with_no_matches() {
402+
use serde_json::json;
403+
404+
let key_id = "some-example-id".to_string();
405+
let (encode_key, decode_key) = create_key("DEADBEEF");
406+
let jwk = Jwk {
407+
alg: KeyAlgorithm::HS512,
408+
decoding_key: decode_key,
409+
};
410+
411+
let expected_audience = "expected-audience".to_string();
412+
let in_the_future = chrono::Utc::now().timestamp() + 1000;
413+
414+
let header = {
415+
let mut h = Header::new(Algorithm::HS512);
416+
h.kid = Some(key_id.clone());
417+
h
418+
};
419+
420+
let claims = json!({
421+
"aud": ["wrong-audience-1", "wrong-audience-2"],
422+
"exp": in_the_future,
423+
"sub": "test user"
424+
});
425+
426+
let token = encode(&header, &claims, &encode_key).expect("encode JWT");
427+
let jwt = Authorization::bearer(&token).expect("create bearer token");
428+
429+
let server =
430+
Url::from_str("https://auth.example.com").expect("should parse a valid example server");
431+
432+
let test_validator = TestTokenValidator {
433+
audiences: vec![expected_audience],
434+
key_pair: (key_id, jwk),
435+
servers: vec![server],
436+
};
437+
438+
assert_eq!(test_validator.validate(jwt).await, None);
439+
440+
logs_assert(|lines: &[&str]| {
441+
lines
442+
.iter()
443+
.filter(|line| line.contains("WARN"))
444+
.any(|line| line.contains("InvalidAudience"))
445+
.then_some(())
446+
.ok_or("Expected warning for validation failure".to_string())
447+
});
448+
}
330449
}

0 commit comments

Comments
 (0)