Skip to content

Commit 6584d85

Browse files
swcollardDaleSeo
authored andcommitted
Trace Auth with status_code and reason
Changeset Unit tests for auth.rs Add raw_operation as a telemetry attribute Unit test starting a streamable http server Rename self to operation_source
1 parent dda6e1e commit 6584d85

File tree

7 files changed

+133
-11
lines changed

7 files changed

+133
-11
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
### Telemetry: Trace operations and auth - @swcollard PR #375
2+
3+
* Adds traces for the MCP server generating Tools from Operations and performing authorization
4+
* Includes the HTTP status code to the top level HTTP trace

crates/apollo-mcp-server/src/auth.rs

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ impl Config {
8484
}
8585

8686
/// Validate that requests made have a corresponding bearer JWT token
87+
#[tracing::instrument(skip_all, fields(status_code, reason))]
8788
async fn oauth_validate(
8889
State(auth_config): State<Config>,
8990
token: Option<TypedHeader<Authorization<Bearer>>>,
@@ -104,17 +105,85 @@ async fn oauth_validate(
104105
};
105106

106107
let validator = NetworkedTokenValidator::new(&auth_config.audiences, &auth_config.servers);
107-
let token = token.ok_or_else(unauthorized_error)?;
108-
109-
let valid_token = validator
110-
.validate(token.0)
111-
.await
112-
.ok_or_else(unauthorized_error)?;
108+
let token = token.ok_or_else(|| {
109+
tracing::Span::current().record("reason", "missing_token");
110+
tracing::Span::current().record("status_code", StatusCode::UNAUTHORIZED.as_u16());
111+
unauthorized_error()
112+
})?;
113+
114+
let valid_token = validator.validate(token.0).await.ok_or_else(|| {
115+
tracing::Span::current().record("reason", "invalid_token");
116+
tracing::Span::current().record("status_code", StatusCode::UNAUTHORIZED.as_u16());
117+
unauthorized_error()
118+
})?;
113119

114120
// Insert new context to ensure that handlers only use our enforced token verification
115121
// for propagation
116122
request.extensions_mut().insert(valid_token);
117123

118124
let response = next.run(request).await;
125+
tracing::Span::current().record("status_code", response.status().as_u16());
119126
Ok(response)
120127
}
128+
129+
#[cfg(test)]
130+
mod tests {
131+
use super::*;
132+
use axum::middleware::from_fn_with_state;
133+
use axum::routing::get;
134+
use axum::{
135+
Router,
136+
body::Body,
137+
http::{Request, StatusCode},
138+
};
139+
use http::header::{AUTHORIZATION, WWW_AUTHENTICATE};
140+
use tower::ServiceExt; // for .oneshot()
141+
use url::Url;
142+
143+
fn test_config() -> Config {
144+
Config {
145+
servers: vec![Url::parse("http://localhost:1234").unwrap()],
146+
audiences: vec!["test-audience".to_string()],
147+
resource: Url::parse("http://localhost:4000").unwrap(),
148+
resource_documentation: None,
149+
scopes: vec!["read".to_string()],
150+
disable_auth_token_passthrough: false,
151+
}
152+
}
153+
154+
fn test_router(config: Config) -> Router {
155+
Router::new()
156+
.route("/test", get(|| async { "ok" }))
157+
.layer(from_fn_with_state(config, oauth_validate))
158+
}
159+
160+
#[tokio::test]
161+
async fn missing_token_returns_unauthorized() {
162+
let config = test_config();
163+
let app = test_router(config.clone());
164+
let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
165+
let res = app.oneshot(req).await.unwrap();
166+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
167+
let headers = res.headers();
168+
let www_auth = headers.get(WWW_AUTHENTICATE).unwrap().to_str().unwrap();
169+
assert!(www_auth.contains("Bearer"));
170+
assert!(www_auth.contains("resource_metadata"));
171+
}
172+
173+
#[tokio::test]
174+
async fn invalid_token_returns_unauthorized() {
175+
let config = test_config();
176+
let app = test_router(config.clone());
177+
let req = Request::builder()
178+
.uri("/test")
179+
.header(AUTHORIZATION, "Bearer invalidtoken")
180+
.body(Body::empty())
181+
.unwrap();
182+
let res = app.oneshot(req).await.unwrap();
183+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
184+
let headers = res.headers();
185+
let www_auth = headers.get(WWW_AUTHENTICATE).unwrap().to_str().unwrap();
186+
assert!(www_auth.contains("Bearer"));
187+
assert!(www_auth.contains("resource_metadata"));
188+
}
189+
}

