Skip to content

Commit 90d4bb0

Browse files
authored
feat(auth): token exchange auth workflow (#255)
Signed-off-by: Marc Nuri <[email protected]>
1 parent 58c47dc commit 90d4bb0

File tree

5 files changed

+187
-26
lines changed

5 files changed

+187
-26
lines changed

pkg/config/config.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,27 @@ type StaticConfig struct {
2222
DisableDestructive bool `toml:"disable_destructive,omitempty"`
2323
EnabledTools []string `toml:"enabled_tools,omitempty"`
2424
DisabledTools []string `toml:"disabled_tools,omitempty"`
25-
RequireOAuth bool `toml:"require_oauth,omitempty"`
2625

27-
//Authorization related fields
26+
// Authorization-related fields
27+
// RequireOAuth indicates whether the server requires OAuth for authentication.
28+
RequireOAuth bool `toml:"require_oauth,omitempty"`
2829
// OAuthAudience is the valid audience for the OAuth tokens, used for offline JWT claim validation.
2930
OAuthAudience string `toml:"oauth_audience,omitempty"`
3031
// ValidateToken indicates whether the server should validate the token against the Kubernetes API Server using TokenReview.
3132
ValidateToken bool `toml:"validate_token,omitempty"`
3233
// AuthorizationURL is the URL of the OIDC authorization server.
3334
// It is used for token validation and for STS token exchange.
34-
AuthorizationURL string `toml:"authorization_url,omitempty"`
35-
CertificateAuthority string `toml:"certificate_authority,omitempty"`
36-
ServerURL string `toml:"server_url,omitempty"`
35+
AuthorizationURL string `toml:"authorization_url,omitempty"`
36+
// StsClientId is the OAuth client ID used for backend token exchange
37+
StsClientId string `toml:"sts_client_id,omitempty"`
38+
// StsClientSecret is the OAuth client secret used for backend token exchange
39+
StsClientSecret string `toml:"sts_client_secret,omitempty"`
40+
// StsAudience is the audience for the STS token exchange.
41+
StsAudience string `toml:"sts_audience,omitempty"`
42+
// StsScopes is the scopes for the STS token exchange.
43+
StsScopes []string `toml:"sts_scopes,omitempty"`
44+
CertificateAuthority string `toml:"certificate_authority,omitempty"`
45+
ServerURL string `toml:"server_url,omitempty"`
3746
}
3847

3948
type GroupVersionKind struct {

pkg/http/authorization.go

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@ import (
66
"net/http"
77
"strings"
88

9-
"github.com/containers/kubernetes-mcp-server/pkg/mcp"
109
"github.com/coreos/go-oidc/v3/oidc"
1110
"github.com/go-jose/go-jose/v4"
1211
"github.com/go-jose/go-jose/v4/jwt"
12+
"golang.org/x/oauth2"
1313
authenticationapiv1 "k8s.io/api/authentication/v1"
1414
"k8s.io/klog/v2"
1515
"k8s.io/utils/strings/slices"
1616

1717
"github.com/containers/kubernetes-mcp-server/pkg/config"
18+
"github.com/containers/kubernetes-mcp-server/pkg/mcp"
1819
)
1920

2021
type KubernetesApiTokenVerifier interface {
@@ -26,7 +27,7 @@ type KubernetesApiTokenVerifier interface {
2627
//
2728
// The flow is skipped for unprotected resources, such as health checks and well-known endpoints.
2829
//
29-
// There are several auth scenarios
30+
// There are several auth scenarios supported by this middleware:
3031
//
3132
// 1. requireOAuth is false:
3233
//
@@ -42,13 +43,25 @@ type KubernetesApiTokenVerifier interface {
4243
// - If OAuthAudience is set, the token is validated against the audience.
4344
// - If ValidateToken is set, the token is then used against the Kubernetes API Server for TokenReview.
4445
//
46+
// see TestAuthorizationRawToken
47+
//
4548
// 2.2. OIDC Provider Validation (oidcProvider is not nil):
4649
// - The token is validated offline for basic sanity checks (audience and expiration).
4750
// - If OAuthAudience is set, the token is validated against the audience.
4851
// - The token is then validated against the OIDC Provider.
4952
// - If ValidateToken is set, the token is then used against the Kubernetes API Server for TokenReview.
5053
//
51-
// 2.3. OIDC Token Exchange (oidcProvider is not nil and xxx):
54+
// see TestAuthorizationOidcToken
55+
//
56+
// 2.3. OIDC Token Exchange (oidcProvider is not nil, StsClientId and StsAudience are set):
57+
// - The token is validated offline for basic sanity checks (audience and expiration).
58+
// - If OAuthAudience is set, the token is validated against the audience.
59+
// - The token is then validated against the OIDC Provider.
60+
// - If the token is valid, an external account token exchange is performed using
61+
// the OIDC Provider to obtain a new token with the specified audience and scopes.
62+
// - If ValidateToken is set, the exchanged token is then used against the Kubernetes API Server for TokenReview.
63+
//
64+
// see TestAuthorizationOidcTokenExchange
5265
func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
5366
return func(next http.Handler) http.Handler {
5467
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -96,6 +109,22 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi
96109
klog.V(2).Infof("JWT token validated - Scopes: %v", scopes)
97110
r = r.WithContext(context.WithValue(r.Context(), mcp.TokenScopesContextKey, scopes))
98111
}
112+
// Token exchange with OIDC provider
113+
sts := NewFromConfig(staticConfig, oidcProvider)
114+
if err == nil && sts.IsEnabled() {
115+
var exchangedToken *oauth2.Token
116+
// If the token is valid, we can exchange it for a new token with the specified audience and scopes.
117+
exchangedToken, err = sts.ExternalAccountTokenExchange(r.Context(), &oauth2.Token{
118+
AccessToken: claims.Token,
119+
TokenType: "Bearer",
120+
})
121+
if err == nil {
122+
// Replace the original token with the exchanged token
123+
token = exchangedToken.AccessToken
124+
claims, err = ParseJWTClaims(token)
125+
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) // TODO: Implement test to verify, THIS IS A CRITICAL PART
126+
}
127+
}
99128
// Kubernetes API Server TokenReview validation
100129
if err == nil && staticConfig.ValidateToken {
101130
err = claims.ValidateWithKubernetesApi(r.Context(), staticConfig.OAuthAudience, verifier)

pkg/http/http_test.go

Lines changed: 95 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -132,25 +132,40 @@ func testCaseWithContext(t *testing.T, httpCtx *httpContext, test func(c *httpCo
132132
test(httpCtx)
133133
}
134134

135-
func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *oidc.Provider, httpServer *httptest.Server) {
135+
type OidcTestServer struct {
136+
*rsa.PrivateKey
137+
*oidc.Provider
138+
*httptest.Server
139+
TokenEndpointHandler http.HandlerFunc
140+
}
141+
142+
func NewOidcTestServer(t *testing.T) (oidcTestServer *OidcTestServer) {
136143
t.Helper()
137-
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
144+
var err error
145+
oidcTestServer = &OidcTestServer{}
146+
oidcTestServer.PrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
138147
if err != nil {
139148
t.Fatalf("failed to generate private key for oidc: %v", err)
140149
}
141150
oidcServer := &oidctest.Server{
142151
Algorithms: []string{oidc.RS256, oidc.ES256},
143152
PublicKeys: []oidctest.PublicKey{
144153
{
145-
PublicKey: privateKey.Public(),
154+
PublicKey: oidcTestServer.Public(),
146155
KeyID: "test-oidc-key-id",
147156
Algorithm: oidc.RS256,
148157
},
149158
},
150159
}
151-
httpServer = httptest.NewServer(oidcServer)
152-
oidcServer.SetIssuer(httpServer.URL)
153-
oidcProvider, err = oidc.NewProvider(t.Context(), httpServer.URL)
160+
oidcTestServer.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
161+
if r.URL.Path == "/token" && oidcTestServer.TokenEndpointHandler != nil {
162+
oidcTestServer.TokenEndpointHandler.ServeHTTP(w, r)
163+
return
164+
}
165+
oidcServer.ServeHTTP(w, r)
166+
}))
167+
oidcServer.SetIssuer(oidcTestServer.URL)
168+
oidcTestServer.Provider, err = oidc.NewProvider(t.Context(), oidcTestServer.URL)
154169
if err != nil {
155170
t.Fatalf("failed to create OIDC provider: %v", err)
156171
}
@@ -520,9 +535,9 @@ func TestAuthorizationUnauthorized(t *testing.T) {
520535
})
521536
})
522537
// Failed OIDC validation
523-
key, oidcProvider, httpServer := NewOidcTestServer(t)
524-
t.Cleanup(httpServer.Close)
525-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
538+
oidcTestServer := NewOidcTestServer(t)
539+
t.Cleanup(oidcTestServer.Close)
540+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
526541
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
527542
if err != nil {
528543
t.Fatalf("Failed to create request: %v", err)
@@ -554,12 +569,12 @@ func TestAuthorizationUnauthorized(t *testing.T) {
554569
})
555570
// Failed Kubernetes TokenReview
556571
rawClaims := `{
557-
"iss": "` + httpServer.URL + `",
572+
"iss": "` + oidcTestServer.URL + `",
558573
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
559574
"aud": "mcp-server"
560575
}`
561-
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
562-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
576+
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
577+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
563578
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
564579
if err != nil {
565580
t.Fatalf("Failed to create request: %v", err)
@@ -591,7 +606,6 @@ func TestAuthorizationUnauthorized(t *testing.T) {
591606
})
592607
}
593608

