@@ -132,25 +132,40 @@ func testCaseWithContext(t *testing.T, httpCtx *httpContext, test func(c *httpCo
132132	test (httpCtx )
133133}
134134
135- func  NewOidcTestServer (t  * testing.T ) (privateKey  * rsa.PrivateKey , oidcProvider  * oidc.Provider , httpServer  * httptest.Server ) {
135+ type  OidcTestServer  struct  {
136+ 	* rsa.PrivateKey 
137+ 	* oidc.Provider 
138+ 	* httptest.Server 
139+ 	TokenEndpointHandler  http.HandlerFunc 
140+ }
141+ 
142+ func  NewOidcTestServer (t  * testing.T ) (oidcTestServer  * OidcTestServer ) {
136143	t .Helper ()
137- 	privateKey , err  :=  rsa .GenerateKey (rand .Reader , 2048 )
144+ 	var  err  error 
145+ 	oidcTestServer  =  & OidcTestServer {}
146+ 	oidcTestServer .PrivateKey , err  =  rsa .GenerateKey (rand .Reader , 2048 )
138147	if  err  !=  nil  {
139148		t .Fatalf ("failed to generate private key for oidc: %v" , err )
140149	}
141150	oidcServer  :=  & oidctest.Server {
142151		Algorithms : []string {oidc .RS256 , oidc .ES256 },
143152		PublicKeys : []oidctest.PublicKey {
144153			{
145- 				PublicKey : privateKey .Public (),
154+ 				PublicKey : oidcTestServer .Public (),
146155				KeyID :     "test-oidc-key-id" ,
147156				Algorithm : oidc .RS256 ,
148157			},
149158		},
150159	}
151- 	httpServer  =  httptest .NewServer (oidcServer )
152- 	oidcServer .SetIssuer (httpServer .URL )
153- 	oidcProvider , err  =  oidc .NewProvider (t .Context (), httpServer .URL )
160+ 	oidcTestServer .Server  =  httptest .NewServer (http .HandlerFunc (func (w  http.ResponseWriter , r  * http.Request ) {
161+ 		if  r .URL .Path  ==  "/token"  &&  oidcTestServer .TokenEndpointHandler  !=  nil  {
162+ 			oidcTestServer .TokenEndpointHandler .ServeHTTP (w , r )
163+ 			return 
164+ 		}
165+ 		oidcServer .ServeHTTP (w , r )
166+ 	}))
167+ 	oidcServer .SetIssuer (oidcTestServer .URL )
168+ 	oidcTestServer .Provider , err  =  oidc .NewProvider (t .Context (), oidcTestServer .URL )
154169	if  err  !=  nil  {
155170		t .Fatalf ("failed to create OIDC provider: %v" , err )
156171	}
@@ -520,9 +535,9 @@ func TestAuthorizationUnauthorized(t *testing.T) {
520535		})
521536	})
522537	// Failed OIDC validation 
523- 	key ,  oidcProvider ,  httpServer  :=  NewOidcTestServer (t )
524- 	t .Cleanup (httpServer .Close )
525- 	testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcProvider }, func (ctx  * httpContext ) {
538+ 	oidcTestServer  :=  NewOidcTestServer (t )
539+ 	t .Cleanup (oidcTestServer .Close )
540+ 	testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcTestServer . Provider }, func (ctx  * httpContext ) {
526541		req , err  :=  http .NewRequest ("GET" , fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ), nil )
527542		if  err  !=  nil  {
528543			t .Fatalf ("Failed to create request: %v" , err )
@@ -554,12 +569,12 @@ func TestAuthorizationUnauthorized(t *testing.T) {
554569	})
555570	// Failed Kubernetes TokenReview 
556571	rawClaims  :=  `{ 
557- 		"iss": "`  +  httpServer .URL  +  `", 
572+ 		"iss": "`  +  oidcTestServer .URL  +  `", 
558573		"exp": `  +  strconv .FormatInt (time .Now ().Add (time .Hour ).Unix (), 10 ) +  `, 
559574		"aud": "mcp-server" 
560575	}` 
561- 	validOidcToken  :=  oidctest .SignIDToken (key , "test-oidc-key-id" , oidc .RS256 , rawClaims )
562- 	testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcProvider }, func (ctx  * httpContext ) {
576+ 	validOidcToken  :=  oidctest .SignIDToken (oidcTestServer . PrivateKey , "test-oidc-key-id" , oidc .RS256 , rawClaims )
577+ 	testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcTestServer . Provider }, func (ctx  * httpContext ) {
563578		req , err  :=  http .NewRequest ("GET" , fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ), nil )
564579		if  err  !=  nil  {
565580			t .Fatalf ("Failed to create request: %v" , err )
@@ -591,7 +606,6 @@ func TestAuthorizationUnauthorized(t *testing.T) {
591606	})
592607}
593608
594- // TestAuthorizationRequireOAuthFalse tests the scenario where OAuth is not required. 
595609func  TestAuthorizationRequireOAuthFalse (t  * testing.T ) {
596610	testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : false }}, func (ctx  * httpContext ) {
597611		resp , err  :=  http .Get (fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ))
@@ -657,17 +671,17 @@ func TestAuthorizationRawToken(t *testing.T) {
657671}
658672
659673func  TestAuthorizationOidcToken (t  * testing.T ) {
660- 	key ,  oidcProvider ,  httpServer  :=  NewOidcTestServer (t )
661- 	t .Cleanup (httpServer .Close )
674+ 	oidcTestServer  :=  NewOidcTestServer (t )
675+ 	t .Cleanup (oidcTestServer .Close )
662676	rawClaims  :=  `{ 
663- 		"iss": "`  +  httpServer .URL  +  `", 
677+ 		"iss": "`  +  oidcTestServer .URL  +  `", 
664678		"exp": `  +  strconv .FormatInt (time .Now ().Add (time .Hour ).Unix (), 10 ) +  `, 
665679		"aud": "mcp-server" 
666680	}` 
667- 	validOidcToken  :=  oidctest .SignIDToken (key , "test-oidc-key-id" , oidc .RS256 , rawClaims )
681+ 	validOidcToken  :=  oidctest .SignIDToken (oidcTestServer . PrivateKey , "test-oidc-key-id" , oidc .RS256 , rawClaims )
668682	cases  :=  []bool {false , true }
669683	for  _ , validateToken  :=  range  cases  {
670- 		testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : validateToken }, OidcProvider : oidcProvider }, func (ctx  * httpContext ) {
684+ 		testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : validateToken }, OidcProvider : oidcTestServer . Provider }, func (ctx  * httpContext ) {
671685			tokenReviewed  :=  false 
672686			ctx .mockServer .Handle (http .HandlerFunc (func (w  http.ResponseWriter , req  * http.Request ) {
673687				if  req .URL .EscapedPath () ==  "/apis/authentication.k8s.io/v1/tokenreviews"  {
@@ -701,6 +715,69 @@ func TestAuthorizationOidcToken(t *testing.T) {
701715				}
702716			})
703717		})
718+ 	}
719+ }
704720
721+ func  TestAuthorizationOidcTokenExchange (t  * testing.T ) {
722+ 	oidcTestServer  :=  NewOidcTestServer (t )
723+ 	t .Cleanup (oidcTestServer .Close )
724+ 	rawClaims  :=  `{ 
725+ 		"iss": "`  +  oidcTestServer .URL  +  `", 
726+ 		"exp": `  +  strconv .FormatInt (time .Now ().Add (time .Hour ).Unix (), 10 ) +  `, 
727+ 		"aud": "%s" 
728+ 	}` 
729+ 	validOidcClientToken  :=  oidctest .SignIDToken (oidcTestServer .PrivateKey , "test-oidc-key-id" , oidc .RS256 ,
730+ 		fmt .Sprintf (rawClaims , "mcp-server" ))
731+ 	validOidcBackendToken  :=  oidctest .SignIDToken (oidcTestServer .PrivateKey , "test-oidc-key-id" , oidc .RS256 ,
732+ 		fmt .Sprintf (rawClaims , "backend-audience" ))
733+ 	oidcTestServer .TokenEndpointHandler  =  func (w  http.ResponseWriter , r  * http.Request ) {
734+ 		w .Header ().Set ("Content-Type" , "application/json" )
735+ 		_ , _  =  fmt .Fprintf (w , `{"access_token":"%s","token_type":"Bearer","expires_in":253402297199}` , validOidcBackendToken )
736+ 	}
737+ 	cases  :=  []bool {false , true }
738+ 	for  _ , validateToken  :=  range  cases  {
739+ 		staticConfig  :=  & config.StaticConfig {
740+ 			RequireOAuth :    true ,
741+ 			OAuthAudience :   "mcp-server" ,
742+ 			ValidateToken :   validateToken ,
743+ 			StsClientId :     "test-sts-client-id" ,
744+ 			StsClientSecret : "test-sts-client-secret" ,
745+ 			StsAudience :     "backend-audience" ,
746+ 			StsScopes :       []string {"backend-scope" },
747+ 		}
748+ 		testCaseWithContext (t , & httpContext {StaticConfig : staticConfig , OidcProvider : oidcTestServer .Provider }, func (ctx  * httpContext ) {
749+ 			tokenReviewed  :=  false 
750+ 			ctx .mockServer .Handle (http .HandlerFunc (func (w  http.ResponseWriter , req  * http.Request ) {
751+ 				if  req .URL .EscapedPath () ==  "/apis/authentication.k8s.io/v1/tokenreviews"  {
752+ 					w .Header ().Set ("Content-Type" , "application/json" )
753+ 					_ , _  =  w .Write ([]byte (tokenReviewSuccessful ))
754+ 					tokenReviewed  =  true 
755+ 					return 
756+ 				}
757+ 			}))
758+ 			req , err  :=  http .NewRequest ("GET" , fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ), nil )
759+ 			if  err  !=  nil  {
760+ 				t .Fatalf ("Failed to create request: %v" , err )
761+ 			}
762+ 			req .Header .Set ("Authorization" , "Bearer " + validOidcClientToken )
763+ 			resp , err  :=  http .DefaultClient .Do (req )
764+ 			if  err  !=  nil  {
765+ 				t .Fatalf ("Failed to get protected endpoint: %v" , err )
766+ 			}
767+ 			t .Cleanup (func () { _  =  resp .Body .Close () })
768+ 			t .Run (fmt .Sprintf ("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header returns 200 - OK" , validateToken ), func (t  * testing.T ) {
769+ 				if  resp .StatusCode  !=  http .StatusOK  {
770+ 					t .Errorf ("Expected HTTP 200 OK, got %d" , resp .StatusCode )
771+ 				}
772+ 			})
773+ 			t .Run (fmt .Sprintf ("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header performs token validation accordingly" , validateToken ), func (t  * testing.T ) {
774+ 				if  tokenReviewed  ==  true  &&  ! validateToken  {
775+ 					t .Errorf ("Expected token review to be skipped when validate-token is false, but it was performed" )
776+ 				}
777+ 				if  tokenReviewed  ==  false  &&  validateToken  {
778+ 					t .Errorf ("Expected token review to be performed when validate-token is true, but it was skipped" )
779+ 				}
780+ 			})
781+ 		})
705782	}
706783}
0 commit comments