Skip to content

Commit 49afbad

Browse files
authored
feat(http): add custom CA certificate support for OIDC providers
1 parent 7f4edfd commit 49afbad

File tree

5 files changed

+20
-10
lines changed

5 files changed

+20
-10
lines changed

pkg/http/authorization.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ func write401(w http.ResponseWriter, wwwAuthenticateHeader, errorType, message s
108108
// - If ValidateToken is set, the exchanged token is then used against the Kubernetes API Server for TokenReview.
109109
//
110110
// see TestAuthorizationOidcTokenExchange
111-
func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
111+
func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier, httpClient *http.Client) func(http.Handler) http.Handler {
112112
return func(next http.Handler) http.Handler {
113113
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
114114
if r.URL.Path == healthEndpoint || slices.Contains(WellKnownEndpoints, r.URL.EscapedPath()) {
@@ -159,7 +159,11 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi
159159
if err == nil && sts.IsEnabled() {
160160
var exchangedToken *oauth2.Token
161161
// If the token is valid, we can exchange it for a new token with the specified audience and scopes.
162-
exchangedToken, err = sts.ExternalAccountTokenExchange(r.Context(), &oauth2.Token{
162+
ctx := r.Context()
163+
if httpClient != nil {
164+
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
165+
}
166+
exchangedToken, err = sts.ExternalAccountTokenExchange(ctx, &oauth2.Token{
163167
AccessToken: claims.Token,
164168
TokenType: "Bearer",
165169
})

pkg/http/http.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ const (
2424
sseMessageEndpoint = "/message"
2525
)
2626

27-
func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig, oidcProvider *oidc.Provider) error {
27+
func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, httpClient *http.Client) error {
2828
mux := http.NewServeMux()
2929

3030
wrappedMux := RequestMiddleware(
31-
AuthorizationMiddleware(staticConfig, oidcProvider, mcpServer)(mux),
31+
AuthorizationMiddleware(staticConfig, oidcProvider, mcpServer, httpClient)(mux),
3232
)
3333

3434
httpServer := &http.Server{
@@ -44,7 +44,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat
4444
mux.HandleFunc(healthEndpoint, func(w http.ResponseWriter, r *http.Request) {
4545
w.WriteHeader(http.StatusOK)
4646
})
47-
mux.Handle("/.well-known/", WellKnownHandler(staticConfig))
47+
mux.Handle("/.well-known/", WellKnownHandler(staticConfig, httpClient))
4848

4949
ctx, cancel := context.WithCancel(ctx)
5050
defer cancel()

pkg/http/http_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func (c *httpContext) beforeEach(t *testing.T) {
8989
timeoutCtx, c.timeoutCancel = context.WithTimeout(t.Context(), 10*time.Second)
9090
group, gc := errgroup.WithContext(timeoutCtx)
9191
cancelCtx, c.StopServer = context.WithCancel(gc)
92-
group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OidcProvider) })
92+
group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OidcProvider, nil) })
9393
c.WaitForShutdown = group.Wait
9494
// Wait for HTTP server to start (using net)
9595
for i := 0; i < 10; i++ {

pkg/http/wellknown.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,24 @@ type WellKnown struct {
2525
authorizationUrl string
2626
scopesSupported []string
2727
disableDynamicClientRegistration bool
28+
httpClient *http.Client
2829
}
2930

3031
var _ http.Handler = &WellKnown{}
3132

32-
func WellKnownHandler(staticConfig *config.StaticConfig) http.Handler {
33+
func WellKnownHandler(staticConfig *config.StaticConfig, httpClient *http.Client) http.Handler {
3334
authorizationUrl := staticConfig.AuthorizationURL
3435
if authorizationUrl != "" && strings.HasSuffix("authorizationUrl", "/") {
3536
authorizationUrl = strings.TrimSuffix(authorizationUrl, "/")
3637
}
38+
if httpClient == nil {
39+
httpClient = http.DefaultClient
40+
}
3741
return &WellKnown{
3842
authorizationUrl: authorizationUrl,
3943
disableDynamicClientRegistration: staticConfig.DisableDynamicClientRegistration,
4044
scopesSupported: staticConfig.OAuthScopes,
45+
httpClient: httpClient,
4146
}
4247
}
4348

@@ -51,7 +56,7 @@ func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request)
5156
http.Error(writer, "Failed to create request: "+err.Error(), http.StatusInternalServerError)
5257
return
5358
}
54-
resp, err := http.DefaultClient.Do(req.WithContext(request.Context()))
59+
resp, err := w.httpClient.Do(req.WithContext(request.Context()))
5560
if err != nil {
5661
http.Error(writer, "Failed to perform request: "+err.Error(), http.StatusInternalServerError)
5762
return

pkg/kubernetes-mcp-server/cmd/root.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,11 @@ func (m *MCPServerOptions) Run() error {
301301
}
302302

303303
var oidcProvider *oidc.Provider
304+
var httpClient *http.Client
304305
if m.StaticConfig.AuthorizationURL != "" {
305306
ctx := context.Background()
306307
if m.StaticConfig.CertificateAuthority != "" {
307-
httpClient := &http.Client{}
308+
httpClient = &http.Client{}
308309
caCert, err := os.ReadFile(m.StaticConfig.CertificateAuthority)
309310
if err != nil {
310311
return fmt.Errorf("failed to read CA certificate from %s: %w", m.StaticConfig.CertificateAuthority, err)
@@ -341,7 +342,7 @@ func (m *MCPServerOptions) Run() error {
341342

342343
if m.StaticConfig.Port != "" {
343344
ctx := context.Background()
344-
return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oidcProvider)
345+
return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oidcProvider, httpClient)
345346
}
346347

347348
if err := mcpServer.ServeStdio(); err != nil && !errors.Is(err, context.Canceled) {

0 commit comments

Comments
 (0)