Skip to content

Commit 688752e

Browse files
Merge pull request #139 from allenzhou101/oauth-refresh
Add Refresh Token Support for OAuth
2 parents d438760 + 1b13b57 commit 688752e

File tree

4 files changed

+112
-21
lines changed

4 files changed

+112
-21
lines changed

client/src/components/OAuthCallback.tsx

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,15 @@ const OAuthCallback = () => {
2424
}
2525

2626
try {
27-
const accessToken = await handleOAuthCallback(serverUrl, code);
28-
// Store the access token for future use
29-
sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, accessToken);
27+
const tokens = await handleOAuthCallback(serverUrl, code);
28+
// Store both access and refresh tokens
29+
sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token);
30+
if (tokens.refresh_token) {
31+
sessionStorage.setItem(
32+
SESSION_KEYS.REFRESH_TOKEN,
33+
tokens.refresh_token,
34+
);
35+
}
3036
// Redirect back to the main app with server URL to trigger auto-connect
3137
window.location.href = `/?serverUrl=${encodeURIComponent(serverUrl)}`;
3238
} catch (error) {

client/src/lib/auth.ts

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
import pkceChallenge from "pkce-challenge";
22
import { SESSION_KEYS } from "./constants";
3+
import { z } from "zod";
34

4-
export interface OAuthMetadata {
5-
authorization_endpoint: string;
6-
token_endpoint: string;
7-
}
5+
export const OAuthMetadataSchema = z.object({
6+
authorization_endpoint: z.string(),
7+
token_endpoint: z.string(),
8+
});
9+
10+
export type OAuthMetadata = z.infer<typeof OAuthMetadataSchema>;
11+
12+
export const OAuthTokensSchema = z.object({
13+
access_token: z.string(),
14+
refresh_token: z.string().optional(),
15+
expires_in: z.number().optional(),
16+
});
17+
18+
export type OAuthTokens = z.infer<typeof OAuthTokensSchema>;
819

920
export async function discoverOAuthMetadata(
1021
serverUrl: string,
@@ -15,21 +26,23 @@ export async function discoverOAuthMetadata(
1526

1627
if (response.ok) {
1728
const metadata = await response.json();
18-
return {
29+
const validatedMetadata = OAuthMetadataSchema.parse({
1930
authorization_endpoint: metadata.authorization_endpoint,
2031
token_endpoint: metadata.token_endpoint,
21-
};
32+
});
33+
return validatedMetadata;
2234
}
2335
} catch (error) {
2436
console.warn("OAuth metadata discovery failed:", error);
2537
}
2638

2739
// Fall back to default endpoints
2840
const baseUrl = new URL(serverUrl);
29-
return {
41+
const defaultMetadata = {
3042
authorization_endpoint: new URL("/authorize", baseUrl).toString(),
3143
token_endpoint: new URL("/token", baseUrl).toString(),
3244
};
45+
return OAuthMetadataSchema.parse(defaultMetadata);
3346
}
3447

3548
export async function startOAuthFlow(serverUrl: string): Promise<string> {
@@ -60,7 +73,7 @@ export async function startOAuthFlow(serverUrl: string): Promise<string> {
6073
export async function handleOAuthCallback(
6174
serverUrl: string,
6275
code: string,
63-
): Promise<string> {
76+
): Promise<OAuthTokens> {
6477
// Get stored code verifier
6578
const codeVerifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER);
6679
if (!codeVerifier) {
@@ -69,7 +82,6 @@ export async function handleOAuthCallback(
6982

7083
// Discover OAuth endpoints
7184
const metadata = await discoverOAuthMetadata(serverUrl);
72-
7385
// Exchange code for tokens
7486
const response = await fetch(metadata.token_endpoint, {
7587
method: "POST",
@@ -88,6 +100,35 @@ export async function handleOAuthCallback(
88100
throw new Error("Token exchange failed");
89101
}
90102

91-
const data = await response.json();
92-
return data.access_token;
103+
const tokens = await response.json();
104+
return OAuthTokensSchema.parse(tokens);
105+
}
106+
107+
export async function refreshAccessToken(
108+
serverUrl: string,
109+
): Promise<OAuthTokens> {
110+
const refreshToken = sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN);
111+
if (!refreshToken) {
112+
throw new Error("No refresh token available");
113+
}
114+
115+
const metadata = await discoverOAuthMetadata(serverUrl);
116+
117+
const response = await fetch(metadata.token_endpoint, {
118+
method: "POST",
119+
headers: {
120+
"Content-Type": "application/json",
121+
},
122+
body: JSON.stringify({
123+
grant_type: "refresh_token",
124+
refresh_token: refreshToken,
125+
}),
126+
});
127+
128+
if (!response.ok) {
129+
throw new Error("Token refresh failed");
130+
}
131+
132+
const tokens = await response.json();
133+
return OAuthTokensSchema.parse(tokens);
93134
}

client/src/lib/constants.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ export const SESSION_KEYS = {
33
CODE_VERIFIER: "mcp_code_verifier",
44
SERVER_URL: "mcp_server_url",
55
ACCESS_TOKEN: "mcp_access_token",
6+
REFRESH_TOKEN: "mcp_refresh_token",
67
} as const;

client/src/lib/hooks/useConnection.ts

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import {
1616
import { useState } from "react";
1717
import { toast } from "react-toastify";
1818
import { z } from "zod";
19-
import { startOAuthFlow } from "../auth";
19+
import { startOAuthFlow, refreshAccessToken } from "../auth";
2020
import { SESSION_KEYS } from "../constants";
2121
import { Notification, StdErrNotificationSchema } from "../notificationTypes";
2222

@@ -121,7 +121,49 @@ export function useConnection({
121121
}
122122
};
123123

124-
const connect = async () => {
124+
const initiateOAuthFlow = async () => {
125+
sessionStorage.removeItem(SESSION_KEYS.ACCESS_TOKEN);
126+
sessionStorage.removeItem(SESSION_KEYS.REFRESH_TOKEN);
127+
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl);
128+
const redirectUrl = await startOAuthFlow(sseUrl);
129+
window.location.href = redirectUrl;
130+
};
131+
132+
const handleTokenRefresh = async () => {
133+
try {
134+
const tokens = await refreshAccessToken(sseUrl);
135+
sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token);
136+
if (tokens.refresh_token) {
137+
sessionStorage.setItem(
138+
SESSION_KEYS.REFRESH_TOKEN,
139+
tokens.refresh_token,
140+
);
141+
}
142+
return tokens.access_token;
143+
} catch (error) {
144+
console.error("Token refresh failed:", error);
145+
await initiateOAuthFlow();
146+
throw error;
147+
}
148+
};
149+
150+
const handleAuthError = async (error: unknown) => {
151+
if (error instanceof SseError && error.code === 401) {
152+
if (sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN)) {
153+
try {
154+
await handleTokenRefresh();
155+
return true;
156+
} catch (error) {
157+
console.error("Token refresh failed:", error);
158+
}
159+
} else {
160+
await initiateOAuthFlow();
161+
}
162+
}
163+
return false;
164+
};
165+
166+
const connect = async (_e?: unknown, retryCount: number = 0) => {
125167
try {
126168
const client = new Client<Request, Notification, Result>(
127169
{
@@ -182,14 +224,15 @@ export function useConnection({
182224
await client.connect(clientTransport);
183225
} catch (error) {
184226
console.error("Failed to connect to MCP server:", error);
227+
const shouldRetry = await handleAuthError(error);
228+
if (shouldRetry) {
229+
return connect(undefined, retryCount + 1);
230+
}
231+
185232
if (error instanceof SseError && error.code === 401) {
186-
// Store the server URL for the callback handler
187-
sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl);
188-
const redirectUrl = await startOAuthFlow(sseUrl);
189-
window.location.href = redirectUrl;
233+
// Don't set error state if we're about to redirect for auth
190234
return;
191235
}
192-
193236
throw error;
194237
}
195238

0 commit comments

Comments
 (0)