Skip to content

Commit d47d88b

Browse files
committed
Auth tests for SSEClientTransport
1 parent 3e2dd35 commit d47d88b

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed

src/client/sse.test.ts

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http";
22
import { AddressInfo } from "net";
33
import { JSONRPCMessage } from "../types.js";
44
import { SSEClientTransport } from "./sse.js";
5+
import { auth, OAuthClientProvider } from "./auth.js";
56

67
describe("SSEClientTransport", () => {
78
let server: Server;
@@ -284,4 +285,180 @@ describe("SSEClientTransport", () => {
284285
expect(calledHeaders.get("content-type")).toBe("application/json");
285286
});
286287
});
288+
289+
describe("auth handling", () => {
290+
let mockAuthProvider: jest.Mocked<OAuthClientProvider>;
291+
292+
beforeEach(() => {
293+
mockAuthProvider = {
294+
get redirectUrl() { return "http://localhost/callback"; },
295+
get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; },
296+
clientInformation: jest.fn(() => ({ client_id: "test-client-id" })),
297+
tokens: jest.fn(),
298+
saveTokens: jest.fn(),
299+
redirectToAuthorization: jest.fn(),
300+
saveCodeVerifier: jest.fn(),
301+
codeVerifier: jest.fn(),
302+
};
303+
});
304+
305+
it("attaches auth header from provider on SSE connection", async () => {
306+
mockAuthProvider.tokens.mockResolvedValue({
307+
access_token: "test-token",
308+
token_type: "Bearer"
309+
});
310+
311+
transport = new SSEClientTransport(baseUrl, {
312+
authProvider: mockAuthProvider,
313+
});
314+
315+
await transport.start();
316+
317+
expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
318+
expect(mockAuthProvider.tokens).toHaveBeenCalled();
319+
});
320+
321+
it("attaches auth header from provider on POST requests", async () => {
322+
mockAuthProvider.tokens.mockResolvedValue({
323+
access_token: "test-token",
324+
token_type: "Bearer"
325+
});
326+
327+
transport = new SSEClientTransport(baseUrl, {
328+
authProvider: mockAuthProvider,
329+
});
330+
331+
await transport.start();
332+
333+
const message: JSONRPCMessage = {
334+
jsonrpc: "2.0",
335+
id: "1",
336+
method: "test",
337+
params: {},
338+
};
339+
340+
await transport.send(message);
341+
342+
expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
343+
expect(mockAuthProvider.tokens).toHaveBeenCalled();
344+
});
345+
346+
it("attempts auth flow on 401 during SSE connection", async () => {
347+
// Create server that returns 401s
348+
server.close();
349+
await new Promise(resolve => server.on("close", resolve));
350+
351+
server = createServer((req, res) => {
352+
lastServerRequest = req;
353+
if (req.url !== "/") {
354+
res.writeHead(404).end();
355+
} else {
356+
res.writeHead(401).end();
357+
}
358+
});
359+
360+
await new Promise<void>(resolve => {
361+
server.listen(0, "127.0.0.1", () => {
362+
const addr = server.address() as AddressInfo;
363+
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
364+
resolve();
365+
});
366+
});
367+
368+
transport = new SSEClientTransport(baseUrl, {
369+
authProvider: mockAuthProvider,
370+
});
371+
372+
await expect(() => transport.start()).rejects.toThrow("Unauthorized");
373+
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
374+
});
375+
376+
it("attempts auth flow on 401 during POST request", async () => {
377+
// Create server that accepts SSE but returns 401 on POST
378+
server.close();
379+
await new Promise(resolve => server.on("close", resolve));
380+
381+
server = createServer((req, res) => {
382+
lastServerRequest = req;
383+
384+
switch (req.method) {
385+
case "GET":
386+
if (req.url !== "/") {
387+
res.writeHead(404).end();
388+
return;
389+
}
390+
391+
res.writeHead(200, {
392+
"Content-Type": "text/event-stream",
393+
"Cache-Control": "no-cache",
394+
Connection: "keep-alive",
395+
});
396+
res.write("event: endpoint\n");
397+
res.write(`data: ${baseUrl.href}\n\n`);
398+
break;
399+
400+
case "POST":
401+
res.writeHead(401);
402+
res.end();
403+
break;
404+
}
405+
});
406+
407+
await new Promise<void>(resolve => {
408+
server.listen(0, "127.0.0.1", () => {
409+
const addr = server.address() as AddressInfo;
410+
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
411+
resolve();
412+
});
413+
});
414+
415+
transport = new SSEClientTransport(baseUrl, {
416+
authProvider: mockAuthProvider,
417+
});
418+
419+
await transport.start();
420+
421+
const message: JSONRPCMessage = {
422+
jsonrpc: "2.0",
423+
id: "1",
424+
method: "test",
425+
params: {},
426+
};
427+
428+
await expect(() => transport.send(message)).rejects.toThrow("Unauthorized");
429+
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
430+
});
431+
432+
it("respects custom headers when using auth provider", async () => {
433+
mockAuthProvider.tokens.mockResolvedValue({
434+
access_token: "test-token",
435+
token_type: "Bearer"
436+
});
437+
438+
const customHeaders = {
439+
"X-Custom-Header": "custom-value",
440+
};
441+
442+
transport = new SSEClientTransport(baseUrl, {
443+
authProvider: mockAuthProvider,
444+
requestInit: {
445+
headers: customHeaders,
446+
},
447+
});
448+
449+
await transport.start();
450+
451+
const message: JSONRPCMessage = {
452+
jsonrpc: "2.0",
453+
id: "1",
454+
method: "test",
455+
params: {},
456+
};
457+
458+
await transport.send(message);
459+
460+
expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
461+
expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value");
462+
});
463+
});
287464
});

0 commit comments

Comments
 (0)