Skip to content

Commit f80d8df

Browse files
authored
feat(auth): authorize user from custom SSE header (96)
feat(auth): Authorize user from custom SSE header PoC to show how we can propagate an Authorization Bearer token from the MCP client up to the Kubernetes API by passing a custom header (Kubernetes-Authorization-Bearer-Token). A new Derived client is necessary for each request due to the incompleteness of some of the client-go clients. This might add some overhead for each prompt. Ideally, the issue with the discoveryclient and others should be fixed to allow reading the authorization header from the request context. To use the feature, the MCP Server still needs to be started with a basic configuration (either provided InCluster by a service account or locally by a .kube/config file) so that it's able to infer the server settings. --- test(auth): added tests to verify header propagation --- refactor(auth): minor improvements for derived client
1 parent 9830e22 commit f80d8df

File tree

10 files changed

+182
-22
lines changed

10 files changed

+182
-22
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package kubernetes
2+
3+
import "net/http"
4+
5+
type impersonateRoundTripper struct {
6+
delegate http.RoundTripper
7+
}
8+
9+
func (irt *impersonateRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
10+
// TODO: Solution won't work with discoveryclient which uses context.TODO() instead of the passed-in context
11+
if v, ok := req.Context().Value(AuthorizationHeader).(string); ok {
12+
req.Header.Set("Authorization", v)
13+
}
14+
return irt.delegate.RoundTrip(req)
15+
}

pkg/kubernetes/kubernetes.go

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package kubernetes
22

