@@ -132,25 +132,40 @@ func testCaseWithContext(t *testing.T, httpCtx *httpContext, test func(c *httpCo
132
132
test (httpCtx )
133
133
}
134
134
135
- func NewOidcTestServer (t * testing.T ) (privateKey * rsa.PrivateKey , oidcProvider * oidc.Provider , httpServer * httptest.Server ) {
135
+ type OidcTestServer struct {
136
+ * rsa.PrivateKey
137
+ * oidc.Provider
138
+ * httptest.Server
139
+ TokenEndpointHandler http.HandlerFunc
140
+ }
141
+
142
+ func NewOidcTestServer (t * testing.T ) (oidcTestServer * OidcTestServer ) {
136
143
t .Helper ()
137
- privateKey , err := rsa .GenerateKey (rand .Reader , 2048 )
144
+ var err error
145
+ oidcTestServer = & OidcTestServer {}
146
+ oidcTestServer .PrivateKey , err = rsa .GenerateKey (rand .Reader , 2048 )
138
147
if err != nil {
139
148
t .Fatalf ("failed to generate private key for oidc: %v" , err )
140
149
}
141
150
oidcServer := & oidctest.Server {
142
151
Algorithms : []string {oidc .RS256 , oidc .ES256 },
143
152
PublicKeys : []oidctest.PublicKey {
144
153
{
145
- PublicKey : privateKey .Public (),
154
+ PublicKey : oidcTestServer .Public (),
146
155
KeyID : "test-oidc-key-id" ,
147
156
Algorithm : oidc .RS256 ,
148
157
},
149
158
},
150
159
}
151
- httpServer = httptest .NewServer (oidcServer )
152
- oidcServer .SetIssuer (httpServer .URL )
153
- oidcProvider , err = oidc .NewProvider (t .Context (), httpServer .URL )
160
+ oidcTestServer .Server = httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
161
+ if r .URL .Path == "/token" && oidcTestServer .TokenEndpointHandler != nil {
162
+ oidcTestServer .TokenEndpointHandler .ServeHTTP (w , r )
163
+ return
164
+ }
165
+ oidcServer .ServeHTTP (w , r )
166
+ }))
167
+ oidcServer .SetIssuer (oidcTestServer .URL )
168
+ oidcTestServer .Provider , err = oidc .NewProvider (t .Context (), oidcTestServer .URL )
154
169
if err != nil {
155
170
t .Fatalf ("failed to create OIDC provider: %v" , err )
156
171
}
@@ -520,9 +535,9 @@ func TestAuthorizationUnauthorized(t *testing.T) {
520
535
})
521
536
})
522
537
// Failed OIDC validation
523
- key , oidcProvider , httpServer := NewOidcTestServer (t )
524
- t .Cleanup (httpServer .Close )
525
- testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcProvider }, func (ctx * httpContext ) {
538
+ oidcTestServer := NewOidcTestServer (t )
539
+ t .Cleanup (oidcTestServer .Close )
540
+ testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcTestServer . Provider }, func (ctx * httpContext ) {
526
541
req , err := http .NewRequest ("GET" , fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ), nil )
527
542
if err != nil {
528
543
t .Fatalf ("Failed to create request: %v" , err )
@@ -554,12 +569,12 @@ func TestAuthorizationUnauthorized(t *testing.T) {
554
569
})
555
570
// Failed Kubernetes TokenReview
556
571
rawClaims := `{
557
- "iss": "` + httpServer .URL + `",
572
+ "iss": "` + oidcTestServer .URL + `",
558
573
"exp": ` + strconv .FormatInt (time .Now ().Add (time .Hour ).Unix (), 10 ) + `,
559
574
"aud": "mcp-server"
560
575
}`
561
- validOidcToken := oidctest .SignIDToken (key , "test-oidc-key-id" , oidc .RS256 , rawClaims )
562
- testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcProvider }, func (ctx * httpContext ) {
576
+ validOidcToken := oidctest .SignIDToken (oidcTestServer . PrivateKey , "test-oidc-key-id" , oidc .RS256 , rawClaims )
577
+ testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcTestServer . Provider }, func (ctx * httpContext ) {
563
578
req , err := http .NewRequest ("GET" , fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ), nil )
564
579
if err != nil {
565
580
t .Fatalf ("Failed to create request: %v" , err )
@@ -591,7 +606,6 @@ func TestAuthorizationUnauthorized(t *testing.T) {
591
606
})
592
607
}
593
608
594
- // TestAuthorizationRequireOAuthFalse tests the scenario where OAuth is not required.
595
609
func TestAuthorizationRequireOAuthFalse (t * testing.T ) {
596
610
testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : false }}, func (ctx * httpContext ) {
597
611
resp , err := http .Get (fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ))
@@ -657,17 +671,17 @@ func TestAuthorizationRawToken(t *testing.T) {
657
671
}
658
672
659
673
func TestAuthorizationOidcToken (t * testing.T ) {
660
- key , oidcProvider , httpServer := NewOidcTestServer (t )
661
- t .Cleanup (httpServer .Close )
674
+ oidcTestServer := NewOidcTestServer (t )
675
+ t .Cleanup (oidcTestServer .Close )
662
676
rawClaims := `{
663
- "iss": "` + httpServer .URL + `",
677
+ "iss": "` + oidcTestServer .URL + `",
664
678
"exp": ` + strconv .FormatInt (time .Now ().Add (time .Hour ).Unix (), 10 ) + `,
665
679
"aud": "mcp-server"
666
680
}`
667
- validOidcToken := oidctest .SignIDToken (key , "test-oidc-key-id" , oidc .RS256 , rawClaims )
681
+ validOidcToken := oidctest .SignIDToken (oidcTestServer . PrivateKey , "test-oidc-key-id" , oidc .RS256 , rawClaims )
668
682
cases := []bool {false , true }
669
683
for _ , validateToken := range cases {
670
- testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : validateToken }, OidcProvider : oidcProvider }, func (ctx * httpContext ) {
684
+ testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : validateToken }, OidcProvider : oidcTestServer . Provider }, func (ctx * httpContext ) {
671
685
tokenReviewed := false
672
686
ctx .mockServer .Handle (http .HandlerFunc (func (w http.ResponseWriter , req * http.Request ) {
673
687
if req .URL .EscapedPath () == "/apis/authentication.k8s.io/v1/tokenreviews" {
@@ -701,6 +715,69 @@ func TestAuthorizationOidcToken(t *testing.T) {
701
715
}
702
716
})
703
717
})
718
+ }
719
+ }
704
720
721
+ func TestAuthorizationOidcTokenExchange (t * testing.T ) {
722
+ oidcTestServer := NewOidcTestServer (t )
723
+ t .Cleanup (oidcTestServer .Close )
724
+ rawClaims := `{
725
+ "iss": "` + oidcTestServer .URL + `",
726
+ "exp": ` + strconv .FormatInt (time .Now ().Add (time .Hour ).Unix (), 10 ) + `,
727
+ "aud": "%s"
728
+ }`
729
+ validOidcClientToken := oidctest .SignIDToken (oidcTestServer .PrivateKey , "test-oidc-key-id" , oidc .RS256 ,
730
+ fmt .Sprintf (rawClaims , "mcp-server" ))
731
+ validOidcBackendToken := oidctest .SignIDToken (oidcTestServer .PrivateKey , "test-oidc-key-id" , oidc .RS256 ,
732
+ fmt .Sprintf (rawClaims , "backend-audience" ))
733
+ oidcTestServer .TokenEndpointHandler = func (w http.ResponseWriter , r * http.Request ) {
734
+ w .Header ().Set ("Content-Type" , "application/json" )
735
+ _ , _ = fmt .Fprintf (w , `{"access_token":"%s","token_type":"Bearer","expires_in":253402297199}` , validOidcBackendToken )
736
+ }
737
+ cases := []bool {false , true }
738
+ for _ , validateToken := range cases {
739
+ staticConfig := & config.StaticConfig {
740
+ RequireOAuth : true ,
741
+ OAuthAudience : "mcp-server" ,
742
+ ValidateToken : validateToken ,
743
+ StsClientId : "test-sts-client-id" ,
744
+ StsClientSecret : "test-sts-client-secret" ,
745
+ StsAudience : "backend-audience" ,
746
+ StsScopes : []string {"backend-scope" },
747
+ }
748
+ testCaseWithContext (t , & httpContext {StaticConfig : staticConfig , OidcProvider : oidcTestServer .Provider }, func (ctx * httpContext ) {
749
+ tokenReviewed := false
750
+ ctx .mockServer .Handle (http .HandlerFunc (func (w http.ResponseWriter , req * http.Request ) {
751
+ if req .URL .EscapedPath () == "/apis/authentication.k8s.io/v1/tokenreviews" {
752
+ w .Header ().Set ("Content-Type" , "application/json" )
753
+ _ , _ = w .Write ([]byte (tokenReviewSuccessful ))
754
+ tokenReviewed = true
755
+ return
756
+ }
757
+ }))
758
+ req , err := http .NewRequest ("GET" , fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ), nil )
759
+ if err != nil {
760
+ t .Fatalf ("Failed to create request: %v" , err )
761
+ }
762
+ req .Header .Set ("Authorization" , "Bearer " + validOidcClientToken )
763
+ resp , err := http .DefaultClient .Do (req )
764
+ if err != nil {
765
+ t .Fatalf ("Failed to get protected endpoint: %v" , err )
766
+ }
767
+ t .Cleanup (func () { _ = resp .Body .Close () })
768
+ t .Run (fmt .Sprintf ("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header returns 200 - OK" , validateToken ), func (t * testing.T ) {
769
+ if resp .StatusCode != http .StatusOK {
770
+ t .Errorf ("Expected HTTP 200 OK, got %d" , resp .StatusCode )
771
+ }
772
+ })
773
+ t .Run (fmt .Sprintf ("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header performs token validation accordingly" , validateToken ), func (t * testing.T ) {
774
+ if tokenReviewed == true && ! validateToken {
775
+ t .Errorf ("Expected token review to be skipped when validate-token is false, but it was performed" )
776
+ }
777
+ if tokenReviewed == false && validateToken {
778
+ t .Errorf ("Expected token review to be performed when validate-token is true, but it was skipped" )
779
+ }
780
+ })
781
+ })
705
782
}
706
783
}
0 commit comments