|
| 1 | +package http |
| 2 | + |
| 3 | +import ( |
| 4 | + "bufio" |
| 5 | + "bytes" |
| 6 | + "context" |
| 7 | + "fmt" |
| 8 | + "net" |
| 9 | + "net/http" |
| 10 | + "os" |
| 11 | + "path/filepath" |
| 12 | + "strings" |
| 13 | + "testing" |
| 14 | + "time" |
| 15 | + |
| 16 | + "golang.org/x/sync/errgroup" |
| 17 | + "k8s.io/client-go/tools/clientcmd" |
| 18 | + "k8s.io/client-go/tools/clientcmd/api" |
| 19 | + "k8s.io/klog/v2" |
| 20 | + "k8s.io/klog/v2/textlogger" |
| 21 | + |
| 22 | + "github.com/manusa/kubernetes-mcp-server/pkg/config" |
| 23 | + "github.com/manusa/kubernetes-mcp-server/pkg/mcp" |
| 24 | +) |
| 25 | + |
| 26 | +type httpContext struct { |
| 27 | + t *testing.T |
| 28 | + klogState klog.State |
| 29 | + logBuffer bytes.Buffer |
| 30 | + httpAddress string // HTTP server address |
| 31 | + timeoutCancel context.CancelFunc // Release resources if test completes before the timeout |
| 32 | + stopServer context.CancelFunc |
| 33 | + waitForShutdown func() error |
| 34 | +} |
| 35 | + |
| 36 | +func (c *httpContext) beforeEach() { |
| 37 | + http.DefaultClient.Timeout = 10 * time.Second |
| 38 | + // Fake Kubernetes configuration |
| 39 | + fakeConfig := api.NewConfig() |
| 40 | + fakeConfig.Clusters["fake"] = api.NewCluster() |
| 41 | + fakeConfig.Clusters["fake"].Server = "https://example.com" |
| 42 | + fakeConfig.Contexts["fake-context"] = api.NewContext() |
| 43 | + fakeConfig.Contexts["fake-context"].Cluster = "fake" |
| 44 | + fakeConfig.CurrentContext = "fake-context" |
| 45 | + kubeConfig := filepath.Join(c.t.TempDir(), "config") |
| 46 | + _ = clientcmd.WriteToFile(*fakeConfig, kubeConfig) |
| 47 | + _ = os.Setenv("KUBECONFIG", kubeConfig) |
| 48 | + // Capture logging |
| 49 | + c.klogState = klog.CaptureState() |
| 50 | + klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(1), textlogger.Output(&c.logBuffer)))) |
| 51 | + // Start server in random port |
| 52 | + ln, err := net.Listen("tcp", "0.0.0.0:0") |
| 53 | + if err != nil { |
| 54 | + c.t.Fatalf("Failed to find random port for HTTP server: %v", err) |
| 55 | + } |
| 56 | + c.httpAddress = ln.Addr().String() |
| 57 | + if randomPortErr := ln.Close(); randomPortErr != nil { |
| 58 | + c.t.Fatalf("Failed to close random port listener: %v", randomPortErr) |
| 59 | + } |
| 60 | + staticConfig := &config.StaticConfig{Port: fmt.Sprintf("%d", ln.Addr().(*net.TCPAddr).Port)} |
| 61 | + mcpServer, err := mcp.NewServer(mcp.Configuration{ |
| 62 | + Profile: mcp.Profiles[0], |
| 63 | + StaticConfig: staticConfig, |
| 64 | + }) |
| 65 | + if err != nil { |
| 66 | + c.t.Fatalf("Failed to create MCP server: %v", err) |
| 67 | + } |
| 68 | + var timeoutCtx, cancelCtx context.Context |
| 69 | + timeoutCtx, c.timeoutCancel = context.WithTimeout(c.t.Context(), 10*time.Second) |
| 70 | + group, gc := errgroup.WithContext(timeoutCtx) |
| 71 | + cancelCtx, c.stopServer = context.WithCancel(gc) |
| 72 | + group.Go(func() error { return Serve(cancelCtx, mcpServer, staticConfig) }) |
| 73 | + c.waitForShutdown = group.Wait |
| 74 | + // Wait for HTTP server to start (using net) |
| 75 | + for i := 0; i < 10; i++ { |
| 76 | + conn, err := net.Dial("tcp", c.httpAddress) |
| 77 | + if err == nil { |
| 78 | + _ = conn.Close() |
| 79 | + break |
| 80 | + } |
| 81 | + time.Sleep(50 * time.Millisecond) // Wait before retrying |
| 82 | + } |
| 83 | +} |
| 84 | + |
| 85 | +func (c *httpContext) afterEach() { |
| 86 | + c.stopServer() |
| 87 | + err := c.waitForShutdown() |
| 88 | + if err != nil { |
| 89 | + c.t.Errorf("HTTP server did not shut down gracefully: %v", err) |
| 90 | + } |
| 91 | + c.timeoutCancel() |
| 92 | + c.klogState.Restore() |
| 93 | + _ = os.Setenv("KUBECONFIG", "") |
| 94 | +} |
| 95 | + |
| 96 | +func testCase(t *testing.T, test func(c *httpContext)) { |
| 97 | + ctx := &httpContext{t: t} |
| 98 | + ctx.beforeEach() |
| 99 | + t.Cleanup(ctx.afterEach) |
| 100 | + test(ctx) |
| 101 | +} |
| 102 | + |
| 103 | +func TestGracefulShutdown(t *testing.T) { |
| 104 | + testCase(t, func(ctx *httpContext) { |
| 105 | + ctx.stopServer() |
| 106 | + err := ctx.waitForShutdown() |
| 107 | + t.Run("Stops gracefully", func(t *testing.T) { |
| 108 | + if err != nil { |
| 109 | + t.Errorf("Expected graceful shutdown, but got error: %v", err) |
| 110 | + } |
| 111 | + }) |
| 112 | + t.Run("Stops on context cancel", func(t *testing.T) { |
| 113 | + if !strings.Contains(ctx.logBuffer.String(), "Context cancelled, initiating graceful shutdown") { |
| 114 | + t.Errorf("Context cancelled, initiating graceful shutdown, got: %s", ctx.logBuffer.String()) |
| 115 | + } |
| 116 | + }) |
| 117 | + t.Run("Starts server shutdown", func(t *testing.T) { |
| 118 | + if !strings.Contains(ctx.logBuffer.String(), "Shutting down HTTP server gracefully") { |
| 119 | + t.Errorf("Expected graceful shutdown log, got: %s", ctx.logBuffer.String()) |
| 120 | + } |
| 121 | + }) |
| 122 | + t.Run("Server shutdown completes", func(t *testing.T) { |
| 123 | + if !strings.Contains(ctx.logBuffer.String(), "HTTP server shutdown complete") { |
| 124 | + t.Errorf("Expected HTTP server shutdown completed log, got: %s", ctx.logBuffer.String()) |
| 125 | + } |
| 126 | + }) |
| 127 | + }) |
| 128 | +} |
| 129 | + |
| 130 | +func TestSseTransport(t *testing.T) { |
| 131 | + testCase(t, func(ctx *httpContext) { |
| 132 | + sseResp, sseErr := http.Get(fmt.Sprintf("http://%s/sse", ctx.httpAddress)) |
| 133 | + t.Cleanup(func() { _ = sseResp.Body.Close() }) |
| 134 | + t.Run("Exposes SSE endpoint at /sse", func(t *testing.T) { |
| 135 | + if sseErr != nil { |
| 136 | + t.Fatalf("Failed to get SSE endpoint: %v", sseErr) |
| 137 | + } |
| 138 | + if sseResp.StatusCode != http.StatusOK { |
| 139 | + t.Errorf("Expected HTTP 200 OK, got %d", sseResp.StatusCode) |
| 140 | + } |
| 141 | + }) |
| 142 | + t.Run("SSE endpoint returns text/event-stream content type", func(t *testing.T) { |
| 143 | + if sseResp.Header.Get("Content-Type") != "text/event-stream" { |
| 144 | + t.Errorf("Expected Content-Type text/event-stream, got %s", sseResp.Header.Get("Content-Type")) |
| 145 | + } |
| 146 | + }) |
| 147 | + responseReader := bufio.NewReader(sseResp.Body) |
| 148 | + event, eventErr := responseReader.ReadString('\n') |
| 149 | + endpoint, endpointErr := responseReader.ReadString('\n') |
| 150 | + t.Run("SSE endpoint returns stream with messages endpoint", func(t *testing.T) { |
| 151 | + if eventErr != nil { |
| 152 | + t.Fatalf("Failed to read SSE response body (event): %v", eventErr) |
| 153 | + } |
| 154 | + if event != "event: endpoint\n" { |
| 155 | + t.Errorf("Expected SSE event 'endpoint', got %s", event) |
| 156 | + } |
| 157 | + if endpointErr != nil { |
| 158 | + t.Fatalf("Failed to read SSE response body (endpoint): %v", endpointErr) |
| 159 | + } |
| 160 | + if !strings.HasPrefix(endpoint, "data: /message?sessionId=") { |
| 161 | + t.Errorf("Expected SSE data: '/message', got %s", endpoint) |
| 162 | + } |
| 163 | + }) |
| 164 | + messageResp, messageErr := http.Post( |
| 165 | + fmt.Sprintf("http://%s/message?sessionId=%s", ctx.httpAddress, strings.TrimSpace(endpoint[25:])), |
| 166 | + "application/json", |
| 167 | + bytes.NewBufferString("{}"), |
| 168 | + ) |
| 169 | + t.Cleanup(func() { _ = messageResp.Body.Close() }) |
| 170 | + t.Run("Exposes message endpoint at /message", func(t *testing.T) { |
| 171 | + if messageErr != nil { |
| 172 | + t.Fatalf("Failed to get message endpoint: %v", messageErr) |
| 173 | + } |
| 174 | + if messageResp.StatusCode != http.StatusAccepted { |
| 175 | + t.Errorf("Expected HTTP 202 OK, got %d", messageResp.StatusCode) |
| 176 | + } |
| 177 | + }) |
| 178 | + }) |
| 179 | +} |
| 180 | + |
| 181 | +func TestStreamableHttpTransport(t *testing.T) { |
| 182 | + testCase(t, func(ctx *httpContext) { |
| 183 | + mcpGetResp, mcpGetErr := http.Get(fmt.Sprintf("http://%s/mcp", ctx.httpAddress)) |
| 184 | + t.Cleanup(func() { _ = mcpGetResp.Body.Close() }) |
| 185 | + t.Run("Exposes MCP GET endpoint at /mcp", func(t *testing.T) { |
| 186 | + if mcpGetErr != nil { |
| 187 | + t.Fatalf("Failed to get MCP endpoint: %v", mcpGetErr) |
| 188 | + } |
| 189 | + if mcpGetResp.StatusCode != http.StatusOK { |
| 190 | + t.Errorf("Expected HTTP 200 OK, got %d", mcpGetResp.StatusCode) |
| 191 | + } |
| 192 | + }) |
| 193 | + t.Run("MCP GET endpoint returns text/event-stream content type", func(t *testing.T) { |
| 194 | + if mcpGetResp.Header.Get("Content-Type") != "text/event-stream" { |
| 195 | + t.Errorf("Expected Content-Type text/event-stream (GET), got %s", mcpGetResp.Header.Get("Content-Type")) |
| 196 | + } |
| 197 | + }) |
| 198 | + mcpPostResp, mcpPostErr := http.Post(fmt.Sprintf("http://%s/mcp", ctx.httpAddress), "application/json", bytes.NewBufferString("{}")) |
| 199 | + t.Cleanup(func() { _ = mcpPostResp.Body.Close() }) |
| 200 | + t.Run("Exposes MCP POST endpoint at /mcp", func(t *testing.T) { |
| 201 | + if mcpPostErr != nil { |
| 202 | + t.Fatalf("Failed to post to MCP endpoint: %v", mcpPostErr) |
| 203 | + } |
| 204 | + if mcpPostResp.StatusCode != http.StatusOK { |
| 205 | + t.Errorf("Expected HTTP 200 OK, got %d", mcpPostResp.StatusCode) |
| 206 | + } |
| 207 | + }) |
| 208 | + t.Run("MCP POST endpoint returns application/json content type", func(t *testing.T) { |
| 209 | + if mcpPostResp.Header.Get("Content-Type") != "application/json" { |
| 210 | + t.Errorf("Expected Content-Type application/json (POST), got %s", mcpPostResp.Header.Get("Content-Type")) |
| 211 | + } |
| 212 | + }) |
| 213 | + }) |
| 214 | +} |
| 215 | + |
| 216 | +func TestHealthCheck(t *testing.T) { |
| 217 | + testCase(t, func(ctx *httpContext) { |
| 218 | + t.Run("Exposes health check endpoint at /healthz", func(t *testing.T) { |
| 219 | + resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.httpAddress)) |
| 220 | + if err != nil { |
| 221 | + t.Fatalf("Failed to get health check endpoint: %v", err) |
| 222 | + } |
| 223 | + t.Cleanup(func() { _ = resp.Body.Close }) |
| 224 | + if resp.StatusCode != http.StatusOK { |
| 225 | + t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) |
| 226 | + } |
| 227 | + }) |
| 228 | + }) |
| 229 | +} |
| 230 | + |
| 231 | +func TestWellKnownOAuthProtectedResource(t *testing.T) { |
| 232 | + testCase(t, func(ctx *httpContext) { |
| 233 | + resp, err := http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.httpAddress)) |
| 234 | + t.Cleanup(func() { _ = resp.Body.Close() }) |
| 235 | + t.Run("Exposes .well-known/oauth-protected-resource endpoint", func(t *testing.T) { |
| 236 | + if err != nil { |
| 237 | + t.Fatalf("Failed to get .well-known/oauth-protected-resource endpoint: %v", err) |
| 238 | + } |
| 239 | + if resp.StatusCode != http.StatusOK { |
| 240 | + t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) |
| 241 | + } |
| 242 | + }) |
| 243 | + t.Run(".well-known/oauth-protected-resource returns application/json content type", func(t *testing.T) { |
| 244 | + if resp.Header.Get("Content-Type") != "application/json" { |
| 245 | + t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type")) |
| 246 | + } |
| 247 | + }) |
| 248 | + }) |
| 249 | +} |
0 commit comments