diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index 5e592bded..df92616c4 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -1,52 +1,53 @@ +import { useToast } from "@/lib/hooks/useToast"; +import { + getMCPProxyAddress, + getMCPProxyAuthToken, + getMCPServerRequestMaxTotalTimeout, + getMCPServerRequestTimeout, + resetRequestTimeoutOnProgress, +} from "@/utils/configUtils"; +import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { SSEClientTransport, - SseError, SSEClientTransportOptions, + SseError, } from "@modelcontextprotocol/sdk/client/sse.js"; import { StreamableHTTPClientTransport, StreamableHTTPClientTransportOptions, } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; +import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; import { + CancelledNotificationSchema, ClientNotification, ClientRequest, + CompleteResultSchema, CreateMessageRequestSchema, + ErrorCode, ListRootsRequestSchema, - ResourceUpdatedNotificationSchema, LoggingMessageNotificationSchema, + McpError, + Progress, + PromptListChangedNotificationSchema, + PromptReference, Request, + ResourceListChangedNotificationSchema, + ResourceReference, + ResourceUpdatedNotificationSchema, Result, ServerCapabilities, - PromptReference, - ResourceReference, - McpError, - CompleteResultSchema, - ErrorCode, - CancelledNotificationSchema, - ResourceListChangedNotificationSchema, ToolListChangedNotificationSchema, - PromptListChangedNotificationSchema, - Progress, } from "@modelcontextprotocol/sdk/types.js"; -import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js"; import { useState } from "react"; -import { useToast } from "@/lib/hooks/useToast"; import { z } from "zod"; -import { ConnectionStatus } from "../constants"; -import { Notification, StdErrNotificationSchema } from "../notificationTypes"; -import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; -import { InspectorOAuthClientProvider } from "../auth"; import packageJson from "../../../package.json"; -import { - getMCPProxyAddress, - getMCPServerRequestMaxTotalTimeout, - resetRequestTimeoutOnProgress, - getMCPProxyAuthToken, -} from "@/utils/configUtils"; -import { getMCPServerRequestTimeout } from "@/utils/configUtils"; +import { InspectorOAuthClientProvider } from "../auth"; import { InspectorConfig } from "../configurationTypes"; -import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; +import { ConnectionStatus } from "../constants"; +import { Notification, StdErrNotificationSchema } from "../notificationTypes"; +import { isTokenExpired } from "../token-utils"; interface UseConnectionOptions { transportType: "stdio" | "sse" | "streamable-http"; @@ -275,12 +276,32 @@ export function useConnection({ ); }; + const ensureValidToken = async ( + authProvider: InspectorOAuthClientProvider, + ) => { + try { + const tokens = await authProvider.tokens(); + + // If no tokens exist, initiate authorization flow + if (!tokens || isTokenExpired(tokens)) { + const result = await auth(authProvider, { serverUrl: sseUrl }); + return result === "AUTHORIZED"; + } + + return true; // Token is still valid + } catch (error) { + console.error("Token refresh/authorization failed:", error); + return false; + } + }; + const handleAuthError = async (error: unknown) => { if (is401Error(error)) { const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl); - const result = await auth(serverAuthProvider, { serverUrl: sseUrl }); - return result === "AUTHORIZED"; + // Try to use existing tokens (refresh if needed) or initiate new authorization flow + const hasValidToken = await ensureValidToken(serverAuthProvider); + return hasValidToken; } return false; @@ -318,8 +339,13 @@ export function useConnection({ const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl); // Use manually provided bearer token if available, otherwise use OAuth tokens - const token = - bearerToken || (await serverAuthProvider.tokens())?.access_token; + let token = bearerToken; + if (!token) { + // Ensure we have a valid token before proceeding + await ensureValidToken(serverAuthProvider); + token = (await serverAuthProvider.tokens())?.access_token; + } + if (token) { const authHeaderName = headerName || "Authorization"; diff --git a/client/src/lib/token-utils.ts b/client/src/lib/token-utils.ts new file mode 100644 index 000000000..0125457b7 --- /dev/null +++ b/client/src/lib/token-utils.ts @@ -0,0 +1,29 @@ +import { OAuthTokens } from "@modelcontextprotocol/sdk/shared/auth.js"; + +export function isTokenExpired(tokens: OAuthTokens & { issued_at?: number }) { + try { + if (!tokens.access_token) { + console.warn("No access_token provided"); + return true; + } + const jwtParts = tokens.access_token.split("."); + if (jwtParts.length !== 3) { + console.warn("Invalid JWT format"); + return true; + } + const payload = JSON.parse( + atob(jwtParts[1].replace(/-/g, "+").replace(/_/g, "/")), + ); + const exp = Number(payload.exp); + if (isNaN(exp)) { + console.warn("exp field in JWT payload is not a number"); + return true; + } + return Date.now() / 1000 >= exp; + } catch (err) { + console.warn( + `Failed to verify token expiration: ${(err as Error).message}`, + ); + return true; + } +}