@@ -40,6 +40,7 @@ jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({
40
40
registerClient : jest . fn ( ) ,
41
41
startAuthorization : jest . fn ( ) ,
42
42
exchangeAuthorization : jest . fn ( ) ,
43
+ discoverOAuthProtectedResourceMetadata : jest . fn ( ) ,
43
44
} ) ) ;
44
45
45
46
// Import the functions to get their types
@@ -49,6 +50,7 @@ import {
49
50
startAuthorization ,
50
51
exchangeAuthorization ,
51
52
auth ,
53
+ discoverOAuthProtectedResourceMetadata ,
52
54
} from "@modelcontextprotocol/sdk/client/auth.js" ;
53
55
import { OAuthMetadata } from "@modelcontextprotocol/sdk/shared/auth.js" ;
54
56
import { EMPTY_DEBUGGER_STATE } from "@/lib/auth-types" ;
@@ -67,6 +69,10 @@ const mockExchangeAuthorization = exchangeAuthorization as jest.MockedFunction<
67
69
typeof exchangeAuthorization
68
70
> ;
69
71
const mockAuth = auth as jest . MockedFunction < typeof auth > ;
72
+ const mockDiscoverOAuthProtectedResourceMetadata =
73
+ discoverOAuthProtectedResourceMetadata as jest . MockedFunction <
74
+ typeof discoverOAuthProtectedResourceMetadata
75
+ > ;
70
76
71
77
const sessionStorageMock = {
72
78
getItem : jest . fn ( ) ,
@@ -100,6 +106,7 @@ describe("AuthDebugger", () => {
100
106
101
107
mockDiscoverOAuthMetadata . mockResolvedValue ( mockOAuthMetadata ) ;
102
108
mockRegisterClient . mockResolvedValue ( mockOAuthClientInfo ) ;
109
+ mockDiscoverOAuthProtectedResourceMetadata . mockResolvedValue ( null ) ;
103
110
mockStartAuthorization . mockImplementation ( async ( _sseUrl , options ) => {
104
111
const authUrl = new URL ( "https://oauth.example.com/authorize" ) ;
105
112
@@ -421,4 +428,63 @@ describe("AuthDebugger", () => {
421
428
} ) ;
422
429
} ) ;
423
430
} ) ;
431
+
432
+ describe ( "OAuth State Persistence" , ( ) => {
433
+ it ( "should store auth state to sessionStorage before redirect in Quick OAuth Flow" , async ( ) => {
434
+ const updateAuthState = jest . fn ( ) ;
435
+
436
+ // Mock window.location.href setter
437
+ delete ( window as any ) . location ;
438
+ window . location = { href : "" } as any ;
439
+
440
+ // Setup mocks for OAuth flow
441
+ mockStartAuthorization . mockResolvedValue ( {
442
+ authorizationUrl : new URL (
443
+ "https://oauth.example.com/authorize?client_id=test_client_id&redirect_uri=http%3A%2F%2Flocalhost%3A3000%2Foauth%2Fcallback%2Fdebug" ,
444
+ ) ,
445
+ codeVerifier : "test_verifier" ,
446
+ } ) ;
447
+
448
+ await act ( async ( ) => {
449
+ renderAuthDebugger ( {
450
+ updateAuthState,
451
+ authState : { ...defaultAuthState , loading : false } ,
452
+ } ) ;
453
+ } ) ;
454
+
455
+ // Click Quick OAuth Flow
456
+ await act ( async ( ) => {
457
+ fireEvent . click ( screen . getByText ( "Quick OAuth Flow" ) ) ;
458
+ } ) ;
459
+
460
+ // Wait for the flow to reach the authorization step
461
+ await waitFor ( ( ) => {
462
+ expect ( sessionStorage . setItem ) . toHaveBeenCalledWith (
463
+ SESSION_KEYS . AUTH_DEBUGGER_STATE ,
464
+ expect . stringContaining ( '"oauthStep":"authorization_code"' ) ,
465
+ ) ;
466
+ } ) ;
467
+
468
+ // Verify the stored state includes all the accumulated data
469
+ const storedStateCall = (
470
+ sessionStorage . setItem as jest . Mock
471
+ ) . mock . calls . find ( ( call ) => call [ 0 ] === SESSION_KEYS . AUTH_DEBUGGER_STATE ) ;
472
+
473
+ expect ( storedStateCall ) . toBeDefined ( ) ;
474
+ const storedState = JSON . parse ( storedStateCall ! [ 1 ] ) ;
475
+
476
+ expect ( storedState ) . toMatchObject ( {
477
+ oauthStep : "authorization_code" ,
478
+ authorizationUrl : expect . stringMatching (
479
+ / ^ h t t p s : \/ \/ o a u t h \. e x a m p l e \. c o m \/ a u t h o r i z e / ,
480
+ ) ,
481
+ oauthMetadata : expect . objectContaining ( {
482
+ token_endpoint : "https://oauth.example.com/token" ,
483
+ } ) ,
484
+ oauthClientInfo : expect . objectContaining ( {
485
+ client_id : "test_client_id" ,
486
+ } ) ,
487
+ } ) ;
488
+ } ) ;
489
+ } ) ;
424
490
} ) ;
0 commit comments