Skip to content

Commit 787cc01

Browse files
authored
feat(extension): extract http request part into rmcp extension (#163)
* feat(extension): extract http request part into rmcp extension * perf(extension): insert extension from one entrypoint
1 parent 546490a commit 787cc01

File tree

4 files changed

+67
-8
lines changed

4 files changed

+67
-8
lines changed

crates/rmcp/src/model/meta.rs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
44
use serde_json::Value;
55

66
use super::{
7-
ClientNotification, ClientRequest, Extensions, JsonObject, NumberOrString, ProgressToken,
8-
ServerNotification, ServerRequest,
7+
ClientNotification, ClientRequest, Extensions, JsonObject, JsonRpcMessage, NumberOrString,
8+
ProgressToken, ServerNotification, ServerRequest,
99
};
1010

1111
pub trait GetMeta {
@@ -153,3 +153,42 @@ impl DerefMut for Meta {
153153
&mut self.0
154154
}
155155
}
156+
157+
impl<Req, Resp, Noti> JsonRpcMessage<Req, Resp, Noti>
158+
where
159+
Req: GetExtensions,
160+
Noti: GetExtensions,
161+
{
162+
pub fn insert_extension<T: Clone + Send + Sync + 'static>(&mut self, value: T) {
163+
match self {
164+
JsonRpcMessage::Request(json_rpc_request) => {
165+
json_rpc_request.request.extensions_mut().insert(value);
166+
}
167+
JsonRpcMessage::Notification(json_rpc_notification) => {
168+
json_rpc_notification
169+
.notification
170+
.extensions_mut()
171+
.insert(value);
172+
}
173+
JsonRpcMessage::BatchRequest(json_rpc_batch_request_items) => {
174+
for item in json_rpc_batch_request_items {
175+
match item {
176+
super::JsonRpcBatchRequestItem::Request(json_rpc_request) => {
177+
json_rpc_request
178+
.request
179+
.extensions_mut()
180+
.insert(value.clone());
181+
}
182+
super::JsonRpcBatchRequestItem::Notification(json_rpc_notification) => {
183+
json_rpc_notification
184+
.notification
185+
.extensions_mut()
186+
.insert(value.clone());
187+
}
188+
}
189+
}
190+
}
191+
_ => {}
192+
}
193+
}
194+
}

crates/rmcp/src/transport/sse_server.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration};
33
use axum::{
44
Json, Router,
55
extract::{Query, State},
6-
http::StatusCode,
6+
http::{StatusCode, request::Parts},
77
response::{
88
Response,
99
sse::{Event, KeepAlive, Sse},
@@ -64,7 +64,8 @@ pub struct PostEventQuery {
6464
async fn post_event_handler(
6565
State(app): State<App>,
6666
Query(PostEventQuery { session_id }): Query<PostEventQuery>,
67-
Json(message): Json<ClientJsonRpcMessage>,
67+
parts: Parts,
68+
Json(mut message): Json<ClientJsonRpcMessage>,
6869
) -> Result<StatusCode, StatusCode> {
6970
tracing::debug!(session_id, ?message, "new client message");
7071
let tx = {
@@ -73,6 +74,7 @@ async fn post_event_handler(
7374
.ok_or(StatusCode::NOT_FOUND)?
7475
.clone()
7576
};
77+
message.insert_extension(parts);
7678
if tx.send(message).await.is_err() {
7779
tracing::error!("send message error");
7880
return Err(StatusCode::GONE);

crates/rmcp/src/transport/streamable_http_server/axum.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration};
33
use axum::{
44
Json, Router,
55
extract::State,
6-
http::{HeaderMap, HeaderValue, StatusCode},
6+
http::{HeaderMap, HeaderValue, StatusCode, request::Parts},
77
response::{
88
IntoResponse, Response,
99
sse::{Event, KeepAlive, Sse},
@@ -68,11 +68,11 @@ fn receiver_as_stream(
6868

6969
async fn post_handler(
7070
State(app): State<App>,
71-
header_map: HeaderMap,
72-
Json(message): Json<ClientJsonRpcMessage>,
71+
parts: Parts,
72+
Json(mut message): Json<ClientJsonRpcMessage>,
7373
) -> Result<Response, Response> {
7474
use futures::StreamExt;
75-
if let Some(session_id) = header_map.get(HEADER_SESSION_ID) {
75+
if let Some(session_id) = parts.headers.get(HEADER_SESSION_ID).cloned() {
7676
let session_id = session_id
7777
.to_str()
7878
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()).into_response())?;
@@ -84,6 +84,8 @@ async fn post_handler(
8484
.ok_or((StatusCode::NOT_FOUND, "session not found").into_response())?;
8585
session.handle().clone()
8686
};
87+
// inject request part
88+
message.insert_extension(parts);
8789
match &message {
8890
ClientJsonRpcMessage::Request(_) | ClientJsonRpcMessage::BatchRequest(_) => {
8991
let receiver = handle.establish_request_wise_channel().await.map_err(|e| {
@@ -128,6 +130,8 @@ async fn post_handler(
128130
} else {
129131
// expect initialize message
130132
let session_id = session_id();
133+
// inject request part
134+
message.insert_extension(parts);
131135
let (session, transport) =
132136
super::session::create_session(session_id.clone(), Default::default());
133137
let Ok(_) = app.transport_tx.send(transport) else {

examples/servers/src/common/counter.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pub struct StructRequest {
1717
pub struct Counter {
1818
counter: Arc<Mutex<i32>>,
1919
}
20+
2021
#[tool(tool_box)]
2122
impl Counter {
2223
#[allow(dead_code)]
@@ -194,4 +195,17 @@ impl ServerHandler for Counter {
194195
resource_templates: Vec::new(),
195196
})
196197
}
198+
199+
async fn initialize(
200+
&self,
201+
_request: InitializeRequestParam,
202+
context: RequestContext<RoleServer>,
203+
) -> Result<InitializeResult, McpError> {
204+
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
205+
let initialize_headers = &http_request_part.headers;
206+
let initialize_uri = &http_request_part.uri;
207+
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
208+
}
209+
Ok(self.get_info())
210+
}
197211
}

0 commit comments

Comments
 (0)