Skip to content

Commit f2f030e

Browse files
authored
Merge branch 'main' into update-protocol-version
2 parents cada5a4 + 7e18c70 commit f2f030e

File tree

5 files changed

+60
-17
lines changed

5 files changed

+60
-17
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,13 @@ app.post('/mcp', async (req: Request, res: Response) => {
321321
const transport: StreamableHTTPServerTransport = new StreamableHTTPServerTransport({
322322
sessionIdGenerator: undefined,
323323
});
324-
await server.connect(transport);
325-
await transport.handleRequest(req, res, req.body);
326324
res.on('close', () => {
327325
console.log('Request closed');
328326
transport.close();
329327
server.close();
330328
});
329+
await server.connect(transport);
330+
await transport.handleRequest(req, res, req.body);
331331
} catch (error) {
332332
console.error('Error handling MCP request:', error);
333333
if (!res.headersSent) {

src/client/streamableHttp.test.ts

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js";
2+
import { OAuthClientProvider, UnauthorizedError } from "./auth.js";
23
import { JSONRPCMessage } from "../types.js";
34

45

56
describe("StreamableHTTPClientTransport", () => {
67
let transport: StreamableHTTPClientTransport;
8+
let mockAuthProvider: jest.Mocked<OAuthClientProvider>;
79

810
beforeEach(() => {
9-
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"));
11+
mockAuthProvider = {
12+
get redirectUrl() { return "http://localhost/callback"; },
13+
get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; },
14+
clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret" })),
15+
tokens: jest.fn(),
16+
saveTokens: jest.fn(),
17+
redirectToAuthorization: jest.fn(),
18+
saveCodeVerifier: jest.fn(),
19+
codeVerifier: jest.fn(),
20+
};
21+
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { authProvider: mockAuthProvider });
1022
jest.spyOn(global, "fetch");
1123
});
1224

@@ -497,4 +509,27 @@ describe("StreamableHTTPClientTransport", () => {
497509
expect(getDelay(10)).toBe(5000);
498510
});
499511

512+
it("attempts auth flow on 401 during POST request", async () => {
513+
const message: JSONRPCMessage = {
514+
jsonrpc: "2.0",
515+
method: "test",
516+
params: {},
517+
id: "test-id"
518+
};
519+
520+
(global.fetch as jest.Mock)
521+
.mockResolvedValueOnce({
522+
ok: false,
523+
status: 401,
524+
statusText: "Unauthorized",
525+
headers: new Headers()
526+
})
527+
.mockResolvedValue({
528+
ok: false,
529+
status: 404
530+
});
531+
532+
await expect(transport.send(message)).rejects.toThrow(UnauthorizedError);
533+
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
534+
});
500535
});

src/server/streamableHttp.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ export interface StreamableHTTPServerTransportOptions {
9393
* - State is maintained in-memory (connections, message history)
9494
*
9595
* In stateless mode:
96-
* - Session ID is only included in initialization responses
96+
* - No Session ID is included in any responses
9797
* - No session validation is performed
9898
*/
9999
export class StreamableHTTPServerTransport implements Transport {
@@ -166,7 +166,7 @@ export class StreamableHTTPServerTransport implements Transport {
166166
}
167167

168168
// If an Mcp-Session-Id is returned by the server during initialization,
169-
// clients using the Streamable HTTP transport MUST include it
169+
// clients using the Streamable HTTP transport MUST include it
170170
// in the Mcp-Session-Id header on all of their subsequent HTTP requests.
171171
if (!this.validateSession(req, res)) {
172172
return;
@@ -180,7 +180,7 @@ export class StreamableHTTPServerTransport implements Transport {
180180
}
181181
}
182182

183-
// The server MUST either return Content-Type: text/event-stream in response to this HTTP GET,
183+
// The server MUST either return Content-Type: text/event-stream in response to this HTTP GET,
184184
// or else return HTTP 405 Method Not Allowed
185185
const headers: Record<string, string> = {
186186
"Content-Type": "text/event-stream",
@@ -587,7 +587,7 @@ export class StreamableHTTPServerTransport implements Transport {
587587
}
588588
}
589589

590-
if (isJSONRPCResponse(message)) {
590+
if (isJSONRPCResponse(message) || isJSONRPCError(message)) {
591591
this._requestResponseMap.set(requestId, message);
592592
const relatedIds = Array.from(this._requestToStreamMapping.entries())
593593
.filter(([_, streamId]) => this._streamMapping.get(streamId) === response)

src/shared/protocol.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import {
2121
RequestId,
2222
Result,
2323
ServerCapabilities,
24+
RequestMeta,
2425
} from "../types.js";
2526
import { Transport, TransportSendOptions } from "./transport.js";
2627
import { AuthInfo } from "../server/auth/types.js";
@@ -115,6 +116,11 @@ export type RequestHandlerExtra<SendRequestT extends Request,
115116
*/
116117
sessionId?: string;
117118

119+
/**
120+
* Metadata from the original request.
121+
*/
122+
_meta?: RequestMeta;
123+
118124
/**
119125
* The JSON-RPC ID of the request being handled.
120126
* This can be useful for tracking or logging purposes.
@@ -361,6 +367,7 @@ export abstract class Protocol<
361367
const fullExtra: RequestHandlerExtra<SendRequestT, SendNotificationT> = {
362368
signal: abortController.signal,
363369
sessionId: this._transport?.sessionId,
370+
_meta: request.params?._meta,
364371
sendNotification:
365372
(notification) =>
366373
this.notification(notification, { relatedRequestId: request.id }),

src/types.ts

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@ export const ProgressTokenSchema = z.union([z.string(), z.number().int()]);
2020
*/
2121
export const CursorSchema = z.string();
2222

23+
const RequestMetaSchema = z
24+
.object({
25+
/**
26+
* If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications.
27+
*/
28+
progressToken: z.optional(ProgressTokenSchema),
29+
})
30+
.passthrough();
31+
2332
const BaseRequestParamsSchema = z
2433
.object({
25-
_meta: z.optional(
26-
z
27-
.object({
28-
/**
29-
* If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications.
30-
*/
31-
progressToken: z.optional(ProgressTokenSchema),
32-
})
33-
.passthrough(),
34-
),
34+
_meta: z.optional(RequestMetaSchema),
3535
})
3636
.passthrough();
3737

@@ -1241,6 +1241,7 @@ type Infer<Schema extends ZodTypeAny> = Flatten<z.infer<Schema>>;
12411241
export type ProgressToken = Infer<typeof ProgressTokenSchema>;
12421242
export type Cursor = Infer<typeof CursorSchema>;
12431243
export type Request = Infer<typeof RequestSchema>;
1244+
export type RequestMeta = Infer<typeof RequestMetaSchema>;
12441245
export type Notification = Infer<typeof NotificationSchema>;
12451246
export type Result = Infer<typeof ResultSchema>;
12461247
export type RequestId = Infer<typeof RequestIdSchema>;

0 commit comments

Comments
 (0)