crates/apollo-mcp-server/src/operations/operation_source.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ pub enum OperationSource {
3838
}
3939

4040
impl OperationSource {
41-
#[tracing::instrument]
41+
#[tracing::instrument(skip_all, fields(operation_source = ?self))]
4242
pub async fn into_stream(self) -> impl Stream<Item = Event> {
4343
match self {
4444
OperationSource::Files(paths) => Self::stream_file_changes(paths).boxed(),

crates/apollo-mcp-server/src/server/states/running.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ impl Running {
105105
Ok(self)
106106
}
107107

108+
#[tracing::instrument(skip_all)]
108109
pub(super) async fn update_operations(
109110
self,
110111
operations: Vec<RawOperation>,
@@ -146,6 +147,7 @@ impl Running {
146147
}
147148

148149
/// Notify any peers that tools have changed. Drops unreachable peers from the list.
150+
#[tracing::instrument(skip_all)]
149151
async fn notify_tool_list_changed(peers: Arc<RwLock<Vec<Peer<RoleServer>>>>) {
150152
let mut peers = peers.write().await;
151153
if !peers.is_empty() {

crates/apollo-mcp-server/src/server/states/starting.rs

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,7 @@ impl Starting {
229229
|response: &axum::http::Response<_>,
230230
_latency: std::time::Duration,
231231
span: &tracing::Span| {
232-
span.record(
233-
"status_code",
234-
tracing::field::display(response.status()),
235-
);
232+
span.record("status", tracing::field::display(response.status()));
236233
},
237234
),
238235
);
@@ -334,3 +331,49 @@ async fn health_endpoint(
334331

335332
Ok((status_code, Json(json!(health))))
336333
}
334+
335+
#[cfg(test)]
336+
mod tests {
337+
use http::HeaderMap;
338+
use url::Url;
339+
340+
use crate::health::HealthCheckConfig;
341+
342+
use super::*;
343+
344+
#[tokio::test]
345+
async fn start_basic_server() {
346+
let starting = Starting {
347+
config: Config {
348+
transport: Transport::StreamableHttp {
349+
auth: None,
350+
address: "127.0.0.1".parse().unwrap(),
351+
port: 7799,
352+
stateful_mode: false,
353+
},
354+
endpoint: Url::parse("http://localhost:4000").expect("valid url"),
355+
mutation_mode: MutationMode::All,
356+
execute_introspection: true,
357+
headers: HeaderMap::new(),
358+
validate_introspection: true,
359+
introspect_introspection: true,
360+
search_introspection: true,
361+
introspect_minify: false,
362+
search_minify: false,
363+
explorer_graph_ref: None,
364+
custom_scalar_map: None,
365+
disable_type_description: false,
366+
disable_schema_description: false,
367+
disable_auth_token_passthrough: false,
368+
search_leaf_depth: 5,
369+
index_memory_bytes: 1024 * 1024 * 1024,
370+
health_check: HealthCheckConfig::default(),
371+
},
372+
schema: Schema::parse_and_validate("type Query { hello: String }", "test.graphql")
373+
.expect("Valid schema"),
374+
operations: vec![],
375+
};
376+
let running = starting.start();
377+
assert!(running.await.is_ok());
378+
}
379+
}

crates/apollo-mcp-server/src/telemetry_attributes.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ impl TelemetryAttribute {
2020
TelemetryAttribute::RequestId => {
2121
Key::from_static_str(TelemetryAttribute::RequestId.as_str())
2222
}
23+
TelemetryAttribute::RawOperation => {
24+
Key::from_static_str(TelemetryAttribute::RawOperation.as_str())
25+
}
2326
}
2427
}
2528

crates/apollo-mcp-server/telemetry.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ operation_id = "The operation id - either persisted query id, operation name, or
44
operation_source = "The operation source - either operation (local file/op collection), persisted query, or LLM generated"
55
request_id = "The request id"
66
success = "Sucess flag indicator"
7+
raw_operation = "Graphql operation text and metadata used for Tool generation"
78

89
[metrics.apollo.mcp]
910
"initialize.count" = "Number of times initialize has been called"

0 commit comments

Comments
 (0)