@@ -2,6 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http";
2
2
import { AddressInfo } from "net" ;
3
3
import { JSONRPCMessage } from "../types.js" ;
4
4
import { SSEClientTransport } from "./sse.js" ;
5
+ import { auth , OAuthClientProvider } from "./auth.js" ;
5
6
6
7
describe ( "SSEClientTransport" , ( ) => {
7
8
let server : Server ;
@@ -284,4 +285,180 @@ describe("SSEClientTransport", () => {
284
285
expect ( calledHeaders . get ( "content-type" ) ) . toBe ( "application/json" ) ;
285
286
} ) ;
286
287
} ) ;
288
+
289
+ describe ( "auth handling" , ( ) => {
290
+ let mockAuthProvider : jest . Mocked < OAuthClientProvider > ;
291
+
292
+ beforeEach ( ( ) => {
293
+ mockAuthProvider = {
294
+ get redirectUrl ( ) { return "http://localhost/callback" ; } ,
295
+ get clientMetadata ( ) { return { redirect_uris : [ "http://localhost/callback" ] } ; } ,
296
+ clientInformation : jest . fn ( ( ) => ( { client_id : "test-client-id" } ) ) ,
297
+ tokens : jest . fn ( ) ,
298
+ saveTokens : jest . fn ( ) ,
299
+ redirectToAuthorization : jest . fn ( ) ,
300
+ saveCodeVerifier : jest . fn ( ) ,
301
+ codeVerifier : jest . fn ( ) ,
302
+ } ;
303
+ } ) ;
304
+
305
+ it ( "attaches auth header from provider on SSE connection" , async ( ) => {
306
+ mockAuthProvider . tokens . mockResolvedValue ( {
307
+ access_token : "test-token" ,
308
+ token_type : "Bearer"
309
+ } ) ;
310
+
311
+ transport = new SSEClientTransport ( baseUrl , {
312
+ authProvider : mockAuthProvider ,
313
+ } ) ;
314
+
315
+ await transport . start ( ) ;
316
+
317
+ expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer test-token" ) ;
318
+ expect ( mockAuthProvider . tokens ) . toHaveBeenCalled ( ) ;
319
+ } ) ;
320
+
321
+ it ( "attaches auth header from provider on POST requests" , async ( ) => {
322
+ mockAuthProvider . tokens . mockResolvedValue ( {
323
+ access_token : "test-token" ,
324
+ token_type : "Bearer"
325
+ } ) ;
326
+
327
+ transport = new SSEClientTransport ( baseUrl , {
328
+ authProvider : mockAuthProvider ,
329
+ } ) ;
330
+
331
+ await transport . start ( ) ;
332
+
333
+ const message : JSONRPCMessage = {
334
+ jsonrpc : "2.0" ,
335
+ id : "1" ,
336
+ method : "test" ,
337
+ params : { } ,
338
+ } ;
339
+
340
+ await transport . send ( message ) ;
341
+
342
+ expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer test-token" ) ;
343
+ expect ( mockAuthProvider . tokens ) . toHaveBeenCalled ( ) ;
344
+ } ) ;
345
+
346
+ it ( "attempts auth flow on 401 during SSE connection" , async ( ) => {
347
+ // Create server that returns 401s
348
+ server . close ( ) ;
349
+ await new Promise ( resolve => server . on ( "close" , resolve ) ) ;
350
+
351
+ server = createServer ( ( req , res ) => {
352
+ lastServerRequest = req ;
353
+ if ( req . url !== "/" ) {
354
+ res . writeHead ( 404 ) . end ( ) ;
355
+ } else {
356
+ res . writeHead ( 401 ) . end ( ) ;
357
+ }
358
+ } ) ;
359
+
360
+ await new Promise < void > ( resolve => {
361
+ server . listen ( 0 , "127.0.0.1" , ( ) => {
362
+ const addr = server . address ( ) as AddressInfo ;
363
+ baseUrl = new URL ( `http://127.0.0.1:${ addr . port } ` ) ;
364
+ resolve ( ) ;
365
+ } ) ;
366
+ } ) ;
367
+
368
+ transport = new SSEClientTransport ( baseUrl , {
369
+ authProvider : mockAuthProvider ,
370
+ } ) ;
371
+
372
+ await expect ( ( ) => transport . start ( ) ) . rejects . toThrow ( "Unauthorized" ) ;
373
+ expect ( mockAuthProvider . redirectToAuthorization . mock . calls ) . toHaveLength ( 1 ) ;
374
+ } ) ;
375
+
376
+ it ( "attempts auth flow on 401 during POST request" , async ( ) => {
377
+ // Create server that accepts SSE but returns 401 on POST
378
+ server . close ( ) ;
379
+ await new Promise ( resolve => server . on ( "close" , resolve ) ) ;
380
+
381
+ server = createServer ( ( req , res ) => {
382
+ lastServerRequest = req ;
383
+
384
+ switch ( req . method ) {
385
+ case "GET" :
386
+ if ( req . url !== "/" ) {
387
+ res . writeHead ( 404 ) . end ( ) ;
388
+ return ;
389
+ }
390
+
391
+ res . writeHead ( 200 , {
392
+ "Content-Type" : "text/event-stream" ,
393
+ "Cache-Control" : "no-cache" ,
394
+ Connection : "keep-alive" ,
395
+ } ) ;
396
+ res . write ( "event: endpoint\n" ) ;
397
+ res . write ( `data: ${ baseUrl . href } \n\n` ) ;
398
+ break ;
399
+
400
+ case "POST" :
401
+ res . writeHead ( 401 ) ;
402
+ res . end ( ) ;
403
+ break ;
404
+ }
405
+ } ) ;
406
+
407
+ await new Promise < void > ( resolve => {
408
+ server . listen ( 0 , "127.0.0.1" , ( ) => {
409
+ const addr = server . address ( ) as AddressInfo ;
410
+ baseUrl = new URL ( `http://127.0.0.1:${ addr . port } ` ) ;
411
+ resolve ( ) ;
412
+ } ) ;
413
+ } ) ;
414
+
415
+ transport = new SSEClientTransport ( baseUrl , {
416
+ authProvider : mockAuthProvider ,
417
+ } ) ;
418
+
419
+ await transport . start ( ) ;
420
+
421
+ const message : JSONRPCMessage = {
422
+ jsonrpc : "2.0" ,
423
+ id : "1" ,
424
+ method : "test" ,
425
+ params : { } ,
426
+ } ;
427
+
428
+ await expect ( ( ) => transport . send ( message ) ) . rejects . toThrow ( "Unauthorized" ) ;
429
+ expect ( mockAuthProvider . redirectToAuthorization . mock . calls ) . toHaveLength ( 1 ) ;
430
+ } ) ;
431
+
432
+ it ( "respects custom headers when using auth provider" , async ( ) => {
433
+ mockAuthProvider . tokens . mockResolvedValue ( {
434
+ access_token : "test-token" ,
435
+ token_type : "Bearer"
436
+ } ) ;
437
+
438
+ const customHeaders = {
439
+ "X-Custom-Header" : "custom-value" ,
440
+ } ;
441
+
442
+ transport = new SSEClientTransport ( baseUrl , {
443
+ authProvider : mockAuthProvider ,
444
+ requestInit : {
445
+ headers : customHeaders ,
446
+ } ,
447
+ } ) ;
448
+
449
+ await transport . start ( ) ;
450
+
451
+ const message : JSONRPCMessage = {
452
+ jsonrpc : "2.0" ,
453
+ id : "1" ,
454
+ method : "test" ,
455
+ params : { } ,
456
+ } ;
457
+
458
+ await transport . send ( message ) ;
459
+
460
+ expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer test-token" ) ;
461
+ expect ( lastServerRequest . headers [ "x-custom-header" ] ) . toBe ( "custom-value" ) ;
462
+ } ) ;
463
+ } ) ;
287
464
} ) ;
0 commit comments