Skip to content

Commit f64601b

Browse files
committed
optimization: When using the custom SSE request, the Authorization header can still be automatically attached to the SSE request.
1 parent bced33d commit f64601b

File tree

2 files changed

+68
-16
lines changed

2 files changed

+68
-16
lines changed

src/client/sse.test.ts

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { createServer, type IncomingMessage, type Server } from "http";
1+
import { createServer, IncomingMessage, Server, ServerResponse } from "http";
22
import { AddressInfo } from "net";
33
import { JSONRPCMessage } from "../types.js";
44
import { SSEClientTransport } from "./sse.js";
@@ -10,8 +10,21 @@ describe("SSEClientTransport", () => {
1010
let transport: SSEClientTransport;
1111
let baseUrl: URL;
1212
let lastServerRequest: IncomingMessage;
13+
const serverRequests: Record<string, IncomingMessage[]> = {};
1314
let sendServerMessage: ((message: string) => void) | null = null;
1415

16+
const recordServerRequest = (req: IncomingMessage, res: ServerResponse) => {
17+
lastServerRequest = req;
18+
19+
const key = `${req.method} ${req.url}`;
20+
serverRequests[key] = serverRequests[key] || [];
21+
serverRequests[key].push(req);
22+
23+
res.on('finish', () => {
24+
console.log(`[server] ${req.method} ${req.url} -> ${res.statusCode} ${res.statusMessage}`);
25+
});
26+
};
27+
1528
beforeEach((done) => {
1629
// Reset state
1730
lastServerRequest = null as unknown as IncomingMessage;
@@ -487,7 +500,7 @@ describe("SSEClientTransport", () => {
487500

488501
let connectionAttempts = 0;
489502
server = createServer((req, res) => {
490-
lastServerRequest = req;
503+
recordServerRequest(req, res);
491504

492505
if (req.url === "/token" && req.method === "POST") {
493506
// Handle token refresh request
@@ -496,7 +509,7 @@ describe("SSEClientTransport", () => {
496509
req.on("end", () => {
497510
const params = new URLSearchParams(body);
498511
if (params.get("grant_type") === "refresh_token" &&
499-
params.get("refresh_token") === "refresh-token" &&
512+
params.get("refresh_token")?.includes("refresh-token") &&
500513
params.get("client_id") === "test-client-id" &&
501514
params.get("client_secret") === "test-client-secret") {
502515
res.writeHead(200, { "Content-Type": "application/json" });
@@ -531,6 +544,7 @@ describe("SSEClientTransport", () => {
531544
});
532545
res.write("event: endpoint\n");
533546
res.write(`data: ${baseUrl.href}\n\n`);
547+
res.end();
534548
connectionAttempts++;
535549
return;
536550
}
@@ -548,6 +562,14 @@ describe("SSEClientTransport", () => {
548562

549563
transport = new SSEClientTransport(baseUrl, {
550564
authProvider: mockAuthProvider,
565+
eventSourceInit: {
566+
fetch: (url, init) => {
567+
return fetch(url, { ...init, headers: {
568+
...init?.headers,
569+
'X-Custom-Header': 'custom-value'
570+
} });
571+
}
572+
},
551573
});
552574

553575
await transport.start();
@@ -559,6 +581,9 @@ describe("SSEClientTransport", () => {
559581
});
560582
expect(connectionAttempts).toBe(1);
561583
expect(lastServerRequest.headers.authorization).toBe("Bearer new-token");
584+
expect(serverRequests["GET /"]).toHaveLength(2);
585+
expect(serverRequests["GET /"]
586+
.every(req => req.headers["x-custom-header"] === "custom-value")).toBe(true);
562587
});
563588

564589
it("refreshes expired token during POST request", async () => {

src/client/sse.ts

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ export class SSEClientTransport implements Transport {
9696
return await this._startOrAuth();
9797
}
9898

99-
private async _commonHeaders(): Promise<HeadersInit> {
99+
private async _commonHeaders(): Promise<Record<string, string>> {
100100
const headers: HeadersInit = {};
101101
if (this._authProvider) {
102102
const tokens = await this._authProvider.tokens();
@@ -110,18 +110,7 @@ export class SSEClientTransport implements Transport {
110110

111111
private _startOrAuth(): Promise<void> {
112112
return new Promise((resolve, reject) => {
113-
this._eventSource = new EventSource(
114-
this._url.href,
115-
this._eventSourceInit ?? {
116-
fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, {
117-
...init,
118-
headers: {
119-
...headers,
120-
Accept: "text/event-stream"
121-
}
122-
})),
123-
},
124-
);
113+
this._eventSource = new EventSource(this._url.href, this._getEventSourceInit());
125114
this._abortController = new AbortController();
126115

127116
this._eventSource.onerror = (event) => {
@@ -175,6 +164,44 @@ export class SSEClientTransport implements Transport {
175164
});
176165
}
177166

167+
private _getEventSourceInit(): EventSourceInit {
168+
let eventSourceInit: EventSourceInit;
169+
170+
if (this._eventSourceInit) {
171+
const originalFetch = this._eventSourceInit.fetch;
172+
173+
if (originalFetch && this._authProvider) {
174+
// merge the new headers with the existing headers
175+
eventSourceInit = {
176+
...this._eventSourceInit,
177+
fetch: async (url, init) => {
178+
const newHeaders: Record<string, string> = await this._commonHeaders();
179+
return originalFetch(url, {
180+
...init,
181+
headers: {
182+
...newHeaders,
183+
...init?.headers
184+
}
185+
});
186+
}
187+
};
188+
} else {
189+
eventSourceInit = this._eventSourceInit;
190+
}
191+
} else {
192+
eventSourceInit = {
193+
fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, {
194+
...init,
195+
headers: {
196+
...headers,
197+
Accept: "text/event-stream"
198+
}
199+
})),
200+
};
201+
}
202+
return eventSourceInit;
203+
}
204+
178205
async start() {
179206
if (this._eventSource) {
180207
throw new Error(

0 commit comments

Comments
 (0)