Skip to content
23 changes: 23 additions & 0 deletions src/client/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,29 @@ describe("SSEClientTransport", () => {
expect(mockAuthProvider.tokens).toHaveBeenCalled();
});

it("attaches custom header from provider on initial SSE connection", async () => {
mockAuthProvider.tokens.mockResolvedValue({
access_token: "test-token",
token_type: "Bearer"
});
const customHeaders = {
"X-Custom-Header": "custom-value",
};

transport = new SSEClientTransport(baseUrl, {
authProvider: mockAuthProvider,
requestInit: {
headers: customHeaders,
},
});

await transport.start();

expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value");
expect(mockAuthProvider.tokens).toHaveBeenCalled();
});

it("attaches auth header from provider on POST requests", async () => {
mockAuthProvider.tokens.mockResolvedValue({
access_token: "test-token",
Expand Down
18 changes: 11 additions & 7 deletions src/client/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ export class SSEClientTransport implements Transport {
this._eventSource = new EventSource(
this._url.href,
this._eventSourceInit ?? {
fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, {
...init,
headers: {
...headers,
Accept: "text/event-stream"
}
})),
fetch: async (url, init) => {
const commonHeaders = await this._commonHeaders();
const allHeaders = { ...commonHeaders, ...this._requestInit?.headers};
return fetch(url, {
...init,
headers: {
...allHeaders,
Accept: "text/event-stream"
}
})
}
},
);
this._abortController = new AbortController();
Expand Down