594-
// TestAuthorizationRequireOAuthFalse tests the scenario where OAuth is not required.
595609
func TestAuthorizationRequireOAuthFalse(t *testing.T) {
596610
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false}}, func(ctx *httpContext) {
597611
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
@@ -657,17 +671,17 @@ func TestAuthorizationRawToken(t *testing.T) {
657671
}
658672

659673
func TestAuthorizationOidcToken(t *testing.T) {
660-
key, oidcProvider, httpServer := NewOidcTestServer(t)
661-
t.Cleanup(httpServer.Close)
674+
oidcTestServer := NewOidcTestServer(t)
675+
t.Cleanup(oidcTestServer.Close)
662676
rawClaims := `{
663-
"iss": "` + httpServer.URL + `",
677+
"iss": "` + oidcTestServer.URL + `",
664678
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
665679
"aud": "mcp-server"
666680
}`
667-
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
681+
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
668682
cases := []bool{false, true}
669683
for _, validateToken := range cases {
670-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
684+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
671685
tokenReviewed := false
672686
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
673687
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
@@ -701,6 +715,69 @@ func TestAuthorizationOidcToken(t *testing.T) {
701715
}
702716
})
703717
})
718+
}
719+
}
704720

721+
func TestAuthorizationOidcTokenExchange(t *testing.T) {
722+
oidcTestServer := NewOidcTestServer(t)
723+
t.Cleanup(oidcTestServer.Close)
724+
rawClaims := `{
725+
"iss": "` + oidcTestServer.URL + `",
726+
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
727+
"aud": "%s"
728+
}`
729+
validOidcClientToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256,
730+
fmt.Sprintf(rawClaims, "mcp-server"))
731+
validOidcBackendToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256,
732+
fmt.Sprintf(rawClaims, "backend-audience"))
733+
oidcTestServer.TokenEndpointHandler = func(w http.ResponseWriter, r *http.Request) {
734+
w.Header().Set("Content-Type", "application/json")
735+
_, _ = fmt.Fprintf(w, `{"access_token":"%s","token_type":"Bearer","expires_in":253402297199}`, validOidcBackendToken)
736+
}
737+
cases := []bool{false, true}
738+
for _, validateToken := range cases {
739+
staticConfig := &config.StaticConfig{
740+
RequireOAuth: true,
741+
OAuthAudience: "mcp-server",
742+
ValidateToken: validateToken,
743+
StsClientId: "test-sts-client-id",
744+
StsClientSecret: "test-sts-client-secret",
745+
StsAudience: "backend-audience",
746+
StsScopes: []string{"backend-scope"},
747+
}
748+
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
749+
tokenReviewed := false
750+
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
751+
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
752+
w.Header().Set("Content-Type", "application/json")
753+
_, _ = w.Write([]byte(tokenReviewSuccessful))
754+
tokenReviewed = true
755+
return
756+
}
757+
}))
758+
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
759+
if err != nil {
760+
t.Fatalf("Failed to create request: %v", err)
761+
}
762+
req.Header.Set("Authorization", "Bearer "+validOidcClientToken)
763+
resp, err := http.DefaultClient.Do(req)
764+
if err != nil {
765+
t.Fatalf("Failed to get protected endpoint: %v", err)
766+
}
767+
t.Cleanup(func() { _ = resp.Body.Close() })
768+
t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header returns 200 - OK", validateToken), func(t *testing.T) {
769+
if resp.StatusCode != http.StatusOK {
770+
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
771+
}
772+
})
773+
t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header performs token validation accordingly", validateToken), func(t *testing.T) {
774+
if tokenReviewed == true && !validateToken {
775+
t.Errorf("Expected token review to be skipped when validate-token is false, but it was performed")
776+
}
777+
if tokenReviewed == false && validateToken {
778+
t.Errorf("Expected token review to be performed when validate-token is true, but it was skipped")
779+
}
780+
})
781+
})
705782
}
706783
}

