Skip to content

Commit 03e583b

Browse files
committed
Add support for AuthProvider in SSEClientTransport
1 parent fceebdc commit 03e583b

File tree

1 file changed

+83
-8
lines changed

1 file changed

+83
-8
lines changed

src/client/sse.ts

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource";
22
import { Transport } from "../shared/transport.js";
33
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
4+
import { auth, AuthResult, OAuthClientProvider } from "./auth.js";
45

56
export class SseError extends Error {
67
constructor(
@@ -12,6 +13,34 @@ export class SseError extends Error {
1213
}
1314
}
1415

16+
/**
17+
* Configuration options for the `SSEClientTransport`.
18+
*/
19+
export type SSEClientTransportOptions = {
20+
/**
21+
* An OAuth client provider to use for authentication.
22+
*
23+
* If given, the transport will automatically attach an `Authorization` header
24+
* if an access token is available, or begin the authorization flow if not.
25+
*/
26+
authProvider?: OAuthClientProvider;
27+
28+
/**
29+
* Customizes the initial SSE request to the server (the request that begins the stream).
30+
*
31+
* NOTE: Setting this property will prevent an `Authorization` header from
32+
* being automatically attached to the SSE request, if an `authProvider` is
33+
* also given. This can be worked around by setting the `Authorization` header
34+
* manually.
35+
*/
36+
eventSourceInit?: EventSourceInit;
37+
38+
/**
39+
* Customizes recurring POST requests to the server.
40+
*/
41+
requestInit?: RequestInit;
42+
};
43+
1544
/**
1645
* Client transport for SSE: this will connect to a server using Server-Sent Events for receiving
1746
* messages and make separate POST requests for sending messages.
@@ -23,35 +52,70 @@ export class SSEClientTransport implements Transport {
2352
private _url: URL;
2453
private _eventSourceInit?: EventSourceInit;
2554
private _requestInit?: RequestInit;
55+
private _authProvider?: OAuthClientProvider;
2656

2757
onclose?: () => void;
2858
onerror?: (error: Error) => void;
2959
onmessage?: (message: JSONRPCMessage) => void;
3060

3161
constructor(
3262
url: URL,
33-
opts?: { eventSourceInit?: EventSourceInit; requestInit?: RequestInit },
63+
opts?: SSEClientTransportOptions,
3464
) {
3565
this._url = url;
3666
this._eventSourceInit = opts?.eventSourceInit;
3767
this._requestInit = opts?.requestInit;
68+
this._authProvider = opts?.authProvider;
3869
}
3970

40-
start(): Promise<void> {
41-
if (this._eventSource) {
42-
throw new Error(
43-
"SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.",
44-
);
71+
private async _authThenStart(): Promise<void> {
72+
if (!this._authProvider) {
73+
throw new Error("No auth provider");
4574
}
4675

76+
let result: AuthResult;
77+
try {
78+
result = await auth(this._authProvider, { serverUrl: this._url });
79+
} catch (error) {
80+
this.onerror?.(error as Error);
81+
throw error;
82+
}
83+
84+
if (result !== "AUTHORIZED") {
85+
throw new Error("Unauthorized");
86+
}
87+
88+
return await this._startOrAuth();
89+
}
90+
91+
private async _commonHeaders(): Promise<HeadersInit> {
92+
const headers: HeadersInit = {};
93+
if (this._authProvider) {
94+
const tokens = await this._authProvider.tokens();
95+
if (tokens) {
96+
headers["Authorization"] = `Bearer ${tokens.access_token}`;
97+
}
98+
}
99+
100+
return headers;
101+
}
102+
103+
private _startOrAuth(): Promise<void> {
47104
return new Promise((resolve, reject) => {
48105
this._eventSource = new EventSource(
49106
this._url.href,
50-
this._eventSourceInit,
107+
this._eventSourceInit ?? {
108+
fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, { ...init, headers })),
109+
},
51110
);
52111
this._abortController = new AbortController();
53112

54113
this._eventSource.onerror = (event) => {
114+
if (event.code === 401 && this._authProvider) {
115+
this._authThenStart().then(resolve, reject);
116+
return;
117+
}
118+
55119
const error = new SseError(event.code, event.message, event);
56120
reject(error);
57121
this.onerror?.(error);
@@ -97,6 +161,16 @@ export class SSEClientTransport implements Transport {
97161
});
98162
}
99163

164+
async start() {
165+
if (this._eventSource) {
166+
throw new Error(
167+
"SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.",
168+
);
169+
}
170+
171+
return await this._startOrAuth();
172+
}
173+
100174
async close(): Promise<void> {
101175
this._abortController?.abort();
102176
this._eventSource?.close();
@@ -109,7 +183,8 @@ export class SSEClientTransport implements Transport {
109183
}
110184

111185
try {
112-
const headers = new Headers(this._requestInit?.headers);
186+
const commonHeaders = await this._commonHeaders();
187+
const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers });
113188
headers.set("content-type", "application/json");
114189
const init = {
115190
...this._requestInit,

0 commit comments

Comments
 (0)