33
import (
4+
"context"
45
"github.com/fsnotify/fsnotify"
56
"github.com/manusa/kubernetes-mcp-server/pkg/helm"
67
v1 "k8s.io/api/core/v1"
@@ -15,9 +16,15 @@ import (
1516
"k8s.io/client-go/rest"
1617
"k8s.io/client-go/restmapper"
1718
"k8s.io/client-go/tools/clientcmd"
19+
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
1820
"sigs.k8s.io/yaml"
1921
)
2022

23+
const (
24+
AuthorizationHeader = "Kubernetes-Authorization"
25+
AuthorizationBearerTokenHeader = "kubernetes-authorization-bearer-token"
26+
)
27+
2128
type CloseWatchKubeConfig func() error
2229

2330
type Kubernetes struct {
@@ -42,6 +49,10 @@ func NewKubernetes(kubeconfig string) (*Kubernetes, error) {
4249
if err := resolveKubernetesConfigurations(k8s); err != nil {
4350
return nil, err
4451
}
52+
// TODO: Won't work because not all client-go clients use the shared context (e.g. discovery client uses context.TODO())
53+
//k8s.cfg.Wrap(func(original http.RoundTripper) http.RoundTripper {
54+
// return &impersonateRoundTripper{original}
55+
//})
4556
var err error
4657
k8s.clientSet, err = kubernetes.NewForConfig(k8s.cfg)
4758
if err != nil {
@@ -52,7 +63,7 @@ func NewKubernetes(kubeconfig string) (*Kubernetes, error) {
5263
return nil, err
5364
}
5465
k8s.discoveryClient = memory.NewMemCacheClient(discoveryClient)
55-
k8s.deferredDiscoveryRESTMapper = restmapper.NewDeferredDiscoveryRESTMapper(memory.NewMemCacheClient(k8s.discoveryClient))
66+
k8s.deferredDiscoveryRESTMapper = restmapper.NewDeferredDiscoveryRESTMapper(k8s.discoveryClient)
5667
k8s.dynamicClient, err = dynamic.NewForConfig(k8s.cfg)
5768
if err != nil {
5869
return nil, err
@@ -116,6 +127,50 @@ func (k *Kubernetes) ToRESTMapper() (meta.RESTMapper, error) {
116127
return k.deferredDiscoveryRESTMapper, nil
117128
}
118129

130+
func (k *Kubernetes) Derived(ctx context.Context) *Kubernetes {
131+
bearerToken, ok := ctx.Value(AuthorizationBearerTokenHeader).(string)
132+
if !ok {
133+
return k
134+
}
135+
derivedCfg := rest.CopyConfig(k.cfg)
136+
derivedCfg.BearerToken = bearerToken
137+
derivedCfg.BearerTokenFile = ""
138+
derivedCfg.Username = ""
139+
derivedCfg.Password = ""
140+
derivedCfg.AuthProvider = nil
141+
derivedCfg.AuthConfigPersister = nil
142+
derivedCfg.ExecProvider = nil
143+
derivedCfg.Impersonate = rest.ImpersonationConfig{}
144+
clientCmdApiConfig, err := k.clientCmdConfig.RawConfig()
145+
if err != nil {
146+
return k
147+
}
148+
clientCmdApiConfig.AuthInfos = make(map[string]*clientcmdapi.AuthInfo)
149+
derived := &Kubernetes{
150+
Kubeconfig: k.Kubeconfig,
151+
clientCmdConfig: clientcmd.NewDefaultClientConfig(clientCmdApiConfig, nil),
152+
cfg: derivedCfg,
153+
scheme: k.scheme,
154+
parameterCodec: k.parameterCodec,
155+
}
156+
derived.clientSet, err = kubernetes.NewForConfig(derived.cfg)
157+
if err != nil {
158+
return k
159+
}
160+
discoveryClient, err := discovery.NewDiscoveryClientForConfig(derived.cfg)
161+
if err != nil {
162+
return k
163+
}
164+
derived.discoveryClient = memory.NewMemCacheClient(discoveryClient)
165+
derived.deferredDiscoveryRESTMapper = restmapper.NewDeferredDiscoveryRESTMapper(derived.discoveryClient)
166+
derived.dynamicClient, err = dynamic.NewForConfig(derived.cfg)
167+
if err != nil {
168+
return k
169+
}
170+
derived.Helm = helm.NewHelm(derived)
171+
return derived
172+
}
173+
119174
func marshal(v any) (string, error) {
120175
switch t := v.(type) {
121176
case []unstructured.Unstructured:

pkg/mcp/common_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"github.com/mark3labs/mcp-go/client"
8+
"github.com/mark3labs/mcp-go/client/transport"
89
"github.com/mark3labs/mcp-go/mcp"
910
"github.com/mark3labs/mcp-go/server"
1011
"github.com/pkg/errors"
@@ -97,6 +98,7 @@ type mcpContext struct {
9798
profile Profile
9899
readOnly bool
99100
disableDestructive bool
101+
clientOptions []transport.ClientOption
100102
before func(*mcpContext)
101103
after func(*mcpContext)
102104
ctx context.Context
@@ -124,8 +126,8 @@ func (c *mcpContext) beforeEach(t *testing.T) {
124126
t.Fatal(err)
125127
return
126128
}
127-
c.mcpHttpServer = server.NewTestServer(c.mcpServer.server)
128-
if c.mcpClient, err = client.NewSSEMCPClient(c.mcpHttpServer.URL + "/sse"); err != nil {
129+
c.mcpHttpServer = server.NewTestServer(c.mcpServer.server, server.WithSSEContextFunc(contextFunc))
130+
if c.mcpClient, err = client.NewSSEMCPClient(c.mcpHttpServer.URL+"/sse", c.clientOptions...); err != nil {
129131
t.Fatal(err)
130132
return
131133
}

pkg/mcp/events.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func (s *Server) eventsList(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.
2727
if namespace == nil {
2828
namespace = ""
2929
}
30-
ret, err := s.k.EventsList(ctx, namespace.(string))
30+
ret, err := s.k.Derived(ctx).EventsList(ctx, namespace.(string))
3131
if err != nil {
3232
return NewTextResult("", fmt.Errorf("failed to list events in all namespaces: %v", err)), nil
3333
}

pkg/mcp/helm.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ func (s *Server) helmInstall(ctx context.Context, ctr mcp.CallToolRequest) (*mcp
6464
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
6565
namespace = v
6666
}
67-
ret, err := s.k.Helm.Install(ctx, chart, values, name, namespace)
67+
ret, err := s.k.Derived(ctx).Helm.Install(ctx, chart, values, name, namespace)
6868
if err != nil {
6969
return NewTextResult("", fmt.Errorf("failed to install helm chart '%s': %w", chart, err)), nil
7070
}
7171
return NewTextResult(ret, err), nil
7272
}
7373

74-
func (s *Server) helmList(_ context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) {
74+
func (s *Server) helmList(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) {
7575
allNamespaces := false
7676
if v, ok := ctr.GetArguments()["all_namespaces"].(bool); ok {
7777
allNamespaces = v
@@ -80,14 +80,14 @@ func (s *Server) helmList(_ context.Context, ctr mcp.CallToolRequest) (*mcp.Call
8080
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
8181
namespace = v
8282
}
83-
ret, err := s.k.Helm.List(namespace, allNamespaces)
83+
ret, err := s.k.Derived(ctx).Helm.List(namespace, allNamespaces)
8484
if err != nil {
8585
return NewTextResult("", fmt.Errorf("failed to list helm releases in namespace '%s': %w", namespace, err)), nil
8686
}
8787
return NewTextResult(ret, err), nil
8888
}
8989

90-
func (s *Server) helmUninstall(_ context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) {
90+
func (s *Server) helmUninstall(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) {
9191
var name string
9292
ok := false
9393
if name, ok = ctr.GetArguments()["name"].(string); !ok {
@@ -97,7 +97,7 @@ func (s *Server) helmUninstall(_ context.Context, ctr mcp.CallToolRequest) (*mcp
9797
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
9898
namespace = v
9999
}
100-
ret, err := s.k.Helm.Uninstall(name, namespace)
100+
ret, err := s.k.Derived(ctx).Helm.Uninstall(name, namespace)
101101
if err != nil {
102102
return NewTextResult("", fmt.Errorf("failed to uninstall helm chart '%s': %w", name, err)), nil
103103
}

pkg/mcp/mcp.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package mcp
22

33
import (
4+
"context"
45
"github.com/manusa/kubernetes-mcp-server/pkg/kubernetes"
56
"github.com/manusa/kubernetes-mcp-server/pkg/version"
67
"github.com/mark3labs/mcp-go/mcp"
78
"github.com/mark3labs/mcp-go/server"
9+
"net/http"
810
)
911

1012
type Configuration struct {
@@ -71,6 +73,7 @@ func (s *Server) ServeStdio() error {
7173

7274
func (s *Server) ServeSse(baseUrl string) *server.SSEServer {
7375
options := make([]server.SSEOption, 0)
76+
options = append(options, server.WithSSEContextFunc(contextFunc))
7477
if baseUrl != "" {
7578
options = append(options, server.WithBaseURL(baseUrl))
7679
}
@@ -104,3 +107,8 @@ func NewTextResult(content string, err error) *mcp.CallToolResult {
104107
},
105108
}
106109
}
110+
111+
func contextFunc(ctx context.Context, r *http.Request) context.Context {
112+
//return context.WithValue(ctx, kubernetes.AuthorizationHeader, r.Header.Get(kubernetes.AuthorizationHeader))
113+
return context.WithValue(ctx, kubernetes.AuthorizationBearerTokenHeader, r.Header.Get(kubernetes.AuthorizationBearerTokenHeader))
114+
}

pkg/mcp/mcp_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package mcp
22

33
import (
44
"context"
5+
"github.com/mark3labs/mcp-go/client"
56
"github.com/mark3labs/mcp-go/mcp"
7+
"net/http"
68
"os"
79
"path/filepath"
810
"runtime"
@@ -88,3 +90,81 @@ func TestDisableDestructive(t *testing.T) {
8890
})
8991
})
9092
}
93+
94+
func TestSseHeaders(t *testing.T) {
95+
mockServer := NewMockServer()
96+
defer mockServer.Close()
97+
before := func(c *mcpContext) {
98+
c.withKubeConfig(mockServer.config)
99+
c.clientOptions = append(c.clientOptions, client.WithHeaders(map[string]string{"kubernetes-authorization-bearer-token": "a-token-from-mcp-client"}))
100+
}
101+
pathHeaders := make(map[string]http.Header, 0)
102+
mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
103+
pathHeaders[req.URL.Path] = req.Header.Clone()
104+
// Request Performed by DiscoveryClient to Kube API (Get API Groups legacy -core-)
105+
if req.URL.Path == "/api" {
106+
w.Header().Set("Content-Type", "application/json")
107+
_, _ = w.Write([]byte(`{"kind":"APIVersions","versions":["v1"],"serverAddressByClientCIDRs":[{"clientCIDR":"0.0.0.0/0"}]}`))
108+
return
109+
}
110+
// Request Performed by DiscoveryClient to Kube API (Get API Groups)
111+
if req.URL.Path == "/apis" {
112+
w.Header().Set("Content-Type", "application/json")
113+
//w.Write([]byte(`{"kind":"APIGroupList","apiVersion":"v1","groups":[{"name":"apps","versions":[{"groupVersion":"apps/v1","version":"v1"}],"preferredVersion":{"groupVersion":"apps/v1","version":"v1"}}]}`))
114+
_, _ = w.Write([]byte(`{"kind":"APIGroupList","apiVersion":"v1","groups":[]}`))
115+
return
116+
}
117+
// Request Performed by DiscoveryClient to Kube API (Get API Resources)
118+
if req.URL.Path == "/api/v1" {
119+
w.Header().Set("Content-Type", "application/json")
120+
_, _ = w.Write([]byte(`{"kind":"APIResourceList","apiVersion":"v1","resources":[{"name":"pods","singularName":"","namespaced":true,"kind":"Pod","verbs":["get","list","watch","create","update","patch","delete"]}]}`))
121+
return
122+
}
123+
// Request Performed by DynamicClient
124+
if req.URL.Path == "/api/v1/namespaces/default/pods" {
125+
w.Header().Set("Content-Type", "application/json")
126+
_, _ = w.Write([]byte(`{"kind":"PodList","apiVersion":"v1","items":[]}`))
127+
return
128+
}
129+
// Request Performed by kubernetes.Interface
130+
if req.URL.Path == "/api/v1/namespaces/default/pods/a-pod-to-delete" {
131+
w.WriteHeader(200)
132+
return
133+
}
134+
w.WriteHeader(404)
135+
}))
136+
testCaseWithContext(t, &mcpContext{before: before}, func(c *mcpContext) {
137+
c.callTool("pods_list", map[string]interface{}{})
138+
t.Run("DiscoveryClient propagates headers to Kube API", func(t *testing.T) {
139+
if len(pathHeaders) == 0 {
140+
t.Fatalf("No requests were made to Kube API")
141+
}
142+
if pathHeaders["/api"] == nil || pathHeaders["/api"].Get("Authorization") != "Bearer a-token-from-mcp-client" {
143+
t.Fatalf("Overridden header Authorization not found in request to /api")
144+
}
145+
if pathHeaders["/apis"] == nil || pathHeaders["/apis"].Get("Authorization") != "Bearer a-token-from-mcp-client" {
146+
t.Fatalf("Overridden header Authorization not found in request to /apis")
147+
}
148+
if pathHeaders["/api/v1"] == nil || pathHeaders["/api/v1"].Get("Authorization") != "Bearer a-token-from-mcp-client" {
149+
t.Fatalf("Overridden header Authorization not found in request to /api/v1")
150+
}
151+
})
152+
t.Run("DynamicClient propagates headers to Kube API", func(t *testing.T) {
153+
if len(pathHeaders) == 0 {
154+
t.Fatalf("No requests were made to Kube API")
155+
}
156+
if pathHeaders["/api/v1/namespaces/default/pods"] == nil || pathHeaders["/api/v1/namespaces/default/pods"].Get("Authorization") != "Bearer a-token-from-mcp-client" {
157+
t.Fatalf("Overridden header Authorization not found in request to /api/v1/namespaces/default/pods")
158+
}
159+
})
160+
c.callTool("pods_delete", map[string]interface{}{"name": "a-pod-to-delete"})
161+
t.Run("kubernetes.Interface propagates headers to Kube API", func(t *testing.T) {
162+
if len(pathHeaders) == 0 {
163+
t.Fatalf("No requests were made to Kube API")
164+
}
165+
if pathHeaders["/api/v1/namespaces/default/pods/a-pod-to-delete"] == nil || pathHeaders["/api/v1/namespaces/default/pods/a-pod-to-delete"].Get("Authorization") != "Bearer a-token-from-mcp-client" {
166+
t.Fatalf("Overridden header Authorization not found in request to /api/v1/namespaces/default/pods/a-pod-to-delete")
167+
}
168+
})
169+
})
170+
}

pkg/mcp/namespaces.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ func (s *Server) initNamespaces() []server.ServerTool {
3535
}
3636

3737
func (s *Server) namespacesList(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
38-
ret, err := s.k.NamespacesList(ctx)
38+
ret, err := s.k.Derived(ctx).NamespacesList(ctx)
3939
if err != nil {
4040
err = fmt.Errorf("failed to list namespaces: %v", err)
4141
}
4242
return NewTextResult(ret, err), nil
4343
}
4444

4545
func (s *Server) projectsList(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
46-
ret, err := s.k.ProjectsList(ctx)
46+
ret, err := s.k.Derived(ctx).ProjectsList(ctx)
4747
if err != nil {
4848
err = fmt.Errorf("failed to list projects: %v", err)
4949
}

pkg/mcp/pods.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ func (s *Server) podsListInAllNamespaces(ctx context.Context, ctr mcp.CallToolRe
110110
selector = labelSelector.(string)
111111
}
112112

113-
ret, err := s.k.PodsListInAllNamespaces(ctx, selector)
113+
ret, err := s.k.Derived(ctx).PodsListInAllNamespaces(ctx, selector)
114114
if err != nil {
115115
return NewTextResult("", fmt.Errorf("failed to list pods in all namespaces: %v", err)), nil
116116
}
@@ -127,7 +127,7 @@ func (s *Server) podsListInNamespace(ctx context.Context, ctr mcp.CallToolReques
127127
if labelSelector != nil {
128128
selector = labelSelector.(string)
129129
}
130-
ret, err := s.k.PodsListInNamespace(ctx, ns.(string), selector)
130+
ret, err := s.k.Derived(ctx).PodsListInNamespace(ctx, ns.(string), selector)
131131
if err != nil {
132132
return NewTextResult("", fmt.Errorf("failed to list pods in namespace %s: %v", ns, err)), nil
133133
}
@@ -143,7 +143,7 @@ func (s *Server) podsGet(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
143143
if name == nil {
144144
return NewTextResult("", errors.New("failed to get pod, missing argument name")), nil
145145
}
146-
ret, err := s.k.PodsGet(ctx, ns.(string), name.(string))
146+
ret, err := s.k.Derived(ctx).PodsGet(ctx, ns.(string), name.(string))
147147
if err != nil {
148148
return NewTextResult("", fmt.Errorf("failed to get pod %s in namespace %s: %v", name, ns, err)), nil
149149
}
@@ -159,7 +159,7 @@ func (s *Server) podsDelete(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.
159159
if name == nil {
160160
return NewTextResult("", errors.New("failed to delete pod, missing argument name")), nil
161161
}
162-
ret, err := s.k.PodsDelete(ctx, ns.(string), name.(string))
162+
ret, err := s.k.Derived(ctx).PodsDelete(ctx, ns.(string), name.(string))
163163
if err != nil {
164164
return NewTextResult("", fmt.Errorf("failed to delete pod %s in namespace %s: %v", name, ns, err)), nil
165165
}
@@ -190,7 +190,7 @@ func (s *Server) podsExec(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Ca
190190
} else {
191191
return NewTextResult("", errors.New("failed to exec in pod, invalid command argument")), nil
192192
}
193-
ret, err := s.k.PodsExec(ctx, ns.(string), name.(string), container.(string), command)
193+
ret, err := s.k.Derived(ctx).PodsExec(ctx, ns.(string), name.(string), container.(string), command)
194194
if err != nil {
195195
return NewTextResult("", fmt.Errorf("failed to exec in pod %s in namespace %s: %v", name, ns, err)), nil
196196
} else if ret == "" {
@@ -212,7 +212,7 @@ func (s *Server) podsLog(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
212212
if container == nil {
213213
container = ""
214214
}
215-
ret, err := s.k.PodsLog(ctx, ns.(string), name.(string), container.(string))
215+
ret, err := s.k.Derived(ctx).PodsLog(ctx, ns.(string), name.(string), container.(string))
216216
if err != nil {
217217
return NewTextResult("", fmt.Errorf("failed to get pod %s log in namespace %s: %v", name, ns, err)), nil
218218
} else if ret == "" {
@@ -238,7 +238,7 @@ func (s *Server) podsRun(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
238238
if port == nil {
239239
port = float64(0)
240240
}
241-
ret, err := s.k.PodsRun(ctx, ns.(string), name.(string), image.(string), int32(port.(float64)))
241+
ret, err := s.k.Derived(ctx).PodsRun(ctx, ns.(string), name.(string), image.(string), int32(port.(float64)))
242242
if err != nil {
243243
return NewTextResult("", fmt.Errorf("failed to get pod %s log in namespace %s: %v", name, ns, err)), nil
244244
}

0 commit comments

Comments
 (0)