|
1 | 1 | use std::borrow::Cow; |
2 | 2 | use std::collections::HashMap; |
3 | | -use std::ops::{ |
4 | | - Deref, |
5 | | - DerefMut, |
6 | | -}; |
7 | 3 | use std::process::Stdio; |
8 | 4 |
|
9 | 5 | use regex::Regex; |
10 | 6 | use reqwest::Client; |
11 | 7 | use rmcp::model::{ |
12 | | - ErrorCode, |
13 | | - Implementation, |
14 | | - InitializeRequestParam, |
15 | | - ListPromptsResult, |
16 | | - ListToolsResult, |
17 | | - LoggingLevel, |
18 | | - LoggingMessageNotificationParam, |
19 | | - PaginatedRequestParam, |
20 | | - ServerNotification, |
21 | | - ServerRequest, |
| 8 | + CallToolRequestParam, CallToolResult, ErrorCode, GetPromptRequestParam, GetPromptResult, Implementation, InitializeRequestParam, ListPromptsResult, ListToolsResult, LoggingLevel, LoggingMessageNotificationParam, PaginatedRequestParam, ServerNotification, ServerRequest |
22 | 9 | }; |
23 | 10 | use rmcp::service::{ |
24 | 11 | ClientInitializeError, |
@@ -151,30 +138,95 @@ pub enum McpClientError { |
151 | 138 | Auth(#[from] crate::auth::AuthError), |
152 | 139 | } |
153 | 140 |
|
154 | | -pub struct RunningService { |
155 | | - pub inner_service: rmcp::service::RunningService<RoleClient, Box<dyn DynService<RoleClient>>>, |
156 | | - #[allow(dead_code)] |
157 | | - pub auth_dropguard: Option<AuthClientDropGuard>, |
| 141 | +macro_rules! decorate_with_auth_retry { |
| 142 | + ($param_type:ty, $method_name:ident, $return_type:ty) => { |
| 143 | + pub async fn $method_name(&self, param: $param_type) -> Result<$return_type, rmcp::ServiceError> { |
| 144 | + let first_attempt = match &self.inner_service { |
| 145 | + InnerService::Original(rs) => rs.$method_name(param.clone()).await, |
| 146 | + InnerService::Peer(peer) => peer.$method_name(param.clone()).await, |
| 147 | + }; |
| 148 | + |
| 149 | + match first_attempt { |
| 150 | + Ok(result) => Ok(result), |
| 151 | + Err(e) => { |
| 152 | + // TODO: discern error type prior to retrying |
| 153 | + // Not entirely sure what is thrown when auth is required |
| 154 | + if let Some(auth_client) = self.get_auth_client() { |
| 155 | + let refresh_result = auth_client.auth_manager.lock().await.refresh_token().await; |
| 156 | + match refresh_result { |
| 157 | + Ok(_) => { |
| 158 | + // Retry the operation after token refresh |
| 159 | + match &self.inner_service { |
| 160 | + InnerService::Original(rs) => rs.$method_name(param).await, |
| 161 | + InnerService::Peer(peer) => peer.$method_name(param).await, |
| 162 | + } |
| 163 | + }, |
| 164 | + Err(_) => { |
| 165 | + // If refresh fails, return the original error |
| 166 | + Err(e) |
| 167 | + } |
| 168 | + } |
| 169 | + } else { |
| 170 | + // No auth client available, return original error |
| 171 | + Err(e) |
| 172 | + } |
| 173 | + }, |
| 174 | + } |
| 175 | + } |
| 176 | + }; |
158 | 177 | } |
159 | 178 |
|
160 | | -impl RunningService { |
161 | | - pub fn get_auth_client(&self) -> Option<AuthClient<Client>> { |
162 | | - self.auth_dropguard.as_ref().map(|a| a.auth_client.clone()) |
| 179 | +pub enum InnerService { |
| 180 | + Original(rmcp::service::RunningService<RoleClient, Box<dyn DynService<RoleClient>>>), |
| 181 | + Peer(rmcp::service::Peer<RoleClient>), |
| 182 | +} |
| 183 | + |
| 184 | +impl std::fmt::Debug for InnerService { |
| 185 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 186 | + match self { |
| 187 | + InnerService::Original(_) => f.debug_tuple("Original").field(&"RunningService<..>").finish(), |
| 188 | + InnerService::Peer(peer) => f.debug_tuple("Peer").field(peer).finish(), |
| 189 | + } |
163 | 190 | } |
164 | 191 | } |
165 | 192 |
|
166 | | -impl Deref for RunningService { |
167 | | - type Target = rmcp::service::RunningService<RoleClient, Box<dyn DynService<RoleClient>>>; |
| 193 | +impl Clone for InnerService { |
| 194 | + fn clone(&self) -> Self { |
| 195 | + match self { |
| 196 | + InnerService::Original(rs) => InnerService::Peer((*rs).clone()), |
| 197 | + InnerService::Peer(peer) => InnerService::Peer(peer.clone()) |
| 198 | + } |
| 199 | + } |
| 200 | +} |
168 | 201 |
|
169 | | - fn deref(&self) -> &Self::Target { |
170 | | - &self.inner_service |
| 202 | +#[derive(Debug)] |
| 203 | +pub struct RunningService { |
| 204 | + pub inner_service: InnerService, |
| 205 | + auth_dropguard: Option<AuthClientDropGuard>, |
| 206 | +} |
| 207 | + |
| 208 | +impl Clone for RunningService { |
| 209 | + fn clone(&self) -> Self { |
| 210 | + let auth_dropguard = self.auth_dropguard.as_ref().map(|dg| { |
| 211 | + let mut dg = dg.clone(); |
| 212 | + dg.should_write = false; |
| 213 | + dg |
| 214 | + }); |
| 215 | + |
| 216 | + RunningService { |
| 217 | + inner_service: self.inner_service.clone(), |
| 218 | + auth_dropguard |
| 219 | + } |
171 | 220 | } |
172 | 221 | } |
173 | 222 |
|
174 | | -impl DerefMut for RunningService { |
175 | | - fn deref_mut(&mut self) -> &mut Self::Target { |
176 | | - &mut self.inner_service |
| 223 | +impl RunningService { |
| 224 | + pub fn get_auth_client(&self) -> Option<AuthClient<Client>> { |
| 225 | + self.auth_dropguard.as_ref().map(|a| a.auth_client.clone()) |
177 | 226 | } |
| 227 | + |
| 228 | + decorate_with_auth_retry!(CallToolRequestParam, call_tool, CallToolResult); |
| 229 | + decorate_with_auth_retry!(GetPromptRequestParam, get_prompt, GetPromptResult); |
178 | 230 | } |
179 | 231 |
|
180 | 232 | pub type StdioTransport = (TokioChildProcess, Option<ChildStderr>); |
@@ -397,7 +449,7 @@ impl McpClientService { |
397 | 449 | }); |
398 | 450 |
|
399 | 451 | Ok(RunningService { |
400 | | - inner_service: service, |
| 452 | + inner_service: InnerService::Original(service), |
401 | 453 | auth_dropguard, |
402 | 454 | }) |
403 | 455 | }); |
|
0 commit comments