Skip to content

Commit b1f4b65

Browse files
committed
Improve auth docs and add finishAuth convenience method
1 parent 44a4408 commit b1f4b65

File tree

3 files changed

+47
-19
lines changed

3 files changed

+47
-19
lines changed

src/client/auth.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ export interface OAuthClientProvider {
147147

148148
export type AuthResult = "AUTHORIZED" | "REDIRECT";
149149

150+
export class UnauthorizedError extends Error {
151+
constructor(message?: string) {
152+
super(message ?? "Unauthorized");
153+
}
154+
}
155+
150156
/**
151157
* Orchestrates the full auth flow with a server.
152158
*

src/client/sse.test.ts

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http";
22
import { AddressInfo } from "net";
33
import { JSONRPCMessage } from "../types.js";
44
import { SSEClientTransport } from "./sse.js";
5-
import { OAuthClientProvider, OAuthTokens } from "./auth.js";
5+
import { OAuthClientProvider, OAuthTokens, UnauthorizedError } from "./auth.js";
66

77
describe("SSEClientTransport", () => {
88
let server: Server;
@@ -376,7 +376,7 @@ describe("SSEClientTransport", () => {
376376
authProvider: mockAuthProvider,
377377
});
378378

379-
await expect(() => transport.start()).rejects.toThrow("Unauthorized");
379+
await expect(() => transport.start()).rejects.toThrow(UnauthorizedError);
380380
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
381381
});
382382

@@ -431,7 +431,7 @@ describe("SSEClientTransport", () => {
431431
params: {},
432432
};
433433

434-
await expect(() => transport.send(message)).rejects.toThrow("Unauthorized");
434+
await expect(() => transport.send(message)).rejects.toThrow(UnauthorizedError);
435435
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
436436
});
437437

@@ -485,17 +485,17 @@ describe("SSEClientTransport", () => {
485485
let connectionAttempts = 0;
486486
server = createServer((req, res) => {
487487
lastServerRequest = req;
488-
488+
489489
if (req.url === "/token" && req.method === "POST") {
490490
// Handle token refresh request
491491
let body = "";
492492
req.on("data", chunk => { body += chunk; });
493493
req.on("end", () => {
494494
const params = new URLSearchParams(body);
495-
if (params.get("grant_type") === "refresh_token" &&
496-
params.get("refresh_token") === "refresh-token" &&
497-
params.get("client_id") === "test-client-id" &&
498-
params.get("client_secret") === "test-client-secret") {
495+
if (params.get("grant_type") === "refresh_token" &&
496+
params.get("refresh_token") === "refresh-token" &&
497+
params.get("client_id") === "test-client-id" &&
498+
params.get("client_secret") === "test-client-secret") {
499499
res.writeHead(200, { "Content-Type": "application/json" });
500500
res.end(JSON.stringify({
501501
access_token: "new-token",
@@ -583,10 +583,10 @@ describe("SSEClientTransport", () => {
583583
req.on("data", chunk => { body += chunk; });
584584
req.on("end", () => {
585585
const params = new URLSearchParams(body);
586-
if (params.get("grant_type") === "refresh_token" &&
587-
params.get("refresh_token") === "refresh-token" &&
588-
params.get("client_id") === "test-client-id" &&
589-
params.get("client_secret") === "test-client-secret") {
586+
if (params.get("grant_type") === "refresh_token" &&
587+
params.get("refresh_token") === "refresh-token" &&
588+
params.get("client_id") === "test-client-id" &&
589+
params.get("client_secret") === "test-client-secret") {
590590
res.writeHead(200, { "Content-Type": "application/json" });
591591
res.end(JSON.stringify({
592592
access_token: "new-token",
@@ -715,7 +715,7 @@ describe("SSEClientTransport", () => {
715715
authProvider: mockAuthProvider,
716716
});
717717

718-
await expect(transport.start()).rejects.toThrow("Unauthorized");
718+
await expect(() => transport.start()).rejects.toThrow(UnauthorizedError);
719719
expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled();
720720
});
721721
});

src/client/sse.ts

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource";
22
import { Transport } from "../shared/transport.js";
33
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
4-
import { auth, AuthResult, OAuthClientProvider } from "./auth.js";
4+
import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js";
55

66
export class SseError extends Error {
77
constructor(
@@ -20,8 +20,16 @@ export type SSEClientTransportOptions = {
2020
/**
2121
* An OAuth client provider to use for authentication.
2222
*
23-
* If given, the transport will automatically attach an `Authorization` header
24-
* if an access token is available, or begin the authorization flow if not.
23+
* When an `authProvider` is specified and the SSE connection is started:
24+
* 1. The connection is attempted with any existing access token from the `authProvider`.
25+
* 2. If the access token has expired, the `authProvider` is used to refresh the token.
26+
* 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`.
27+
*
28+
* After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `SSEClientTransport.finishAuth` with the authorization code before retrying the connection.
29+
*
30+
* If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown.
31+
*
32+
* `UnauthorizedError` might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected.
2533
*/
2634
authProvider?: OAuthClientProvider;
2735

@@ -70,7 +78,7 @@ export class SSEClientTransport implements Transport {
7078

7179
private async _authThenStart(): Promise<void> {
7280
if (!this._authProvider) {
73-
throw new Error("No auth provider");
81+
throw new UnauthorizedError("No auth provider");
7482
}
7583

7684
let result: AuthResult;
@@ -82,7 +90,7 @@ export class SSEClientTransport implements Transport {
8290
}
8391

8492
if (result !== "AUTHORIZED") {
85-
throw new Error("Unauthorized");
93+
throw new UnauthorizedError();
8694
}
8795

8896
return await this._startOrAuth();
@@ -177,6 +185,20 @@ export class SSEClientTransport implements Transport {
177185
return await this._startOrAuth();
178186
}
179187

188+
/**
189+
* Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth.
190+
*/
191+
async finishAuth(authorizationCode: string): Promise<void> {
192+
if (!this._authProvider) {
193+
throw new UnauthorizedError("No auth provider");
194+
}
195+
196+
const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode });
197+
if (result !== "AUTHORIZED") {
198+
throw new UnauthorizedError("Failed to authorize");
199+
}
200+
}
201+
180202
async close(): Promise<void> {
181203
this._abortController?.abort();
182204
this._eventSource?.close();
@@ -205,7 +227,7 @@ export class SSEClientTransport implements Transport {
205227
if (response.status === 401 && this._authProvider) {
206228
const result = await auth(this._authProvider, { serverUrl: this._url });
207229
if (result !== "AUTHORIZED") {
208-
throw new Error("Unauthorized");
230+
throw new UnauthorizedError();
209231
}
210232

211233
// Purposely _not_ awaited, so we don't call onerror twice

0 commit comments

Comments
 (0)