pkg/http/sts.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"github.com/coreos/go-oidc/v3/oidc"
77
"golang.org/x/oauth2"
88
"golang.org/x/oauth2/google/externalaccount"
9+
10+
"github.com/containers/kubernetes-mcp-server/pkg/config"
911
)
1012

1113
type staticSubjectTokenSupplier struct {
@@ -26,6 +28,20 @@ type SecurityTokenService struct {
2628
ExternalAccountScopes []string
2729
}
2830

31+
func NewFromConfig(config *config.StaticConfig, provider *oidc.Provider) *SecurityTokenService {
32+
return &SecurityTokenService{
33+
Provider: provider,
34+
ClientId: config.StsClientId,
35+
ClientSecret: config.StsClientSecret,
36+
ExternalAccountAudience: config.StsAudience,
37+
ExternalAccountScopes: config.StsScopes,
38+
}
39+
}
40+
41+
func (sts *SecurityTokenService) IsEnabled() bool {
42+
return sts.Provider != nil && sts.ClientId != "" && sts.ExternalAccountAudience != ""
43+
}
44+
2945
func (sts *SecurityTokenService) ExternalAccountTokenExchange(ctx context.Context, originalToken *oauth2.Token) (*oauth2.Token, error) {
3046
ts, err := externalaccount.NewTokenSource(ctx, externalaccount.Config{
3147
TokenURL: sts.Endpoint().TokenURL,

pkg/http/sts_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,36 @@ import (
1212
"golang.org/x/oauth2"
1313
)
1414

15+
func TestIsEnabled(t *testing.T) {
16+
disabledCases := []SecurityTokenService{
17+
{},
18+
{Provider: nil},
19+
{Provider: &oidc.Provider{}},
20+
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ClientSecret: "test-client-secret"},
21+
{ClientId: "test-client-id", ClientSecret: "test-client-secret", ExternalAccountAudience: "test-audience"},
22+
{Provider: &oidc.Provider{}, ClientSecret: "test-client-secret", ExternalAccountAudience: "test-audience"},
23+
}
24+
for _, sts := range disabledCases {
25+
t.Run(fmt.Sprintf("SecurityTokenService{%+v}.IsEnabled() = false", sts), func(t *testing.T) {
26+
if sts.IsEnabled() {
27+
t.Errorf("SecurityTokenService{%+v}.IsEnabled() = true; want false", sts)
28+
}
29+
})
30+
}
31+
enabledCases := []SecurityTokenService{
32+
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ExternalAccountAudience: "test-audience"},
33+
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ExternalAccountAudience: "test-audience", ClientSecret: "test-client-secret"},
34+
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ExternalAccountAudience: "test-audience", ClientSecret: "test-client-secret", ExternalAccountScopes: []string{"test-scope"}},
35+
}
36+
for _, sts := range enabledCases {
37+
t.Run(fmt.Sprintf("SecurityTokenService{%+v}.IsEnabled() = true", sts), func(t *testing.T) {
38+
if !sts.IsEnabled() {
39+
t.Errorf("SecurityTokenService{%+v}.IsEnabled() = false; want true", sts)
40+
}
41+
})
42+
}
43+
}
44+
1545
func TestExternalAccountTokenExchange(t *testing.T) {
1646
mockServer := test.NewMockServer()
1747
authServer := mockServer.Config().Host

0 commit comments

Comments
 (0)