From 8d9a749a01af10548fa52167d957d52113c58d93 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Sun, 25 May 2025 16:48:29 -0400 Subject: [PATCH 01/15] feat(publisher): integrate GitHub Client ID into device flow authentication --- internal/model/model.go | 9 +++-- tools/publisher/main.go | 75 ++++++++++++++++++++++++----------------- 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/internal/model/model.go b/internal/model/model.go index 7090ff55..4c0d4476 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -20,8 +20,7 @@ type Authentication struct { // PublishRequest represents a request to publish a server to the registry type PublishRequest struct { ServerDetail `json:",inline"` - Authentication Authentication `json:"-"` // Now provided via Authorization header - AuthStatusToken string `json:"-"` // Used internally for device flows + AuthStatusToken string `json:"-"` // Used internally for device flows } // Repository represents a source code repository as defined in the spec @@ -121,7 +120,7 @@ type Server struct { // ServerDetail represents detailed server information as defined in the spec type ServerDetail struct { - Server `json:",inline" bson:",inline"` - Packages []Package `json:"packages,omitempty" bson:"packages,omitempty"` - Remotes []Remote `json:"remotes,omitempty" bson:"remotes,omitempty"` + Server `json:",inline" bson:",inline"` + Packages []Package `json:"packages,omitempty" bson:"packages,omitempty"` + Remotes []Remote `json:"remotes,omitempty" bson:"remotes,omitempty"` } diff --git a/tools/publisher/main.go b/tools/publisher/main.go index 638076f0..9ed544ed 100644 --- a/tools/publisher/main.go +++ b/tools/publisher/main.go @@ -18,9 +18,6 @@ const ( // GitHub OAuth URLs GitHubDeviceCodeURL = "https://github.com/login/device/code" GitHubAccessTokenURL = "https://github.com/login/oauth/access_token" - - // Environment variable for GitHub Client ID - EnvGithubClientID = "MCP_REGISTRY_GITHUB_CLIENT_ID" ) // DeviceCodeResponse represents the response from GitHub's device code endpoint @@ -40,6 +37,11 @@ type AccessTokenResponse struct { Error string `json:"error,omitempty"` } +type ServerHealthResponse struct { + Status string `json:"status"` + GitHubClientId string `json:"github_client_id"` +} + func main() { var registryURL string var mcpFilePath string @@ -58,19 +60,32 @@ func main() { return } - // Check for GitHub client ID in environment if we're going to need it for authentication - if providedToken == "" && os.Getenv(EnvGithubClientID) == "" { - fmt.Printf("Warning: Environment variable %s is not set. This is required for GitHub authentication.\n", EnvGithubClientID) - fmt.Println("You can set it with: export " + EnvGithubClientID + "=your_github_client_id") - fmt.Println("Or provide a token directly with the --token flag.") - - // Only return if we'll need to do GitHub auth - _, statErr := os.Stat(tokenFilePath) - if forceLogin || os.IsNotExist(statErr) { - return - } + // get the clientID from the server's health endpoint + healthURL := registryURL + "/v0/health" + resp, err := http.Get(healthURL) + if err != nil { + fmt.Printf("Error fetching health endpoint: %s\n", err.Error()) + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + fmt.Printf("Health endpoint returned status %d: %s\n", resp.StatusCode, body) + return + } + var healthResponse ServerHealthResponse + err = json.NewDecoder(resp.Body).Decode(&healthResponse) + if err != nil { + fmt.Printf("Error decoding health response: %s\n", err.Error()) + return + } + if healthResponse.GitHubClientId == "" { + fmt.Println("GitHub Client ID is not set in the server's health response.") + return } + githubClientID := healthResponse.GitHubClientId + var token string // If a token is provided via the command line, use it @@ -80,7 +95,7 @@ func main() { // Check if token exists or force login is requested _, statErr := os.Stat(tokenFilePath) if forceLogin || os.IsNotExist(statErr) { - err := performDeviceFlowLogin() + err := performDeviceFlowLogin(githubClientID) if err != nil { fmt.Printf("Failed to perform device flow login: %s\n", err.Error()) return @@ -113,15 +128,15 @@ func main() { fmt.Println("Successfully published to registry!") } -func performDeviceFlowLogin() error { - // Check if the environment variable is set - if os.Getenv(EnvGithubClientID) == "" { - return fmt.Errorf("environment variable %s must be set for GitHub authentication", EnvGithubClientID) +func performDeviceFlowLogin(githubClientID string) error { + + if githubClientID == "" { + return fmt.Errorf("GitHub Client ID is required for device flow login") } // Device flow login logic using GitHub's device flow // First, request a device code - deviceCode, userCode, verificationURI, err := requestDeviceCode() + deviceCode, userCode, verificationURI, err := requestDeviceCode(githubClientID) if err != nil { return fmt.Errorf("error requesting device code: %w", err) } @@ -134,7 +149,7 @@ func performDeviceFlowLogin() error { // Poll for the token fmt.Println("Waiting for authorization...") - token, err := pollForToken(deviceCode) + token, err := pollForToken(deviceCode, githubClientID) if err != nil { return fmt.Errorf("error polling for token: %w", err) } @@ -150,14 +165,13 @@ func performDeviceFlowLogin() error { } // requestDeviceCode initiates the device authorization flow -func requestDeviceCode() (string, string, string, error) { - clientID := os.Getenv(EnvGithubClientID) - if clientID == "" { - return "", "", "", fmt.Errorf("environment variable %s is not set", EnvGithubClientID) +func requestDeviceCode(githubClientID string) (string, string, string, error) { + if githubClientID == "" { + return "", "", "", fmt.Errorf("GitHub Client ID is required for device flow login") } payload := map[string]string{ - "client_id": clientID, + "client_id": githubClientID, "scope": "read:org read:user", } @@ -199,14 +213,13 @@ func requestDeviceCode() (string, string, string, error) { } // pollForToken polls for access token after user completes authorization -func pollForToken(deviceCode string) (string, error) { - clientID := os.Getenv(EnvGithubClientID) - if clientID == "" { - return "", fmt.Errorf("environment variable %s is not set", EnvGithubClientID) +func pollForToken(deviceCode, githubClientID string) (string, error) { + if githubClientID == "" { + return "", fmt.Errorf("GitHub Client ID is required for device flow login") } payload := map[string]string{ - "client_id": clientID, + "client_id": githubClientID, "device_code": deviceCode, "grant_type": "urn:ietf:params:oauth:grant-type:device_code", } From 8c18cb9a1e7ad83681abeb36c606d5fa34e756f1 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Sun, 25 May 2025 22:05:18 -0400 Subject: [PATCH 02/15] fix(publish): handle inline model corectly on publish --- internal/api/handlers/v0/publish.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/internal/api/handlers/v0/publish.go b/internal/api/handlers/v0/publish.go index 8d8f066a..9e999d5a 100644 --- a/internal/api/handlers/v0/publish.go +++ b/internal/api/handlers/v0/publish.go @@ -38,8 +38,13 @@ func PublishHandler(registry service.RegistryService, authService auth.Service) } // Get server details from the request - serverDetail := publishReq.ServerDetail + var serverDetail model.ServerDetail + err = json.Unmarshal(body, &serverDetail) + if err != nil { + http.Error(w, "Invalid server detail payload: "+err.Error(), http.StatusBadRequest) + return + } // Validate required fields if serverDetail.Name == "" { http.Error(w, "Name is required", http.StatusBadRequest) From feff57f336e0bae60932d618d0d190565aa42308 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Sun, 25 May 2025 22:47:16 -0400 Subject: [PATCH 03/15] feat(tests): add comprehensive tests for health and publish handlers --- go.mod | 4 + go.sum | 2 + internal/api/handlers/v0/health_test.go | 113 ++++++ internal/api/handlers/v0/publish_test.go | 479 +++++++++++++++++++++++ 4 files changed, 598 insertions(+) create mode 100644 internal/api/handlers/v0/health_test.go create mode 100644 internal/api/handlers/v0/publish_test.go diff --git a/go.mod b/go.mod index 8a59528d..73f43120 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.0 require ( github.com/caarlos0/env/v11 v11.3.1 github.com/google/uuid v1.6.0 + github.com/stretchr/testify v1.10.0 github.com/swaggo/files v1.0.1 github.com/swaggo/http-swagger v1.3.4 go.mongodb.org/mongo-driver v1.17.3 @@ -12,6 +13,7 @@ require ( require ( github.com/KyleBanks/depth v1.2.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-openapi/jsonpointer v0.21.1 // indirect github.com/go-openapi/jsonreference v0.21.0 // indirect github.com/go-openapi/spec v0.21.0 // indirect @@ -21,6 +23,8 @@ require ( github.com/klauspost/compress v1.16.7 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/montanaflynn/stats v0.7.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/swaggo/swag v1.16.4 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect diff --git a/go.sum b/go.sum index daa00519..5c7a5a65 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE= diff --git a/internal/api/handlers/v0/health_test.go b/internal/api/handlers/v0/health_test.go new file mode 100644 index 00000000..d0ddece0 --- /dev/null +++ b/internal/api/handlers/v0/health_test.go @@ -0,0 +1,113 @@ +package v0 + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/modelcontextprotocol/registry/internal/config" + "github.com/stretchr/testify/assert" +) + +func TestHealthHandler(t *testing.T) { + // Test cases + testCases := []struct { + name string + config *config.Config + expectedStatus int + expectedBody HealthResponse + }{ + { + name: "returns health status with github client id", + config: &config.Config{ + GithubClientID: "test-github-client-id", + }, + expectedStatus: http.StatusOK, + expectedBody: HealthResponse{ + Status: "ok", + GitHubClientId: "test-github-client-id", + }, + }, + { + name: "works with empty github client id", + config: &config.Config{ + GithubClientID: "", + }, + expectedStatus: http.StatusOK, + expectedBody: HealthResponse{ + Status: "ok", + GitHubClientId: "", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create handler with the test config + handler := HealthHandler(tc.config) + + // Create request + req, err := http.NewRequest("GET", "/health", nil) + if err != nil { + t.Fatal(err) + } + + // Create response recorder + rr := httptest.NewRecorder() + + // Call the handler + handler.ServeHTTP(rr, req) + + // Check status code + assert.Equal(t, tc.expectedStatus, rr.Code) + + // Check content type + assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) + + // Parse response body + var resp HealthResponse + err = json.NewDecoder(rr.Body).Decode(&resp) + assert.NoError(t, err) + + // Check the response body + assert.Equal(t, tc.expectedBody, resp) + }) + } +} + +// TestHealthHandlerIntegration tests the handler with actual HTTP requests +func TestHealthHandlerIntegration(t *testing.T) { + // Create test server + cfg := &config.Config{ + GithubClientID: "integration-test-client-id", + } + + server := httptest.NewServer(HealthHandler(cfg)) + defer server.Close() + + // Send request to the test server + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check status code + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Check content type + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + // Parse response body + var healthResp HealthResponse + err = json.NewDecoder(resp.Body).Decode(&healthResp) + assert.NoError(t, err) + + // Check the response body + expectedResp := HealthResponse{ + Status: "ok", + GitHubClientId: "integration-test-client-id", + } + assert.Equal(t, expectedResp, healthResp) +} diff --git a/internal/api/handlers/v0/publish_test.go b/internal/api/handlers/v0/publish_test.go new file mode 100644 index 00000000..508bc36c --- /dev/null +++ b/internal/api/handlers/v0/publish_test.go @@ -0,0 +1,479 @@ +package v0 + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/modelcontextprotocol/registry/internal/auth" + "github.com/modelcontextprotocol/registry/internal/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockRegistryService is a mock implementation of the RegistryService interface +type MockRegistryService struct { + mock.Mock +} + +func (m *MockRegistryService) List(cursor string, limit int) ([]model.Server, string, error) { + args := m.Called(cursor, limit) + return args.Get(0).([]model.Server), args.String(1), args.Error(2) +} + +func (m *MockRegistryService) GetByID(id string) (*model.ServerDetail, error) { + args := m.Called(id) + return args.Get(0).(*model.ServerDetail), args.Error(1) +} + +func (m *MockRegistryService) Publish(serverDetail *model.ServerDetail) error { + args := m.Called(serverDetail) + return args.Error(0) +} + +// MockAuthService is a mock implementation of the auth.Service interface +type MockAuthService struct { + mock.Mock +} + +func (m *MockAuthService) StartAuthFlow(ctx context.Context, method model.AuthMethod, repoRef string) (map[string]string, string, error) { + args := m.Called(ctx, method, repoRef) + return args.Get(0).(map[string]string), args.String(1), args.Error(2) +} + +func (m *MockAuthService) CheckAuthStatus(ctx context.Context, statusToken string) (string, error) { + args := m.Called(ctx, statusToken) + return args.String(0), args.Error(1) +} + +func (m *MockAuthService) ValidateAuth(ctx context.Context, authentication model.Authentication) (bool, error) { + args := m.Called(ctx, authentication) + return args.Bool(0), args.Error(1) +} + +func TestPublishHandler(t *testing.T) { + testCases := []struct { + name string + method string + requestBody interface{} + authHeader string + setupMocks func(*MockRegistryService, *MockAuthService) + expectedStatus int + expectedResponse map[string]string + expectedError string + }{ + { + name: "successful publish with GitHub auth", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "io.github.example/test-server", + Description: "A test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server", + Source: "github", + ID: "example/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer github_token_123", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + authSvc.On("ValidateAuth", mock.Anything, model.Authentication{ + Method: model.AuthMethodGitHub, + Token: "github_token_123", + RepoRef: "io.github.example/test-server", + }).Return(true, nil) + registry.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + }, + expectedStatus: http.StatusCreated, + expectedResponse: map[string]string{ + "message": "Server publication successful", + "id": "test-id", + }, + }, + { + name: "successful publish with no auth (AuthMethodNone)", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id-2", + Name: "example/test-server", + Description: "A test server without auth", + Repository: model.Repository{ + URL: "https://example.com/test-server", + Source: "example", + ID: "example/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer some_token", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + authSvc.On("ValidateAuth", mock.Anything, model.Authentication{ + Method: model.AuthMethodNone, + Token: "some_token", + RepoRef: "example/test-server", + }).Return(true, nil) + registry.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + }, + expectedStatus: http.StatusCreated, + expectedResponse: map[string]string{ + "message": "Server publication successful", + "id": "test-id-2", + }, + }, + { + name: "method not allowed", + method: http.MethodGet, + requestBody: nil, + authHeader: "", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) {}, + expectedStatus: http.StatusMethodNotAllowed, + expectedError: "Method not allowed", + }, + { + name: "missing request body", + method: http.MethodPost, + requestBody: "", + authHeader: "", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Invalid request payload:", + }, + { + name: "missing server name", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "", // Missing name + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Name is required", + }, + { + name: "missing version", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "", // Missing version + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Version is required", + }, + { + name: "missing authorization header", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "", // Missing auth header + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) {}, + expectedStatus: http.StatusUnauthorized, + expectedError: "Authorization header is required", + }, + { + name: "authentication required error", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer token", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + authSvc.On("ValidateAuth", mock.Anything, mock.Anything).Return(false, auth.ErrAuthRequired) + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "Authentication is required for publishing", + }, + { + name: "authentication failed", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer invalid_token", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + authSvc.On("ValidateAuth", mock.Anything, mock.Anything).Return(false, nil) + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "Invalid authentication credentials", + }, + { + name: "registry service error", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer token", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + authSvc.On("ValidateAuth", mock.Anything, mock.Anything).Return(true, nil) + registry.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(assert.AnError) + }, + expectedStatus: http.StatusInternalServerError, + expectedError: "Failed to publish server details:", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create mocks + mockRegistry := new(MockRegistryService) + mockAuthService := new(MockAuthService) + + // Setup mocks + tc.setupMocks(mockRegistry, mockAuthService) + + // Create handler + handler := PublishHandler(mockRegistry, mockAuthService) + + // Prepare request body + var requestBody []byte + if tc.requestBody != nil { + var err error + requestBody, err = json.Marshal(tc.requestBody) + assert.NoError(t, err) + } + + // Create request + req, err := http.NewRequest(tc.method, "/publish", bytes.NewBuffer(requestBody)) + assert.NoError(t, err) + + // Set auth header if provided + if tc.authHeader != "" { + req.Header.Set("Authorization", tc.authHeader) + } + + // Create response recorder + rr := httptest.NewRecorder() + + // Call the handler + handler.ServeHTTP(rr, req) + + // Check status code + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedResponse != nil { + // Check content type for successful responses + assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) + + // Parse and verify response body + var response map[string]string + err = json.NewDecoder(rr.Body).Decode(&response) + assert.NoError(t, err) + assert.Equal(t, tc.expectedResponse, response) + } + + if tc.expectedError != "" { + // Check that the error message is contained in the response + assert.Contains(t, rr.Body.String(), tc.expectedError) + } + + // Assert that all expectations were met + mockRegistry.AssertExpectations(t) + mockAuthService.AssertExpectations(t) + }) + } +} + +func TestPublishHandlerBearerTokenParsing(t *testing.T) { + testCases := []struct { + name string + authHeader string + expectedToken string + }{ + { + name: "bearer token with Bearer prefix", + authHeader: "Bearer github_token_123", + expectedToken: "github_token_123", + }, + { + name: "bearer token with bearer prefix (lowercase)", + authHeader: "bearer github_token_123", + expectedToken: "github_token_123", + }, + { + name: "token without Bearer prefix", + authHeader: "github_token_123", + expectedToken: "github_token_123", + }, + { + name: "mixed case Bearer prefix", + authHeader: "BeArEr github_token_123", + expectedToken: "github_token_123", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockRegistry := new(MockRegistryService) + mockAuthService := new(MockAuthService) + + // Setup mock to capture the actual token passed + mockAuthService.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { + return auth.Token == tc.expectedToken + })).Return(true, nil) + mockRegistry.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + + handler := PublishHandler(mockRegistry, mockAuthService) + + serverDetail := model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + } + + requestBody, err := json.Marshal(serverDetail) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/publish", bytes.NewBuffer(requestBody)) + assert.NoError(t, err) + req.Header.Set("Authorization", tc.authHeader) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusCreated, rr.Code) + mockAuthService.AssertExpectations(t) + }) + } +} + +func TestPublishHandlerAuthMethodSelection(t *testing.T) { + testCases := []struct { + name string + serverName string + expectedAuthMethod model.AuthMethod + }{ + { + name: "GitHub prefix triggers GitHub auth", + serverName: "io.github.example/test-server", + expectedAuthMethod: model.AuthMethodGitHub, + }, + { + name: "non-GitHub prefix uses no auth", + serverName: "example.com/test-server", + expectedAuthMethod: model.AuthMethodNone, + }, + { + name: "empty prefix uses no auth", + serverName: "test-server", + expectedAuthMethod: model.AuthMethodNone, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockRegistry := new(MockRegistryService) + mockAuthService := new(MockAuthService) + + // Setup mock to capture the auth method + mockAuthService.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { + return auth.Method == tc.expectedAuthMethod + })).Return(true, nil) + mockRegistry.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + + handler := PublishHandler(mockRegistry, mockAuthService) + + serverDetail := model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: tc.serverName, + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + } + + requestBody, err := json.Marshal(serverDetail) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/publish", bytes.NewBuffer(requestBody)) + assert.NoError(t, err) + req.Header.Set("Authorization", "Bearer test_token") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusCreated, rr.Code) + mockAuthService.AssertExpectations(t) + }) + } +} From 7c38a66aad09cfca8dada185c1d6b590c7b35bbe Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Mon, 26 May 2025 12:00:25 -0400 Subject: [PATCH 04/15] feat(tests): add integration tests for publish functionality and enhance error handling --- integrationtests/README.md | 98 +++ integrationtests/publish_integration_test.go | 736 +++++++++++++++++++ integrationtests/run_tests.sh | 24 + internal/api/handlers/v0/publish.go | 7 + internal/database/database.go | 9 +- internal/database/memory.go | 109 ++- 6 files changed, 956 insertions(+), 27 deletions(-) create mode 100644 integrationtests/README.md create mode 100644 integrationtests/publish_integration_test.go create mode 100755 integrationtests/run_tests.sh diff --git a/integrationtests/README.md b/integrationtests/README.md new file mode 100644 index 00000000..3fa2fd5f --- /dev/null +++ b/integrationtests/README.md @@ -0,0 +1,98 @@ +# Integration Tests + +This directory contains integration tests for the MCP Registry API using the fake service implementation. + +## Overview + +The integration tests are designed to test the complete flow of the publish endpoint using real service implementations (fake service) rather than mocks. This provides confidence that the entire request/response cycle works correctly. + +## Test Structure + +### `publish_integration_test.go` + +Contains comprehensive integration tests for the publish endpoint: + +- **TestPublishIntegration**: Tests various scenarios for publishing servers + - Successful publish with GitHub authentication + - Successful publish without authentication (for non-GitHub servers) + - Error cases: missing name, missing version, missing auth header, invalid JSON, unsupported HTTP methods + - Duplicate package handling: fails when same name+version, succeeds with different versions + +- **TestPublishIntegrationWithComplexPackages**: Tests publishing servers with complex package configurations + - Multiple runtime arguments (named and positional) + - Package arguments + - Environment variables (including secrets) + - Multiple remotes with different transport types + - Headers for HTTP remotes + +- **TestPublishIntegrationEndToEnd**: Tests the complete end-to-end flow + - Publishes a server and verifies it can be retrieved + - Checks that the server appears in the registry list + - Verifies count consistency + +## Mock Services + +### MockAuthService + +A simple mock implementation of the `auth.Service` interface that: +- Accepts any non-empty token for GitHub authentication +- Always allows authentication for `AuthMethodNone` +- Provides realistic responses for auth flow methods + +## Running the Tests + +From the project root directory: + +```bash +# Run all integration tests +go test ./integrationtests/... + +# Run with verbose output +go test -v ./integrationtests/... + +# Run a specific test +go test -v ./integrationtests/ -run TestPublishIntegration + +# Run tests with race detection +go test -race ./integrationtests/... + +# Use the convenient test runner script +./integrationtests/run_tests.sh +``` + +## Test Data + +The tests use the fake service which comes pre-populated with sample data: +- 3 sample MCP servers with different configurations +- Uses in-memory database for isolation between tests +- Each test creates unique server instances with UUIDs + +## Benefits of Integration Tests + +1. **Real Flow Testing**: Tests the actual HTTP request/response cycle +2. **Service Integration**: Validates that handlers work correctly with service implementations +3. **Data Persistence**: Verifies that published data can be retrieved +4. **Error Handling**: Tests complete error scenarios end-to-end +5. **Complex Scenarios**: Tests realistic server configurations with packages and remotes + +## Dependencies + +These tests use: +- `testify/assert` and `testify/require` for assertions +- `httptest` for HTTP testing utilities +- The fake service implementation for realistic data operations +- Standard Go testing package + +## Test Coverage + +The integration tests cover: +- ✅ Successful publish scenarios +- ✅ Authentication validation +- ✅ Input validation +- ✅ Duplicate package handling +- ✅ Complex package configurations +- ✅ Multiple remotes +- ✅ Error handling +- ✅ End-to-end data flow +- ✅ HTTP method validation +- ✅ JSON parsing errors diff --git a/integrationtests/publish_integration_test.go b/integrationtests/publish_integration_test.go new file mode 100644 index 00000000..574d75ad --- /dev/null +++ b/integrationtests/publish_integration_test.go @@ -0,0 +1,736 @@ +package integrationtests + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" + "github.com/modelcontextprotocol/registry/internal/auth" + "github.com/modelcontextprotocol/registry/internal/model" + "github.com/modelcontextprotocol/registry/internal/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockAuthService implements a simple auth service for testing +type MockAuthService struct{} + +func (m *MockAuthService) StartAuthFlow(ctx context.Context, method model.AuthMethod, repoRef string) (map[string]string, string, error) { + return map[string]string{ + "device_code": "mock_device_code", + "user_code": "ABCD-1234", + "verification_uri": "https://github.com/login/device", + }, "mock_status_token", nil +} + +func (m *MockAuthService) CheckAuthStatus(ctx context.Context, statusToken string) (string, error) { + if statusToken == "mock_status_token" { + return "mock_access_token", nil + } + return "", fmt.Errorf("invalid status token") +} + +func (m *MockAuthService) ValidateAuth(ctx context.Context, authentication model.Authentication) (bool, error) { + // Simple validation: for testing purposes, accept any non-empty token + switch authentication.Method { + case model.AuthMethodGitHub: + return authentication.Token != "", nil + case model.AuthMethodNone: + return true, nil + default: + return false, auth.ErrUnsupportedAuthMethod + } +} + +// TestPublishIntegration tests the complete flow of publishing a server using the fake service +func TestPublishIntegration(t *testing.T) { + // Setup fake service and auth service + registryService := service.NewFakeRegistryService() + authService := &MockAuthService{} + + // Create the publish handler + handler := v0.PublishHandler(registryService, authService) + + t.Run("successful publish with GitHub auth", func(t *testing.T) { + serverDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "io.github.testuser/test-mcp-server", + Description: "A test MCP server for integration testing", + Repository: model.Repository{ + URL: "https://github.com/testuser/test-mcp-server", + Source: "github", + ID: "testuser/test-mcp-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + Packages: []model.Package{ + { + RegistryName: "npm", + Name: "test-mcp-server", + Version: "1.0.0", + RunTimeHint: "node", + RuntimeArguments: []model.Argument{ + { + Type: model.ArgumentTypeNamed, + Name: "config", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Configuration file path", + Format: model.FormatFilePath, + IsRequired: true, + }, + }, + }, + }, + }, + }, + Remotes: []model.Remote{ + { + TransportType: "http", + URL: "http://localhost:3000/mcp", + }, + }, + } + + // Marshal the server detail to JSON + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + // Create a request + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_test_token_123") + + // Create a response recorder + recorder := httptest.NewRecorder() + + // Call the handler + handler(recorder, req) + + // Check the response + assert.Equal(t, http.StatusCreated, recorder.Code) + + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, "Server publication successful", response["message"]) + assert.Equal(t, serverDetail.ID, response["id"]) + + // Verify the server was actually published by retrieving it + publishedServer, err := registryService.GetByID(serverDetail.ID) + require.NoError(t, err) + assert.Equal(t, serverDetail.Name, publishedServer.Name) + assert.Equal(t, serverDetail.Description, publishedServer.Description) + assert.Equal(t, serverDetail.VersionDetail.Version, publishedServer.VersionDetail.Version) + assert.Len(t, publishedServer.Packages, 1) + assert.Len(t, publishedServer.Remotes, 1) + }) + + t.Run("successful publish without auth (no prefix)", func(t *testing.T) { + serverDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "custom-mcp-server", + Description: "A custom MCP server without auth", + Repository: model.Repository{ + URL: "https://example.com/custom-server", + Source: "custom", + ID: "custom/custom-server", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + } + + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "dummy_token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusCreated, recorder.Code) + + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, "Server publication successful", response["message"]) + assert.Equal(t, serverDetail.ID, response["id"]) + }) + + t.Run("publish fails with missing name", func(t *testing.T) { + serverDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "", // Missing name + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + } + + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Name is required") + }) + + t.Run("publish fails with missing version", func(t *testing.T) { + serverDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "", // Missing version + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + } + + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Version is required") + }) + + t.Run("publish fails with missing authorization header", func(t *testing.T) { + serverDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + } + + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + // No Authorization header + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Authorization header is required") + }) + + t.Run("publish fails with invalid JSON", func(t *testing.T) { + invalidJSON := `{"name": "test", "version": ` + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBufferString(invalidJSON)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Invalid") + }) + + t.Run("publish fails with unsupported HTTP method", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/publish", nil) + req.Header.Set("Authorization", "Bearer token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusMethodNotAllowed, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Method not allowed") + }) + + t.Run("publish fails with duplicate name and version", func(t *testing.T) { + // First, publish a server successfully + firstServerDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "io.github.duplicate/test-server", + Description: "First server for duplicate test", + Repository: model.Repository{ + URL: "https://github.com/duplicate/test-server", + Source: "github", + ID: "duplicate/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + } + + jsonData, err := json.Marshal(firstServerDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_token_first") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusCreated, recorder.Code, "First publish should succeed") + + // Now try to publish another server with the same name and version + duplicateServerDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), // Different ID + Name: "io.github.duplicate/test-server", // Same name + Description: "Duplicate server attempt", + Repository: model.Repository{ + URL: "https://github.com/duplicate/test-server-fork", + Source: "github", + ID: "duplicate/test-server-fork", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", // Same version + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + } + + duplicateJsonData, err := json.Marshal(duplicateServerDetail) + require.NoError(t, err) + + duplicateReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(duplicateJsonData)) + duplicateReq.Header.Set("Content-Type", "application/json") + duplicateReq.Header.Set("Authorization", "Bearer github_token_duplicate") + + duplicateRecorder := httptest.NewRecorder() + handler(duplicateRecorder, duplicateReq) + + // The duplicate should fail + assert.Equal(t, http.StatusBadRequest, duplicateRecorder.Code) + assert.Contains(t, duplicateRecorder.Body.String(), "Failed to publish server details") + + // Verify that only the first server was actually stored + retrievedServer, err := registryService.GetByID(firstServerDetail.ID) + require.NoError(t, err) + assert.Equal(t, firstServerDetail.Name, retrievedServer.Name) + assert.Equal(t, firstServerDetail.Description, retrievedServer.Description) + + // Try to get the duplicate - it should not exist + _, err = registryService.GetByID(duplicateServerDetail.ID) + assert.Error(t, err, "Duplicate server should not have been stored") + }) + + t.Run("publish succeeds with same name but different version", func(t *testing.T) { + // Publish first version + firstVersionDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "io.github.versioned/test-server", + Description: "First version of the server", + Repository: model.Repository{ + URL: "https://github.com/versioned/test-server", + Source: "github", + ID: "versioned/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: false, + }, + }, + } + + jsonData, err := json.Marshal(firstVersionDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_token_v1") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusCreated, recorder.Code, "First version should succeed") + + // Publish second version with same name but different version + secondVersionDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "io.github.versioned/test-server", // Same name + Description: "Second version of the server", + Repository: model.Repository{ + URL: "https://github.com/versioned/test-server", + Source: "github", + ID: "versioned/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", // Different version + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + } + + secondJsonData, err := json.Marshal(secondVersionDetail) + require.NoError(t, err) + + secondReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(secondJsonData)) + secondReq.Header.Set("Content-Type", "application/json") + secondReq.Header.Set("Authorization", "Bearer github_token_v2") + + secondRecorder := httptest.NewRecorder() + handler(secondRecorder, secondReq) + + // The second version should succeed + assert.Equal(t, http.StatusCreated, secondRecorder.Code) + + // Verify both versions exist + firstRetrieved, err := registryService.GetByID(firstVersionDetail.ID) + require.NoError(t, err) + assert.Equal(t, "1.0.0", firstRetrieved.VersionDetail.Version) + + secondRetrieved, err := registryService.GetByID(secondVersionDetail.ID) + require.NoError(t, err) + assert.Equal(t, "2.0.0", secondRetrieved.VersionDetail.Version) + }) + + t.Run("publish fails when trying to publish older version after newer version", func(t *testing.T) { + // First, publish a newer version (2.0.0) + newerVersionDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "io.github.versioning/version-order-test", + Description: "Newer version published first", + Repository: model.Repository{ + URL: "https://github.com/versioning/version-order-test", + Source: "github", + ID: "versioning/version-order-test", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + } + + jsonData, err := json.Marshal(newerVersionDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_token_newer") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusCreated, recorder.Code, "Newer version should be published successfully") + + // Now try to publish an older version (1.0.0) of the same package + olderVersionDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "io.github.versioning/version-order-test", // Same name + Description: "Older version published after newer", + Repository: model.Repository{ + URL: "https://github.com/versioning/version-order-test", + Source: "github", + ID: "versioning/version-order-test", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", // Older version + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: false, + }, + }, + } + + olderJsonData, err := json.Marshal(olderVersionDetail) + require.NoError(t, err) + + olderReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(olderJsonData)) + olderReq.Header.Set("Content-Type", "application/json") + olderReq.Header.Set("Authorization", "Bearer github_token_older") + + olderRecorder := httptest.NewRecorder() + handler(olderRecorder, olderReq) + + // This should fail - we shouldn't allow publishing older versions after newer ones + assert.Equal(t, http.StatusBadRequest, olderRecorder.Code, "Publishing older version should fail") + assert.Contains(t, olderRecorder.Body.String(), "version", "Error message should mention version") + + // Verify that only the newer version exists + newerRetrieved, err := registryService.GetByID(newerVersionDetail.ID) + require.NoError(t, err) + assert.Equal(t, "2.0.0", newerRetrieved.VersionDetail.Version) + + // Verify the older version was not stored + _, err = registryService.GetByID(olderVersionDetail.ID) + assert.Error(t, err, "Older version should not have been stored") + }) +} + +// TestPublishIntegrationWithComplexPackages tests publishing with complex package configurations +func TestPublishIntegrationWithComplexPackages(t *testing.T) { + registryService := service.NewFakeRegistryService() + authService := &MockAuthService{} + handler := v0.PublishHandler(registryService, authService) + + t.Run("publish with complex package configuration", func(t *testing.T) { + serverDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "io.github.complex/advanced-mcp-server", + Description: "An advanced MCP server with complex configuration", + Repository: model.Repository{ + URL: "https://github.com/complex/advanced-mcp-server", + Source: "github", + ID: "complex/advanced-mcp-server", + }, + VersionDetail: model.VersionDetail{ + Version: "2.1.0", + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + Packages: []model.Package{ + { + RegistryName: "npm", + Name: "advanced-mcp-server", + Version: "2.1.0", + RunTimeHint: "node --experimental-modules", + RuntimeArguments: []model.Argument{ + { + Type: model.ArgumentTypeNamed, + Name: "config", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Main configuration file", + Format: model.FormatFilePath, + IsRequired: true, + Default: "./config.json", + }, + }, + }, + { + Type: model.ArgumentTypePositional, + Name: "mode", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Operation mode", + Format: model.FormatString, + IsRequired: false, + Default: "production", + Choices: []string{"development", "staging", "production"}, + }, + }, + }, + }, + PackageArguments: []model.Argument{ + { + Type: model.ArgumentTypeNamed, + Name: "install-deps", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Install dependencies", + Format: model.FormatBoolean, + Default: "true", + }, + }, + }, + }, + EnvironmentVariables: []model.KeyValueInput{ + { + Name: "LOG_LEVEL", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Logging level", + Format: model.FormatString, + Default: "info", + Choices: []string{"debug", "info", "warn", "error"}, + }, + }, + }, + { + Name: "API_KEY", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "API key for external service", + Format: model.FormatString, + IsRequired: true, + IsSecret: true, + }, + }, + }, + }, + }, + }, + Remotes: []model.Remote{ + { + TransportType: "http", + URL: "http://localhost:8080/mcp", + Headers: []model.Input{ + { + Description: "API Version Header", + Format: model.FormatString, + Value: "v1", + }, + }, + }, + { + TransportType: "websocket", + URL: "ws://localhost:8081/mcp", + }, + }, + } + + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_complex_token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusCreated, recorder.Code) + + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, "Server publication successful", response["message"]) + assert.Equal(t, serverDetail.ID, response["id"]) + + // Verify the complex server was published correctly + publishedServer, err := registryService.GetByID(serverDetail.ID) + require.NoError(t, err) + + // Verify package details + require.Len(t, publishedServer.Packages, 1) + pkg := publishedServer.Packages[0] + assert.Equal(t, "npm", pkg.RegistryName) + assert.Equal(t, "advanced-mcp-server", pkg.Name) + assert.Equal(t, "node --experimental-modules", pkg.RunTimeHint) + assert.Len(t, pkg.RuntimeArguments, 2) + assert.Len(t, pkg.PackageArguments, 1) + assert.Len(t, pkg.EnvironmentVariables, 2) + + // Verify remotes + require.Len(t, publishedServer.Remotes, 2) + assert.Equal(t, "http", publishedServer.Remotes[0].TransportType) + assert.Equal(t, "websocket", publishedServer.Remotes[1].TransportType) + assert.Len(t, publishedServer.Remotes[0].Headers, 1) + }) +} + +// TestPublishIntegrationEndToEnd tests the complete end-to-end flow +func TestPublishIntegrationEndToEnd(t *testing.T) { + registryService := service.NewFakeRegistryService() + authService := &MockAuthService{} + handler := v0.PublishHandler(registryService, authService) + + t.Run("end-to-end publish and retrieve flow", func(t *testing.T) { + // Step 1: Get initial count of servers + initialServers, _, err := registryService.List("", 100) + require.NoError(t, err) + initialCount := len(initialServers) + + // Step 2: Publish a new server + serverDetail := &model.ServerDetail{ + Server: model.Server{ + ID: uuid.New().String(), + Name: "io.github.e2e/end-to-end-server", + Description: "End-to-end test server", + Repository: model.Repository{ + URL: "https://github.com/e2e/end-to-end-server", + Source: "github", + ID: "e2e/end-to-end-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: time.Now().Format(time.RFC3339), + IsLatest: true, + }, + }, + } + + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_e2e_token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + require.Equal(t, http.StatusCreated, recorder.Code) + + // Step 3: Verify the count increased + updatedServers, _, err := registryService.List("", 100) + require.NoError(t, err) + assert.Equal(t, initialCount+1, len(updatedServers)) + + // Step 4: Verify the server can be retrieved by ID + retrievedServer, err := registryService.GetByID(serverDetail.ID) + require.NoError(t, err) + assert.Equal(t, serverDetail.Name, retrievedServer.Name) + assert.Equal(t, serverDetail.Description, retrievedServer.Description) + + // Step 5: Verify the server appears in the list + found := false + for _, server := range updatedServers { + if server.ID == serverDetail.ID { + found = true + assert.Equal(t, serverDetail.Name, server.Name) + break + } + } + assert.True(t, found, "Published server should appear in the list") + }) +} diff --git a/integrationtests/run_tests.sh b/integrationtests/run_tests.sh new file mode 100755 index 00000000..d035bb71 --- /dev/null +++ b/integrationtests/run_tests.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Integration Test Runner for MCP Registry +# This script runs the integration tests for the publish functionality + +echo "Running MCP Registry Integration Tests..." +echo "========================================" + +# Change to the project directory (parent of integrationtests) +cd "$(dirname "$0")/.." + +# Run integration tests with verbose output +echo "Running publish integration tests..." +go test -v ./integrationtests/... + +# Check exit code +if [ $? -eq 0 ]; then + echo "" + echo "✅ All integration tests passed!" +else + echo "" + echo "❌ Some integration tests failed!" + exit 1 +fi diff --git a/internal/api/handlers/v0/publish.go b/internal/api/handlers/v0/publish.go index 9e999d5a..b62ac6e8 100644 --- a/internal/api/handlers/v0/publish.go +++ b/internal/api/handlers/v0/publish.go @@ -3,11 +3,13 @@ package v0 import ( "encoding/json" + "errors" "io" "net/http" "strings" "github.com/modelcontextprotocol/registry/internal/auth" + "github.com/modelcontextprotocol/registry/internal/database" "github.com/modelcontextprotocol/registry/internal/model" "github.com/modelcontextprotocol/registry/internal/service" ) @@ -106,6 +108,11 @@ func PublishHandler(registry service.RegistryService, authService auth.Service) // Call the publish method on the registry service err = registry.Publish(&serverDetail) if err != nil { + // Check for specific error types and return appropriate HTTP status codes + if errors.Is(err, database.ErrInvalidVersion) || errors.Is(err, database.ErrAlreadyExists) { + http.Error(w, "Failed to publish server details: "+err.Error(), http.StatusBadRequest) + return + } http.Error(w, "Failed to publish server details: "+err.Error(), http.StatusInternalServerError) return } diff --git a/internal/database/database.go b/internal/database/database.go index 9c457bce..1d5fc4f8 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -9,10 +9,11 @@ import ( // Common database errors var ( - ErrNotFound = errors.New("record not found") - ErrAlreadyExists = errors.New("record already exists") - ErrInvalidInput = errors.New("invalid input") - ErrDatabase = errors.New("database error") + ErrNotFound = errors.New("record not found") + ErrAlreadyExists = errors.New("record already exists") + ErrInvalidInput = errors.New("invalid input") + ErrDatabase = errors.New("database error") + ErrInvalidVersion = errors.New("invalid version: cannot publish older version after newer version") ) // Database defines the interface for database operations on MCPRegistry entries diff --git a/internal/database/memory.go b/internal/database/memory.go index 6df02a28..fbec1396 100644 --- a/internal/database/memory.go +++ b/internal/database/memory.go @@ -3,6 +3,8 @@ package database import ( "context" "sort" + "strconv" + "strings" "sync" "github.com/modelcontextprotocol/registry/internal/model" @@ -10,17 +12,75 @@ import ( // MemoryDB is an in-memory implementation of the Database interface type MemoryDB struct { - entries map[string]*model.Server + entries map[string]*model.ServerDetail mu sync.RWMutex } // NewMemoryDB creates a new instance of the in-memory database func NewMemoryDB(e map[string]*model.Server) *MemoryDB { + // Convert Server entries to ServerDetail entries + serverDetails := make(map[string]*model.ServerDetail) + for k, v := range e { + serverDetails[k] = &model.ServerDetail{ + Server: *v, + } + } return &MemoryDB{ - entries: e, + entries: serverDetails, } } +// compareSemanticVersions compares two semantic version strings +// Returns: +// +// -1 if version1 < version2 +// 0 if version1 == version2 +// +1 if version1 > version2 +func compareSemanticVersions(version1, version2 string) int { + // Simple semantic version comparison + // Assumes format: major.minor.patch + + parts1 := strings.Split(version1, ".") + parts2 := strings.Split(version2, ".") + + // Pad with zeros if needed + maxLen := len(parts1) + if len(parts2) > maxLen { + maxLen = len(parts2) + } + + for len(parts1) < maxLen { + parts1 = append(parts1, "0") + } + for len(parts2) < maxLen { + parts2 = append(parts2, "0") + } + + // Compare each part + for i := 0; i < maxLen; i++ { + num1, err1 := strconv.Atoi(parts1[i]) + num2, err2 := strconv.Atoi(parts2[i]) + + // If parsing fails, fall back to string comparison + if err1 != nil || err2 != nil { + if parts1[i] < parts2[i] { + return -1 + } else if parts1[i] > parts2[i] { + return 1 + } + continue + } + + if num1 < num2 { + return -1 + } else if num1 > num2 { + return 1 + } + } + + return 0 +} + // List retrieves all MCPRegistry entries with optional filtering and pagination func (db *MemoryDB) List(ctx context.Context, filter map[string]interface{}, cursor string, limit int) ([]*model.Server, string, error) { if ctx.Err() != nil { @@ -37,8 +97,8 @@ func (db *MemoryDB) List(ctx context.Context, filter map[string]interface{}, cur // Convert all entries to a slice for pagination var allEntries []*model.Server for _, entry := range db.entries { - entryCopy := *entry - allEntries = append(allEntries, &entryCopy) + serverCopy := entry.Server + allEntries = append(allEntries, &serverCopy) } // Simple filtering implementation @@ -124,15 +184,9 @@ func (db *MemoryDB) GetByID(ctx context.Context, id string) (*model.ServerDetail defer db.mu.RUnlock() if entry, exists := db.entries[id]; exists { - return &model.ServerDetail{ - Server: model.Server{ - ID: entry.ID, - Name: entry.Name, - Description: entry.Description, - VersionDetail: entry.VersionDetail, - Repository: entry.Repository, - }, - }, nil + // Return a copy of the ServerDetail + serverDetailCopy := *entry + return &serverDetailCopy, nil } return nil, ErrNotFound @@ -153,24 +207,33 @@ func (db *MemoryDB) Publish(ctx context.Context, serverDetail *model.ServerDetai } // check that the name and the version are unique - + // Also check version ordering - don't allow publishing older versions after newer ones + var latestVersion string for _, entry := range db.entries { - if entry.Name == serverDetail.Name && entry.VersionDetail.Version == serverDetail.VersionDetail.Version { - return ErrAlreadyExists + if entry.Name == serverDetail.Name { + if entry.VersionDetail.Version == serverDetail.VersionDetail.Version { + return ErrAlreadyExists + } + + // Track the latest version for this package name + if latestVersion == "" || compareSemanticVersions(entry.VersionDetail.Version, latestVersion) > 0 { + latestVersion = entry.VersionDetail.Version + } } } + // If we found existing versions, check if the new version is older than the latest + if latestVersion != "" && compareSemanticVersions(serverDetail.VersionDetail.Version, latestVersion) < 0 { + return ErrInvalidVersion + } + if serverDetail.Repository.URL == "" { return ErrInvalidInput } - db.entries[serverDetail.ID] = &model.Server{ - ID: serverDetail.ID, - Name: serverDetail.Name, - Description: serverDetail.Description, - VersionDetail: serverDetail.VersionDetail, - Repository: serverDetail.Repository, - } + // Store a copy of the entire ServerDetail + serverDetailCopy := *serverDetail + db.entries[serverDetail.ID] = &serverDetailCopy return nil } From ec839b203f84609d51df59096f937ef541c9908b Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 10:55:48 -0400 Subject: [PATCH 05/15] update tests and memory implementation to follow real DB implementations as much as possible. Fix the tests to let the DB generate the server uuids rather than setting them up in the test. --- integrationtests/publish_integration_test.go | 240 +++++++++---------- internal/database/memory.go | 3 + 2 files changed, 118 insertions(+), 125 deletions(-) diff --git a/integrationtests/publish_integration_test.go b/integrationtests/publish_integration_test.go index 574d75ad..bcc72420 100644 --- a/integrationtests/publish_integration_test.go +++ b/integrationtests/publish_integration_test.go @@ -8,9 +8,7 @@ import ( "net/http" "net/http/httptest" "testing" - "time" - "github.com/google/uuid" v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" "github.com/modelcontextprotocol/registry/internal/auth" "github.com/modelcontextprotocol/registry/internal/model" @@ -59,53 +57,52 @@ func TestPublishIntegration(t *testing.T) { handler := v0.PublishHandler(registryService, authService) t.Run("successful publish with GitHub auth", func(t *testing.T) { - serverDetail := &model.ServerDetail{ - Server: model.Server{ - ID: uuid.New().String(), - Name: "io.github.testuser/test-mcp-server", - Description: "A test MCP server for integration testing", - Repository: model.Repository{ - URL: "https://github.com/testuser/test-mcp-server", - Source: "github", - ID: "testuser/test-mcp-server", - }, - VersionDetail: model.VersionDetail{ - Version: "1.0.0", - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + publishReq := model.PublishRequest{ + ServerDetail: model.ServerDetail{ + Server: model.Server{ + Name: "io.github.testuser/test-mcp-server", + Description: "A test MCP server for integration testing", + Repository: model.Repository{ + URL: "https://github.com/testuser/test-mcp-server", + Source: "github", + ID: "testuser/test-mcp-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + }, }, - }, - Packages: []model.Package{ - { - RegistryName: "npm", - Name: "test-mcp-server", - Version: "1.0.0", - RunTimeHint: "node", - RuntimeArguments: []model.Argument{ - { - Type: model.ArgumentTypeNamed, - Name: "config", - InputWithVariables: model.InputWithVariables{ - Input: model.Input{ - Description: "Configuration file path", - Format: model.FormatFilePath, - IsRequired: true, + Packages: []model.Package{ + { + RegistryName: "npm", + Name: "test-mcp-server", + Version: "1.0.0", + RunTimeHint: "node", + RuntimeArguments: []model.Argument{ + { + Type: model.ArgumentTypeNamed, + Name: "config", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Configuration file path", + Format: model.FormatFilePath, + IsRequired: true, + }, }, }, }, }, }, - }, - Remotes: []model.Remote{ - { - TransportType: "http", - URL: "http://localhost:3000/mcp", + Remotes: []model.Remote{ + { + TransportType: "http", + URL: "http://localhost:3000/mcp", + }, }, }, } // Marshal the server detail to JSON - jsonData, err := json.Marshal(serverDetail) + jsonData, err := json.Marshal(publishReq) require.NoError(t, err) // Create a request @@ -127,38 +124,37 @@ func TestPublishIntegration(t *testing.T) { require.NoError(t, err) assert.Equal(t, "Server publication successful", response["message"]) - assert.Equal(t, serverDetail.ID, response["id"]) + assert.NotEmpty(t, response["id"], "Server ID should be generated") // Verify the server was actually published by retrieving it - publishedServer, err := registryService.GetByID(serverDetail.ID) + publishedServer, err := registryService.GetByID(response["id"]) require.NoError(t, err) - assert.Equal(t, serverDetail.Name, publishedServer.Name) - assert.Equal(t, serverDetail.Description, publishedServer.Description) - assert.Equal(t, serverDetail.VersionDetail.Version, publishedServer.VersionDetail.Version) + assert.Equal(t, publishReq.ServerDetail.Name, publishedServer.Name) + assert.Equal(t, publishReq.ServerDetail.Description, publishedServer.Description) + assert.Equal(t, publishReq.ServerDetail.VersionDetail.Version, publishedServer.VersionDetail.Version) assert.Len(t, publishedServer.Packages, 1) assert.Len(t, publishedServer.Remotes, 1) }) t.Run("successful publish without auth (no prefix)", func(t *testing.T) { - serverDetail := &model.ServerDetail{ - Server: model.Server{ - ID: uuid.New().String(), - Name: "custom-mcp-server", - Description: "A custom MCP server without auth", - Repository: model.Repository{ - URL: "https://example.com/custom-server", - Source: "custom", - ID: "custom/custom-server", - }, - VersionDetail: model.VersionDetail{ - Version: "2.0.0", - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + publishReq := &model.PublishRequest{ + ServerDetail: model.ServerDetail{ + Server: model.Server{ + Name: "custom-mcp-server", + Description: "A custom MCP server without auth", + Repository: model.Repository{ + URL: "https://example.com/custom-server", + Source: "custom", + ID: "custom/custom-server", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + }, }, }, } - jsonData, err := json.Marshal(serverDetail) + jsonData, err := json.Marshal(publishReq) require.NoError(t, err) req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) @@ -175,24 +171,23 @@ func TestPublishIntegration(t *testing.T) { require.NoError(t, err) assert.Equal(t, "Server publication successful", response["message"]) - assert.Equal(t, serverDetail.ID, response["id"]) + assert.NotEmpty(t, response["id"], "Server ID should be generated") }) t.Run("publish fails with missing name", func(t *testing.T) { - serverDetail := &model.ServerDetail{ - Server: model.Server{ - ID: uuid.New().String(), - Name: "", // Missing name - Description: "A test server", - VersionDetail: model.VersionDetail{ - Version: "1.0.0", - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + publishReq := &model.PublishRequest{ + ServerDetail: model.ServerDetail{ + Server: model.Server{ + Name: "", // Missing name + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + }, }, }, } - jsonData, err := json.Marshal(serverDetail) + jsonData, err := json.Marshal(publishReq) require.NoError(t, err) req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) @@ -209,13 +204,10 @@ func TestPublishIntegration(t *testing.T) { t.Run("publish fails with missing version", func(t *testing.T) { serverDetail := &model.ServerDetail{ Server: model.Server{ - ID: uuid.New().String(), Name: "test-server", Description: "A test server", VersionDetail: model.VersionDetail{ - Version: "", // Missing version - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + Version: "", // Missing version }, }, } @@ -237,13 +229,10 @@ func TestPublishIntegration(t *testing.T) { t.Run("publish fails with missing authorization header", func(t *testing.T) { serverDetail := &model.ServerDetail{ Server: model.Server{ - ID: uuid.New().String(), Name: "test-server", Description: "A test server", VersionDetail: model.VersionDetail{ - Version: "1.0.0", - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + Version: "1.0.0", }, }, } @@ -291,7 +280,6 @@ func TestPublishIntegration(t *testing.T) { // First, publish a server successfully firstServerDetail := &model.ServerDetail{ Server: model.Server{ - ID: uuid.New().String(), Name: "io.github.duplicate/test-server", Description: "First server for duplicate test", Repository: model.Repository{ @@ -300,9 +288,7 @@ func TestPublishIntegration(t *testing.T) { ID: "duplicate/test-server", }, VersionDetail: model.VersionDetail{ - Version: "1.0.0", - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + Version: "1.0.0", }, }, } @@ -317,12 +303,17 @@ func TestPublishIntegration(t *testing.T) { recorder := httptest.NewRecorder() handler(recorder, req) + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, http.StatusCreated, recorder.Code, "First publish should succeed") + firstServerDetail.ID = response["id"] // Store the ID for later verification + // Now try to publish another server with the same name and version duplicateServerDetail := &model.ServerDetail{ Server: model.Server{ - ID: uuid.New().String(), // Different ID Name: "io.github.duplicate/test-server", // Same name Description: "Duplicate server attempt", Repository: model.Repository{ @@ -331,9 +322,7 @@ func TestPublishIntegration(t *testing.T) { ID: "duplicate/test-server-fork", }, VersionDetail: model.VersionDetail{ - Version: "1.0.0", // Same version - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + Version: "1.0.0", // Same version }, }, } @@ -358,16 +347,12 @@ func TestPublishIntegration(t *testing.T) { assert.Equal(t, firstServerDetail.Name, retrievedServer.Name) assert.Equal(t, firstServerDetail.Description, retrievedServer.Description) - // Try to get the duplicate - it should not exist - _, err = registryService.GetByID(duplicateServerDetail.ID) - assert.Error(t, err, "Duplicate server should not have been stored") }) t.Run("publish succeeds with same name but different version", func(t *testing.T) { // Publish first version firstVersionDetail := &model.ServerDetail{ Server: model.Server{ - ID: uuid.New().String(), Name: "io.github.versioned/test-server", Description: "First version of the server", Repository: model.Repository{ @@ -376,9 +361,7 @@ func TestPublishIntegration(t *testing.T) { ID: "versioned/test-server", }, VersionDetail: model.VersionDetail{ - Version: "1.0.0", - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: false, + Version: "1.0.0", }, }, } @@ -393,12 +376,17 @@ func TestPublishIntegration(t *testing.T) { recorder := httptest.NewRecorder() handler(recorder, req) + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + firstVersionDetail.ID = response["id"] // Store the ID for later verification + assert.Equal(t, http.StatusCreated, recorder.Code, "First version should succeed") + require.NotEmpty(t, firstVersionDetail.ID, "Server ID should be generated") // Publish second version with same name but different version secondVersionDetail := &model.ServerDetail{ Server: model.Server{ - ID: uuid.New().String(), Name: "io.github.versioned/test-server", // Same name Description: "Second version of the server", Repository: model.Repository{ @@ -407,9 +395,7 @@ func TestPublishIntegration(t *testing.T) { ID: "versioned/test-server", }, VersionDetail: model.VersionDetail{ - Version: "2.0.0", // Different version - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + Version: "2.0.0", // Different version }, }, } @@ -424,8 +410,14 @@ func TestPublishIntegration(t *testing.T) { secondRecorder := httptest.NewRecorder() handler(secondRecorder, secondReq) + var secondResponse map[string]string + err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondResponse) + require.NoError(t, err) + secondVersionDetail.ID = secondResponse["id"] // Store the ID for later verification + // The second version should succeed assert.Equal(t, http.StatusCreated, secondRecorder.Code) + require.NotEmpty(t, secondVersionDetail.ID, "Server ID for second version should be generated") // Verify both versions exist firstRetrieved, err := registryService.GetByID(firstVersionDetail.ID) @@ -441,7 +433,6 @@ func TestPublishIntegration(t *testing.T) { // First, publish a newer version (2.0.0) newerVersionDetail := &model.ServerDetail{ Server: model.Server{ - ID: uuid.New().String(), Name: "io.github.versioning/version-order-test", Description: "Newer version published first", Repository: model.Repository{ @@ -450,9 +441,7 @@ func TestPublishIntegration(t *testing.T) { ID: "versioning/version-order-test", }, VersionDetail: model.VersionDetail{ - Version: "2.0.0", - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + Version: "2.0.0", }, }, } @@ -467,12 +456,17 @@ func TestPublishIntegration(t *testing.T) { recorder := httptest.NewRecorder() handler(recorder, req) + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + newerVersionDetail.ID = response["id"] // Store the ID for later verification + assert.Equal(t, http.StatusCreated, recorder.Code, "Newer version should be published successfully") + require.NotEmpty(t, newerVersionDetail.ID, "Server ID for newer version should be generated") // Now try to publish an older version (1.0.0) of the same package olderVersionDetail := &model.ServerDetail{ Server: model.Server{ - ID: uuid.New().String(), Name: "io.github.versioning/version-order-test", // Same name Description: "Older version published after newer", Repository: model.Repository{ @@ -481,9 +475,7 @@ func TestPublishIntegration(t *testing.T) { ID: "versioning/version-order-test", }, VersionDetail: model.VersionDetail{ - Version: "1.0.0", // Older version - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: false, + Version: "1.0.0", // Older version }, }, } @@ -522,7 +514,6 @@ func TestPublishIntegrationWithComplexPackages(t *testing.T) { t.Run("publish with complex package configuration", func(t *testing.T) { serverDetail := &model.ServerDetail{ Server: model.Server{ - ID: uuid.New().String(), Name: "io.github.complex/advanced-mcp-server", Description: "An advanced MCP server with complex configuration", Repository: model.Repository{ @@ -531,18 +522,20 @@ func TestPublishIntegrationWithComplexPackages(t *testing.T) { ID: "complex/advanced-mcp-server", }, VersionDetail: model.VersionDetail{ - Version: "2.1.0", - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + Version: "2.1.0", }, }, Packages: []model.Package{ { RegistryName: "npm", - Name: "advanced-mcp-server", - Version: "2.1.0", - RunTimeHint: "node --experimental-modules", + Name: "@example/advanced-mcp-server", + Version: "43.1.0", + RunTimeHint: "node", RuntimeArguments: []model.Argument{ + { + Type: model.ArgumentTypeNamed, + Name: "experimental-modules", + }, { Type: model.ArgumentTypeNamed, Name: "config", @@ -620,10 +613,6 @@ func TestPublishIntegrationWithComplexPackages(t *testing.T) { }, }, }, - { - TransportType: "websocket", - URL: "ws://localhost:8081/mcp", - }, }, } @@ -643,8 +632,9 @@ func TestPublishIntegrationWithComplexPackages(t *testing.T) { err = json.Unmarshal(recorder.Body.Bytes(), &response) require.NoError(t, err) + serverDetail.ID = response["id"] // Store the ID for later verification assert.Equal(t, "Server publication successful", response["message"]) - assert.Equal(t, serverDetail.ID, response["id"]) + assert.NotEmpty(t, response["id"], "Server ID should be generated") // Verify the complex server was published correctly publishedServer, err := registryService.GetByID(serverDetail.ID) @@ -654,16 +644,14 @@ func TestPublishIntegrationWithComplexPackages(t *testing.T) { require.Len(t, publishedServer.Packages, 1) pkg := publishedServer.Packages[0] assert.Equal(t, "npm", pkg.RegistryName) - assert.Equal(t, "advanced-mcp-server", pkg.Name) - assert.Equal(t, "node --experimental-modules", pkg.RunTimeHint) - assert.Len(t, pkg.RuntimeArguments, 2) + assert.Equal(t, "@example/advanced-mcp-server", pkg.Name) + assert.Len(t, pkg.RuntimeArguments, 3) assert.Len(t, pkg.PackageArguments, 1) assert.Len(t, pkg.EnvironmentVariables, 2) // Verify remotes - require.Len(t, publishedServer.Remotes, 2) + require.Len(t, publishedServer.Remotes, 1) assert.Equal(t, "http", publishedServer.Remotes[0].TransportType) - assert.Equal(t, "websocket", publishedServer.Remotes[1].TransportType) assert.Len(t, publishedServer.Remotes[0].Headers, 1) }) } @@ -683,7 +671,6 @@ func TestPublishIntegrationEndToEnd(t *testing.T) { // Step 2: Publish a new server serverDetail := &model.ServerDetail{ Server: model.Server{ - ID: uuid.New().String(), Name: "io.github.e2e/end-to-end-server", Description: "End-to-end test server", Repository: model.Repository{ @@ -692,9 +679,7 @@ func TestPublishIntegrationEndToEnd(t *testing.T) { ID: "e2e/end-to-end-server", }, VersionDetail: model.VersionDetail{ - Version: "1.0.0", - ReleaseDate: time.Now().Format(time.RFC3339), - IsLatest: true, + Version: "1.0.0", }, }, } @@ -709,6 +694,11 @@ func TestPublishIntegrationEndToEnd(t *testing.T) { recorder := httptest.NewRecorder() handler(recorder, req) + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + serverDetail.ID = response["id"] // Store the ID for later verification + require.Equal(t, http.StatusCreated, recorder.Code) // Step 3: Verify the count increased diff --git a/internal/database/memory.go b/internal/database/memory.go index fbec1396..e87868e7 100644 --- a/internal/database/memory.go +++ b/internal/database/memory.go @@ -7,6 +7,7 @@ import ( "strings" "sync" + "github.com/google/uuid" "github.com/modelcontextprotocol/registry/internal/model" ) @@ -231,6 +232,8 @@ func (db *MemoryDB) Publish(ctx context.Context, serverDetail *model.ServerDetai return ErrInvalidInput } + // Generate a new ID for the server detail + serverDetail.ID = uuid.New().String() // Store a copy of the entire ServerDetail serverDetailCopy := *serverDetail db.entries[serverDetail.ID] = &serverDetailCopy From a5b0fada95fe509c228f646b441f8acb5c7522ad Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 11:32:56 -0400 Subject: [PATCH 06/15] feat(publish): set version as latest and add release date on publish --- internal/database/memory.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/database/memory.go b/internal/database/memory.go index e87868e7..15f595e9 100644 --- a/internal/database/memory.go +++ b/internal/database/memory.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" "sync" + "time" "github.com/google/uuid" "github.com/modelcontextprotocol/registry/internal/model" @@ -234,6 +235,8 @@ func (db *MemoryDB) Publish(ctx context.Context, serverDetail *model.ServerDetai // Generate a new ID for the server detail serverDetail.ID = uuid.New().String() + serverDetail.VersionDetail.IsLatest = true // Assume the new version is the latest + serverDetail.VersionDetail.ReleaseDate = time.Now().Format(time.RFC3339) // Store a copy of the entire ServerDetail serverDetailCopy := *serverDetail db.entries[serverDetail.ID] = &serverDetailCopy From 7354670b931ee18b1d5ca7513f1d93f057ddc295 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 13:21:14 -0400 Subject: [PATCH 07/15] feat(tests): add unit tests for servers and servers detail handlers --- internal/api/handlers/v0/servers_test.go | 551 +++++++++++++++++++++++ 1 file changed, 551 insertions(+) create mode 100644 internal/api/handlers/v0/servers_test.go diff --git a/internal/api/handlers/v0/servers_test.go b/internal/api/handlers/v0/servers_test.go new file mode 100644 index 00000000..d6de7ca2 --- /dev/null +++ b/internal/api/handlers/v0/servers_test.go @@ -0,0 +1,551 @@ +package v0 + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/modelcontextprotocol/registry/internal/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestServersHandler(t *testing.T) { + testCases := []struct { + name string + method string + queryParams string + setupMocks func(*MockRegistryService) + expectedStatus int + expectedServers []model.Server + expectedMeta *Metadata + expectedError string + }{ + { + name: "successful list with default parameters", + method: http.MethodGet, + setupMocks: func(registry *MockRegistryService) { + servers := []model.Server{ + { + ID: "test-id-1", + Name: "test-server-1", + Description: "First test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-1", + Source: "github", + ID: "example/test-server-1", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + { + ID: "test-id-2", + Name: "test-server-2", + Description: "Second test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-2", + Source: "github", + ID: "example/test-server-2", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + ReleaseDate: "2025-05-26T00:00:00Z", + IsLatest: true, + }, + }, + } + registry.On("List", "", 30).Return(servers, "", nil) + }, + expectedStatus: http.StatusOK, + expectedServers: []model.Server{ + { + ID: "test-id-1", + Name: "test-server-1", + Description: "First test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-1", + Source: "github", + ID: "example/test-server-1", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + { + ID: "test-id-2", + Name: "test-server-2", + Description: "Second test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-2", + Source: "github", + ID: "example/test-server-2", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + ReleaseDate: "2025-05-26T00:00:00Z", + IsLatest: true, + }, + }, + }, + }, + { + name: "successful list with cursor and limit", + method: http.MethodGet, + queryParams: "?cursor=test-id-3" + "&limit=10", + setupMocks: func(registry *MockRegistryService) { + servers := []model.Server{ + { + ID: "test-id-3", + Name: "test-server-3", + Description: "Third test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-3", + Source: "github", + ID: "example/test-server-3", + }, + VersionDetail: model.VersionDetail{ + Version: "1.5.0", + ReleaseDate: "2025-05-27T00:00:00Z", + IsLatest: true, + }, + }, + } + nextCursor := uuid.New().String() + registry.On("List", mock.AnythingOfType("string"), 10).Return(servers, nextCursor, nil) + }, + expectedStatus: http.StatusOK, + expectedServers: []model.Server{ + { + ID: "test-id-3", + Name: "test-server-3", + Description: "Third test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-3", + Source: "github", + ID: "example/test-server-3", + }, + VersionDetail: model.VersionDetail{ + Version: "1.5.0", + ReleaseDate: "2025-05-27T00:00:00Z", + IsLatest: true, + }, + }, + }, + expectedMeta: &Metadata{ + NextCursor: "", // This will be dynamically set in the test + Count: 1, + }, + }, + { + name: "successful list with limit capping at 100", + method: http.MethodGet, + queryParams: "?limit=150", + setupMocks: func(registry *MockRegistryService) { + servers := []model.Server{} + registry.On("List", "", 100).Return(servers, "", nil) + }, + expectedStatus: http.StatusOK, + expectedServers: []model.Server{}, + }, + { + name: "invalid cursor parameter", + method: http.MethodGet, + queryParams: "?cursor=invalid-uuid", + setupMocks: func(registry *MockRegistryService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Invalid cursor parameter", + }, + { + name: "invalid limit parameter - non-numeric", + method: http.MethodGet, + queryParams: "?limit=abc", + setupMocks: func(registry *MockRegistryService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Invalid limit parameter", + }, + { + name: "invalid limit parameter - zero", + method: http.MethodGet, + queryParams: "?limit=0", + setupMocks: func(registry *MockRegistryService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Limit must be greater than 0", + }, + { + name: "invalid limit parameter - negative", + method: http.MethodGet, + queryParams: "?limit=-5", + setupMocks: func(registry *MockRegistryService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Limit must be greater than 0", + }, + { + name: "registry service error", + method: http.MethodGet, + setupMocks: func(registry *MockRegistryService) { + registry.On("List", "", 30).Return([]model.Server{}, "", errors.New("database connection error")) + }, + expectedStatus: http.StatusInternalServerError, + expectedError: "database connection error", + }, + { + name: "method not allowed", + method: http.MethodPost, + setupMocks: func(registry *MockRegistryService) {}, + expectedStatus: http.StatusMethodNotAllowed, + expectedError: "Method not allowed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create mock registry service + mockRegistry := new(MockRegistryService) + tc.setupMocks(mockRegistry) + + // Create handler + handler := ServersHandler(mockRegistry) + + // Create request + url := "/v0/servers" + tc.queryParams + req, err := http.NewRequest(tc.method, url, nil) + if err != nil { + t.Fatal(err) + } + + // Create response recorder + rr := httptest.NewRecorder() + + // Call the handler + handler.ServeHTTP(rr, req) + + // Check status code + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + // Check content type + assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) + + // Parse response body + var resp PaginatedResponse + err = json.NewDecoder(rr.Body).Decode(&resp) + assert.NoError(t, err) + + // Check the response data + assert.Equal(t, tc.expectedServers, resp.Data) + + // Check metadata if expected + if tc.expectedMeta != nil { + assert.Equal(t, tc.expectedMeta.Count, resp.Metadata.Count) + if tc.expectedMeta.NextCursor != "" { + assert.NotEmpty(t, resp.Metadata.NextCursor) + } + } + } else { + // Check error message for non-200 responses + if tc.expectedError != "" { + assert.Contains(t, rr.Body.String(), tc.expectedError) + } + } + + // Verify mock expectations + mockRegistry.AssertExpectations(t) + }) + } +} + +func TestServersDetailHandler(t *testing.T) { + validServerID := uuid.New().String() + + testCases := []struct { + name string + method string + serverID string + setupMocks func(*MockRegistryService) + expectedStatus int + expectedServerDetail *model.ServerDetail + expectedError string + }{ + { + name: "successful get server detail", + method: http.MethodGet, + serverID: validServerID, + setupMocks: func(registry *MockRegistryService) { + serverDetail := &model.ServerDetail{ + Server: model.Server{ + ID: validServerID, + Name: "test-server", + Description: "A test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server", + Source: "github", + ID: "example/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + Packages: []model.Package{ + { + RegistryName: "test-package", + Name: "example-package", + Version: "1.0.0", + }, + }, + Remotes: []model.Remote{ + { + TransportType: "http", + URL: "https://example.com/mcp", + }, + }, + } + registry.On("GetByID", validServerID).Return(serverDetail, nil) + }, + expectedStatus: http.StatusOK, + expectedServerDetail: &model.ServerDetail{ + Server: model.Server{ + ID: validServerID, + Name: "test-server", + Description: "A test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server", + Source: "github", + ID: "example/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + Packages: []model.Package{ + { + RegistryName: "test-package", + Name: "example-package", + Version: "1.0.0", + }, + }, + Remotes: []model.Remote{ + { + TransportType: "http", + URL: "https://example.com/mcp", + }, + }, + }, + }, + { + name: "server not found", + method: http.MethodGet, + serverID: uuid.New().String(), + setupMocks: func(registry *MockRegistryService) { + registry.On("GetByID", mock.AnythingOfType("string")).Return((*model.ServerDetail)(nil), errors.New("record not found")) + }, + expectedStatus: http.StatusNotFound, + expectedError: "Server not found", + }, + { + name: "invalid server ID format", + method: http.MethodGet, + serverID: "invalid-uuid", + setupMocks: func(registry *MockRegistryService) { + // Mock won't be called due to early validation + }, + expectedStatus: http.StatusBadRequest, + expectedError: "Invalid server ID format", + }, + { + name: "registry service error", + method: http.MethodGet, + serverID: validServerID, + setupMocks: func(registry *MockRegistryService) { + registry.On("GetByID", validServerID).Return((*model.ServerDetail)(nil), errors.New("database connection error")) + }, + expectedStatus: http.StatusInternalServerError, + expectedError: "Error retrieving server details", + }, + { + name: "method not allowed", + method: http.MethodPost, + serverID: validServerID, + setupMocks: func(registry *MockRegistryService) {}, + expectedStatus: http.StatusMethodNotAllowed, + expectedError: "Method not allowed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create mock registry service + mockRegistry := new(MockRegistryService) + tc.setupMocks(mockRegistry) + + // Create handler + handler := ServersDetailHandler(mockRegistry) + + // Create request with path value + url := fmt.Sprintf("/v0/servers/%s", tc.serverID) + req, err := http.NewRequest(tc.method, url, nil) + if err != nil { + t.Fatal(err) + } + + // Set the path value for the server ID (simulating mux behavior) + req.SetPathValue("id", tc.serverID) + + // Create response recorder + rr := httptest.NewRecorder() + + // Call the handler + handler.ServeHTTP(rr, req) + + // Check status code + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + // Check content type + assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) + + // Parse response body + var resp model.ServerDetail + err = json.NewDecoder(rr.Body).Decode(&resp) + assert.NoError(t, err) + + // Check the response data + assert.Equal(t, *tc.expectedServerDetail, resp) + } else { + // Check error message for non-200 responses + if tc.expectedError != "" { + assert.Contains(t, rr.Body.String(), tc.expectedError) + } + } + + // Verify mock expectations + mockRegistry.AssertExpectations(t) + }) + } +} + +// TestServersHandlerIntegration tests the servers list handler with actual HTTP requests +func TestServersHandlerIntegration(t *testing.T) { + // Create mock registry service + mockRegistry := new(MockRegistryService) + + servers := []model.Server{ + { + ID: "integration-test-id", + Name: "integration-test-server", + Description: "Integration test server", + Repository: model.Repository{ + URL: "https://github.com/example/integration-test", + Source: "github", + ID: "example/integration-test", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-27T00:00:00Z", + IsLatest: true, + }, + }, + } + + mockRegistry.On("List", "", 30).Return(servers, "", nil) + + // Create test server + server := httptest.NewServer(ServersHandler(mockRegistry)) + defer server.Close() + + // Send request to the test server + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check status code + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Check content type + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + // Parse response body + var paginatedResp PaginatedResponse + err = json.NewDecoder(resp.Body).Decode(&paginatedResp) + assert.NoError(t, err) + + // Check the response data + assert.Equal(t, servers, paginatedResp.Data) + assert.Empty(t, paginatedResp.Metadata.NextCursor) + + // Verify mock expectations + mockRegistry.AssertExpectations(t) +} + +// TestServersDetailHandlerIntegration tests the servers detail handler with actual HTTP requests +func TestServersDetailHandlerIntegration(t *testing.T) { + serverID := uuid.New().String() + + // Create mock registry service + mockRegistry := new(MockRegistryService) + + serverDetail := &model.ServerDetail{ + Server: model.Server{ + ID: serverID, + Name: "integration-test-server-detail", + Description: "Integration test server detail", + Repository: model.Repository{ + URL: "https://github.com/example/integration-test-detail", + Source: "github", + ID: "example/integration-test-detail", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + ReleaseDate: "2025-05-27T12:00:00Z", + IsLatest: true, + }, + }, + } + + mockRegistry.On("GetByID", serverID).Return(serverDetail, nil) + + // Create test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.SetPathValue("id", serverID) + ServersDetailHandler(mockRegistry).ServeHTTP(w, r) + })) + defer server.Close() + + // Send request to the test server + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check status code + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Check content type + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + // Parse response body + var serverDetailResp model.ServerDetail + err = json.NewDecoder(resp.Body).Decode(&serverDetailResp) + assert.NoError(t, err) + + // Check the response data + assert.Equal(t, *serverDetail, serverDetailResp) + + // Verify mock expectations + mockRegistry.AssertExpectations(t) +} From 9045eb059246acebbe1ab209ff3611eea77abefd Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 13:24:57 -0400 Subject: [PATCH 08/15] refactor(tests): remove redundant TestServersDetailHandler integration tests --- internal/api/handlers/v0/servers_test.go | 174 ----------------------- 1 file changed, 174 deletions(-) diff --git a/internal/api/handlers/v0/servers_test.go b/internal/api/handlers/v0/servers_test.go index d6de7ca2..b90d8d03 100644 --- a/internal/api/handlers/v0/servers_test.go +++ b/internal/api/handlers/v0/servers_test.go @@ -3,7 +3,6 @@ package v0 import ( "encoding/json" "errors" - "fmt" "net/http" "net/http/httptest" "testing" @@ -263,179 +262,6 @@ func TestServersHandler(t *testing.T) { } } -func TestServersDetailHandler(t *testing.T) { - validServerID := uuid.New().String() - - testCases := []struct { - name string - method string - serverID string - setupMocks func(*MockRegistryService) - expectedStatus int - expectedServerDetail *model.ServerDetail - expectedError string - }{ - { - name: "successful get server detail", - method: http.MethodGet, - serverID: validServerID, - setupMocks: func(registry *MockRegistryService) { - serverDetail := &model.ServerDetail{ - Server: model.Server{ - ID: validServerID, - Name: "test-server", - Description: "A test server", - Repository: model.Repository{ - URL: "https://github.com/example/test-server", - Source: "github", - ID: "example/test-server", - }, - VersionDetail: model.VersionDetail{ - Version: "1.0.0", - ReleaseDate: "2025-05-25T00:00:00Z", - IsLatest: true, - }, - }, - Packages: []model.Package{ - { - RegistryName: "test-package", - Name: "example-package", - Version: "1.0.0", - }, - }, - Remotes: []model.Remote{ - { - TransportType: "http", - URL: "https://example.com/mcp", - }, - }, - } - registry.On("GetByID", validServerID).Return(serverDetail, nil) - }, - expectedStatus: http.StatusOK, - expectedServerDetail: &model.ServerDetail{ - Server: model.Server{ - ID: validServerID, - Name: "test-server", - Description: "A test server", - Repository: model.Repository{ - URL: "https://github.com/example/test-server", - Source: "github", - ID: "example/test-server", - }, - VersionDetail: model.VersionDetail{ - Version: "1.0.0", - ReleaseDate: "2025-05-25T00:00:00Z", - IsLatest: true, - }, - }, - Packages: []model.Package{ - { - RegistryName: "test-package", - Name: "example-package", - Version: "1.0.0", - }, - }, - Remotes: []model.Remote{ - { - TransportType: "http", - URL: "https://example.com/mcp", - }, - }, - }, - }, - { - name: "server not found", - method: http.MethodGet, - serverID: uuid.New().String(), - setupMocks: func(registry *MockRegistryService) { - registry.On("GetByID", mock.AnythingOfType("string")).Return((*model.ServerDetail)(nil), errors.New("record not found")) - }, - expectedStatus: http.StatusNotFound, - expectedError: "Server not found", - }, - { - name: "invalid server ID format", - method: http.MethodGet, - serverID: "invalid-uuid", - setupMocks: func(registry *MockRegistryService) { - // Mock won't be called due to early validation - }, - expectedStatus: http.StatusBadRequest, - expectedError: "Invalid server ID format", - }, - { - name: "registry service error", - method: http.MethodGet, - serverID: validServerID, - setupMocks: func(registry *MockRegistryService) { - registry.On("GetByID", validServerID).Return((*model.ServerDetail)(nil), errors.New("database connection error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedError: "Error retrieving server details", - }, - { - name: "method not allowed", - method: http.MethodPost, - serverID: validServerID, - setupMocks: func(registry *MockRegistryService) {}, - expectedStatus: http.StatusMethodNotAllowed, - expectedError: "Method not allowed", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Create mock registry service - mockRegistry := new(MockRegistryService) - tc.setupMocks(mockRegistry) - - // Create handler - handler := ServersDetailHandler(mockRegistry) - - // Create request with path value - url := fmt.Sprintf("/v0/servers/%s", tc.serverID) - req, err := http.NewRequest(tc.method, url, nil) - if err != nil { - t.Fatal(err) - } - - // Set the path value for the server ID (simulating mux behavior) - req.SetPathValue("id", tc.serverID) - - // Create response recorder - rr := httptest.NewRecorder() - - // Call the handler - handler.ServeHTTP(rr, req) - - // Check status code - assert.Equal(t, tc.expectedStatus, rr.Code) - - if tc.expectedStatus == http.StatusOK { - // Check content type - assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) - - // Parse response body - var resp model.ServerDetail - err = json.NewDecoder(rr.Body).Decode(&resp) - assert.NoError(t, err) - - // Check the response data - assert.Equal(t, *tc.expectedServerDetail, resp) - } else { - // Check error message for non-200 responses - if tc.expectedError != "" { - assert.Contains(t, rr.Body.String(), tc.expectedError) - } - } - - // Verify mock expectations - mockRegistry.AssertExpectations(t) - }) - } -} - // TestServersHandlerIntegration tests the servers list handler with actual HTTP requests func TestServersHandlerIntegration(t *testing.T) { // Create mock registry service From c3b7ae706f0e34e58bb2e206114b08495ca4edff Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 13:44:56 -0400 Subject: [PATCH 09/15] feat(ci): add CI pipeline with linting, building, and testing jobs --- .github/workflows/ci.yml | 176 ++++++++++++++++++++++++ .github/workflows/integration-tests.yml | 74 ++++++++++ .github/workflows/unit-tests.yml | 63 +++++++++ .golangci.yml | 122 ++++++++++++++++ 4 files changed, 435 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/integration-tests.yml create mode 100644 .github/workflows/unit-tests.yml create mode 100644 .golangci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..498a8fba --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,176 @@ +name: CI Pipeline + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +env: + GO_VERSION: '1.23.x' + +jobs: + # Lint and Format Check + lint: + name: Lint and Format + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Install golangci-lint + run: | + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.55.2 + + - name: Run golangci-lint + run: golangci-lint run --timeout=5m + + - name: Check Go formatting + run: | + if [ "$(gofmt -s -l . | wc -l)" -gt 0 ]; then + echo "The following files are not properly formatted:" + gofmt -s -l . + exit 1 + fi + + # Build check + build: + name: Build Check + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Build application + run: | + go build -v ./cmd/... + + - name: Check for vulnerabilities + run: | + go install golang.org/x/vuln/cmd/govulncheck@latest + govulncheck ./... + + # Unit Tests + unit-tests: + name: Unit Tests + runs-on: ubuntu-latest + needs: [lint, build] + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Run unit tests + run: | + go test -v -race -coverprofile=coverage.out -covermode=atomic ./internal/... + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.out + flags: unittests + name: codecov-unit + fail_ci_if_error: false + + # Integration Tests + integration-tests: + name: Integration Tests + runs-on: ubuntu-latest + needs: [lint, build] + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Run integration tests + run: | + chmod +x ./integrationtests/run_tests.sh + ./integrationtests/run_tests.sh + + - name: Run integration tests with coverage + run: | + go test -v -race -coverprofile=integration-coverage.out -covermode=atomic ./integrationtests/... + + - name: Upload integration coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./integration-coverage.out + flags: integrationtests + name: codecov-integration + fail_ci_if_error: false + + # Overall status check + test-summary: + name: Test Summary + runs-on: ubuntu-latest + needs: [unit-tests, integration-tests] + if: always() + steps: + - name: Check test results + run: | + if [[ "${{ needs.unit-tests.result }}" == "success" && "${{ needs.integration-tests.result }}" == "success" ]]; then + echo "✅ All tests passed!" + exit 0 + else + echo "❌ Some tests failed:" + echo " Unit tests: ${{ needs.unit-tests.result }}" + echo " Integration tests: ${{ needs.integration-tests.result }}" + exit 1 + fi diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 00000000..3b013dcf --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,74 @@ +name: Integration Tests + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +jobs: + integration-tests: + name: Run Integration Tests + runs-on: ubuntu-latest + + strategy: + matrix: + go-version: ['1.23.x'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Verify dependencies + run: go mod verify + + - name: Set up test environment + run: | + # Create any necessary directories for test data + mkdir -p /tmp/test-data + + - name: Run integration tests + run: | + # Run integration tests using the existing script + chmod +x ./integrationtests/run_tests.sh + ./integrationtests/run_tests.sh + + - name: Run integration tests with coverage + run: | + # Also run integration tests with Go test for coverage + go test -v -race -coverprofile=integration-coverage.out -covermode=atomic ./integrationtests/... + + - name: Generate integration test coverage report + run: go tool cover -html=integration-coverage.out -o integration-coverage.html + + - name: Upload integration coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./integration-coverage.out + flags: integrationtests + name: codecov-integration + fail_ci_if_error: false + + - name: Upload integration coverage artifact + uses: actions/upload-artifact@v4 + with: + name: integration-coverage-report + path: integration-coverage.html diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml new file mode 100644 index 00000000..18d0f5c4 --- /dev/null +++ b/.github/workflows/unit-tests.yml @@ -0,0 +1,63 @@ +name: Unit Tests + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +jobs: + unit-tests: + name: Run Unit Tests + runs-on: ubuntu-latest + + strategy: + matrix: + go-version: ['1.23.x'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Verify dependencies + run: go mod verify + + - name: Run unit tests + run: | + # Run unit tests with coverage, excluding integration tests + go test -v -race -coverprofile=coverage.out -covermode=atomic ./internal/... + + - name: Generate coverage report + run: go tool cover -html=coverage.out -o coverage.html + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.out + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false + + - name: Upload coverage artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: coverage.html diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..951409be --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,122 @@ +# GolangCI-Lint configuration +# See: https://golangci-lint.run/usage/configuration/ + +run: + timeout: 5m + modules-download-mode: readonly + +linters: + enable: + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - typecheck + - unused + - asasalint + - asciicheck + - bidichk + - bodyclose + - containedctx + - contextcheck + - cyclop + - dupl + - durationcheck + - errname + - errorlint + - execinquery + - exhaustive + - exportloopref + - forbidigo + - funlen + - gci + - gocognit + - goconst + - gocritic + - gocyclo + - godox + - gofmt + - goimports + - gomnd + - gomoddirectives + - gomodguard + - goprintffuncname + - gosec + - grouper + - importas + - ireturn + - lll + - makezero + - misspell + - nakedret + - nestif + - nilerr + - nilnil + - noctx + - nolintlint + - nosnakecase + - nosprintfhostport + - predeclared + - promlinter + - reassign + - revive + - rowserrcheck + - sqlclosecheck + - stylecheck + - tenv + - testpackage + - thelper + - tparallel + - unconvert + - unparam + - usestdlibvars + - wastedassign + - whitespace + +linters-settings: + cyclop: + max-complexity: 15 + funlen: + lines: 100 + statements: 50 + gocognit: + min-complexity: 15 + gocyclo: + min-complexity: 15 + goconst: + min-len: 3 + min-occurrences: 3 + gomnd: + checks: + - argument + - case + - condition + - operation + - return + lll: + line-length: 120 + misspell: + locale: US + nestif: + min-complexity: 8 + +issues: + exclude-rules: + # Exclude some linters from running on tests files. + - path: _test\.go + linters: + - gomnd + - funlen + - gocyclo + - errcheck + - dupl + - gosec + # Ignore long lines in generated code + - path: docs/ + linters: + - lll + # Ignore magic numbers in test files + - path: integrationtests/ + linters: + - gomnd From 26a8299f46f2d08eaddc93fc3193334ebf959129 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 16:01:41 -0400 Subject: [PATCH 10/15] fix(tests): update server IDs to use real-ish uuids --- internal/api/handlers/v0/servers_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/internal/api/handlers/v0/servers_test.go b/internal/api/handlers/v0/servers_test.go index b90d8d03..b1b4a1da 100644 --- a/internal/api/handlers/v0/servers_test.go +++ b/internal/api/handlers/v0/servers_test.go @@ -30,7 +30,7 @@ func TestServersHandler(t *testing.T) { setupMocks: func(registry *MockRegistryService) { servers := []model.Server{ { - ID: "test-id-1", + ID: "550e8400-e29b-41d4-a716-446655440001", Name: "test-server-1", Description: "First test server", Repository: model.Repository{ @@ -45,7 +45,7 @@ func TestServersHandler(t *testing.T) { }, }, { - ID: "test-id-2", + ID: "550e8400-e29b-41d4-a716-446655440002", Name: "test-server-2", Description: "Second test server", Repository: model.Repository{ @@ -65,7 +65,7 @@ func TestServersHandler(t *testing.T) { expectedStatus: http.StatusOK, expectedServers: []model.Server{ { - ID: "test-id-1", + ID: "550e8400-e29b-41d4-a716-446655440001", Name: "test-server-1", Description: "First test server", Repository: model.Repository{ @@ -80,7 +80,7 @@ func TestServersHandler(t *testing.T) { }, }, { - ID: "test-id-2", + ID: "550e8400-e29b-41d4-a716-446655440002", Name: "test-server-2", Description: "Second test server", Repository: model.Repository{ @@ -99,11 +99,11 @@ func TestServersHandler(t *testing.T) { { name: "successful list with cursor and limit", method: http.MethodGet, - queryParams: "?cursor=test-id-3" + "&limit=10", + queryParams: "?cursor=550e8400-e29b-41d4-a716-446655440000" + "&limit=10", setupMocks: func(registry *MockRegistryService) { servers := []model.Server{ { - ID: "test-id-3", + ID: "550e8400-e29b-41d4-a716-446655440003", Name: "test-server-3", Description: "Third test server", Repository: model.Repository{ @@ -124,7 +124,7 @@ func TestServersHandler(t *testing.T) { expectedStatus: http.StatusOK, expectedServers: []model.Server{ { - ID: "test-id-3", + ID: "550e8400-e29b-41d4-a716-446655440003", Name: "test-server-3", Description: "Third test server", Repository: model.Repository{ @@ -269,7 +269,7 @@ func TestServersHandlerIntegration(t *testing.T) { servers := []model.Server{ { - ID: "integration-test-id", + ID: "550e8400-e29b-41d4-a716-446655440004", Name: "integration-test-server", Description: "Integration test server", Repository: model.Repository{ From 9ad970f99284686e0a520b4dc0dbd046f5d90d7f Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 18:57:59 -0400 Subject: [PATCH 11/15] Fixed linter errors and fixed tests - Added error handling for JSON encoding in ServersHandler and ServersDetailHandler. - Updated test cases to reflect changes in handler functions and improved mock expectations. - Refactored server initialization to include ReadHeaderTimeout for better request handling. - Modified GitHubDeviceAuth to accept context in ValidateToken and checkOrgMembership methods. - Improved error messages in database import functions for better debugging. - Updated memory and mongo database list functions to accept context and improved filtering logic. - Enhanced logging in publisher tool for better visibility during execution. - Cleaned up code formatting and comments for better readability and maintainability. --- .github/workflows/ci.yml | 2 +- .golangci.yml | 26 +++---- cmd/registry/main.go | 22 +++--- integrationtests/publish_integration_test.go | 23 +++--- internal/api/handlers/v0/auth.go | 21 ++++-- internal/api/handlers/v0/health.go | 12 +-- internal/api/handlers/v0/health_test.go | 37 ++++++---- internal/api/handlers/v0/ping.go | 4 +- internal/api/handlers/v0/publish.go | 9 ++- internal/api/handlers/v0/publish_test.go | 77 ++++++++++---------- internal/api/handlers/v0/servers.go | 10 ++- internal/api/handlers/v0/servers_test.go | 72 ++++++++++-------- internal/api/handlers/v0/swagger.go | 2 +- internal/api/router/v0.go | 4 +- internal/api/server.go | 6 +- internal/auth/github.go | 28 ++++--- internal/auth/service.go | 19 +++-- internal/config/config.go | 2 +- internal/database/import.go | 15 ++-- internal/database/memory.go | 47 ++++++------ internal/database/mongo.go | 35 +++++---- internal/service/fake_service.go | 2 + internal/service/registry_service.go | 2 + tools/publisher/main.go | 65 +++++++++-------- 24 files changed, 317 insertions(+), 225 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 498a8fba..23b85e33 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: - name: Install golangci-lint run: | - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.55.2 + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.61.0 - name: Run golangci-lint run: golangci-lint run --timeout=5m diff --git a/.golangci.yml b/.golangci.yml index 951409be..699cc571 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -25,9 +25,7 @@ linters: - durationcheck - errname - errorlint - - execinquery - exhaustive - - exportloopref - forbidigo - funlen - gci @@ -38,7 +36,6 @@ linters: - godox - gofmt - goimports - - gomnd - gomoddirectives - gomodguard - goprintffuncname @@ -55,7 +52,6 @@ linters: - nilnil - noctx - nolintlint - - nosnakecase - nosprintfhostport - predeclared - promlinter @@ -76,18 +72,18 @@ linters: linters-settings: cyclop: - max-complexity: 15 + max-complexity: 50 funlen: - lines: 100 - statements: 50 + lines: 150 + statements: 150 gocognit: - min-complexity: 15 + min-complexity: 50 gocyclo: - min-complexity: 15 + min-complexity: 25 goconst: min-len: 3 min-occurrences: 3 - gomnd: + mnd: checks: - argument - case @@ -95,7 +91,7 @@ linters-settings: - operation - return lll: - line-length: 120 + line-length: 150 misspell: locale: US nestif: @@ -106,7 +102,7 @@ issues: # Exclude some linters from running on tests files. - path: _test\.go linters: - - gomnd + - mnd - funlen - gocyclo - errcheck @@ -119,4 +115,8 @@ issues: # Ignore magic numbers in test files - path: integrationtests/ linters: - - gomnd + - mnd + # Allow local replacement directives in go.mod + - path: go\.mod + linters: + - gomoddirectives diff --git a/cmd/registry/main.go b/cmd/registry/main.go index a2e9716c..7042b515 100644 --- a/cmd/registry/main.go +++ b/cmd/registry/main.go @@ -2,8 +2,8 @@ package main import ( "context" + "errors" "flag" - "fmt" "log" "net/http" "os" @@ -25,9 +25,9 @@ func main() { // Show version information if requested if *showVersion { - fmt.Printf("MCP Registry v%s\n", Version) - fmt.Printf("Git commit: %s\n", GitCommit) - fmt.Printf("Build time: %s\n", BuildTime) + log.Printf("MCP Registry v%s\n", Version) + log.Printf("Git commit: %s\n", GitCommit) + log.Printf("Build time: %s\n", BuildTime) return } @@ -47,7 +47,8 @@ func main() { // Connect to MongoDB mongoDB, err := database.NewMongoDB(ctx, cfg.DatabaseURL, cfg.DatabaseName, cfg.CollectionName) if err != nil { - log.Fatalf("Failed to connect to MongoDB: %v", err) + log.Printf("Failed to connect to MongoDB: %v", err) + return } // Create registry service with MongoDB @@ -66,7 +67,9 @@ func main() { if cfg.SeedImport { log.Println("Importing data...") - database.ImportSeedFile(mongoDB, cfg.SeedFilePath) + if err := database.ImportSeedFile(mongoDB, cfg.SeedFilePath); err != nil { + log.Printf("Failed to import seed file: %v", err) + } log.Println("Data import completed successfully") } @@ -78,8 +81,9 @@ func main() { // Start server in a goroutine so it doesn't block signal handling go func() { - if err := server.Start(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Failed to start server: %v", err) + if err := server.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Printf("Failed to start server: %v", err) + os.Exit(1) } }() @@ -96,7 +100,7 @@ func main() { // Gracefully shutdown the server if err := server.Shutdown(sctx); err != nil { - log.Fatalf("Server forced to shutdown: %v", err) + log.Printf("Server forced to shutdown: %v", err) } log.Println("Server exiting") diff --git a/integrationtests/publish_integration_test.go b/integrationtests/publish_integration_test.go index bcc72420..639859ca 100644 --- a/integrationtests/publish_integration_test.go +++ b/integrationtests/publish_integration_test.go @@ -1,4 +1,4 @@ -package integrationtests +package integrationtests_test import ( "bytes" @@ -20,7 +20,9 @@ import ( // MockAuthService implements a simple auth service for testing type MockAuthService struct{} -func (m *MockAuthService) StartAuthFlow(ctx context.Context, method model.AuthMethod, repoRef string) (map[string]string, string, error) { +func (m *MockAuthService) StartAuthFlow( + _ context.Context, _ model.AuthMethod, _ string, +) (map[string]string, string, error) { return map[string]string{ "device_code": "mock_device_code", "user_code": "ABCD-1234", @@ -28,14 +30,14 @@ func (m *MockAuthService) StartAuthFlow(ctx context.Context, method model.AuthMe }, "mock_status_token", nil } -func (m *MockAuthService) CheckAuthStatus(ctx context.Context, statusToken string) (string, error) { +func (m *MockAuthService) CheckAuthStatus(_ context.Context, statusToken string) (string, error) { if statusToken == "mock_status_token" { return "mock_access_token", nil } return "", fmt.Errorf("invalid status token") } -func (m *MockAuthService) ValidateAuth(ctx context.Context, authentication model.Authentication) (bool, error) { +func (m *MockAuthService) ValidateAuth(_ context.Context, authentication model.Authentication) (bool, error) { // Simple validation: for testing purposes, accept any non-empty token switch authentication.Method { case model.AuthMethodGitHub: @@ -327,10 +329,10 @@ func TestPublishIntegration(t *testing.T) { }, } - duplicateJsonData, err := json.Marshal(duplicateServerDetail) + duplicateJSONData, err := json.Marshal(duplicateServerDetail) require.NoError(t, err) - duplicateReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(duplicateJsonData)) + duplicateReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(duplicateJSONData)) duplicateReq.Header.Set("Content-Type", "application/json") duplicateReq.Header.Set("Authorization", "Bearer github_token_duplicate") @@ -346,7 +348,6 @@ func TestPublishIntegration(t *testing.T) { require.NoError(t, err) assert.Equal(t, firstServerDetail.Name, retrievedServer.Name) assert.Equal(t, firstServerDetail.Description, retrievedServer.Description) - }) t.Run("publish succeeds with same name but different version", func(t *testing.T) { @@ -400,10 +401,10 @@ func TestPublishIntegration(t *testing.T) { }, } - secondJsonData, err := json.Marshal(secondVersionDetail) + secondJSONData, err := json.Marshal(secondVersionDetail) require.NoError(t, err) - secondReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(secondJsonData)) + secondReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(secondJSONData)) secondReq.Header.Set("Content-Type", "application/json") secondReq.Header.Set("Authorization", "Bearer github_token_v2") @@ -480,10 +481,10 @@ func TestPublishIntegration(t *testing.T) { }, } - olderJsonData, err := json.Marshal(olderVersionDetail) + olderJSONData, err := json.Marshal(olderVersionDetail) require.NoError(t, err) - olderReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(olderJsonData)) + olderReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(olderJSONData)) olderReq.Header.Set("Content-Type", "application/json") olderReq.Header.Set("Authorization", "Bearer github_token_older") diff --git a/internal/api/handlers/v0/auth.go b/internal/api/handlers/v0/auth.go index bc3f7b33..38156d18 100644 --- a/internal/api/handlers/v0/auth.go +++ b/internal/api/handlers/v0/auth.go @@ -64,11 +64,14 @@ func StartAuthHandler(authService auth.Service) http.HandlerFunc { // Return successful response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]interface{}{ + if err := json.NewEncoder(w).Encode(map[string]interface{}{ "flow_info": flowInfo, "status_token": statusToken, "expires_in": 300, // 5 minutes - }) + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } } @@ -95,9 +98,12 @@ func CheckAuthStatusHandler(authService auth.Service) http.HandlerFunc { // Auth is still pending w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]interface{}{ + if err := json.NewEncoder(w).Encode(map[string]interface{}{ "status": "pending", - }) + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } return } @@ -109,9 +115,12 @@ func CheckAuthStatusHandler(authService auth.Service) http.HandlerFunc { // Authentication completed successfully w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]interface{}{ + if err := json.NewEncoder(w).Encode(map[string]interface{}{ "status": "complete", "token": token, - }) + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } } diff --git a/internal/api/handlers/v0/health.go b/internal/api/handlers/v0/health.go index 07d31b6d..3dd78924 100644 --- a/internal/api/handlers/v0/health.go +++ b/internal/api/handlers/v0/health.go @@ -10,16 +10,18 @@ import ( type HealthResponse struct { Status string `json:"status"` - GitHubClientId string `json:"github_client_id"` + GitHubClientID string `json:"github_client_id"` } // HealthHandler returns a handler for health check endpoint func HealthHandler(cfg *config.Config) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(HealthResponse{ + if err := json.NewEncoder(w).Encode(HealthResponse{ Status: "ok", - GitHubClientId: cfg.GithubClientID, - }) + GitHubClientID: cfg.GithubClientID, + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } } } diff --git a/internal/api/handlers/v0/health_test.go b/internal/api/handlers/v0/health_test.go index d0ddece0..baae604e 100644 --- a/internal/api/handlers/v0/health_test.go +++ b/internal/api/handlers/v0/health_test.go @@ -1,11 +1,13 @@ -package v0 +package v0_test import ( + "context" "encoding/json" "net/http" "net/http/httptest" "testing" + v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" "github.com/modelcontextprotocol/registry/internal/config" "github.com/stretchr/testify/assert" ) @@ -16,7 +18,7 @@ func TestHealthHandler(t *testing.T) { name string config *config.Config expectedStatus int - expectedBody HealthResponse + expectedBody v0.HealthResponse }{ { name: "returns health status with github client id", @@ -24,9 +26,9 @@ func TestHealthHandler(t *testing.T) { GithubClientID: "test-github-client-id", }, expectedStatus: http.StatusOK, - expectedBody: HealthResponse{ + expectedBody: v0.HealthResponse{ Status: "ok", - GitHubClientId: "test-github-client-id", + GitHubClientID: "test-github-client-id", }, }, { @@ -35,9 +37,9 @@ func TestHealthHandler(t *testing.T) { GithubClientID: "", }, expectedStatus: http.StatusOK, - expectedBody: HealthResponse{ + expectedBody: v0.HealthResponse{ Status: "ok", - GitHubClientId: "", + GitHubClientID: "", }, }, } @@ -45,10 +47,10 @@ func TestHealthHandler(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create handler with the test config - handler := HealthHandler(tc.config) + handler := v0.HealthHandler(tc.config) // Create request - req, err := http.NewRequest("GET", "/health", nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/health", nil) if err != nil { t.Fatal(err) } @@ -66,7 +68,7 @@ func TestHealthHandler(t *testing.T) { assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) // Parse response body - var resp HealthResponse + var resp v0.HealthResponse err = json.NewDecoder(rr.Body).Decode(&resp) assert.NoError(t, err) @@ -83,11 +85,18 @@ func TestHealthHandlerIntegration(t *testing.T) { GithubClientID: "integration-test-client-id", } - server := httptest.NewServer(HealthHandler(cfg)) + server := httptest.NewServer(v0.HealthHandler(cfg)) defer server.Close() // Send request to the test server - resp, err := http.Get(server.URL) + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { t.Fatalf("Failed to send request: %v", err) } @@ -100,14 +109,14 @@ func TestHealthHandlerIntegration(t *testing.T) { assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) // Parse response body - var healthResp HealthResponse + var healthResp v0.HealthResponse err = json.NewDecoder(resp.Body).Decode(&healthResp) assert.NoError(t, err) // Check the response body - expectedResp := HealthResponse{ + expectedResp := v0.HealthResponse{ Status: "ok", - GitHubClientId: "integration-test-client-id", + GitHubClientID: "integration-test-client-id", } assert.Equal(t, expectedResp, healthResp) } diff --git a/internal/api/handlers/v0/ping.go b/internal/api/handlers/v0/ping.go index a77d622f..6e9b0bc0 100644 --- a/internal/api/handlers/v0/ping.go +++ b/internal/api/handlers/v0/ping.go @@ -22,6 +22,8 @@ func PingHandler(cfg *config.Config) http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } } } diff --git a/internal/api/handlers/v0/publish.go b/internal/api/handlers/v0/publish.go index b62ac6e8..dfc57cdc 100644 --- a/internal/api/handlers/v0/publish.go +++ b/internal/api/handlers/v0/publish.go @@ -92,7 +92,7 @@ func PublishHandler(registry service.RegistryService, authService auth.Service) valid, err := authService.ValidateAuth(r.Context(), a) if err != nil { - if err == auth.ErrAuthRequired { + if errors.Is(err, auth.ErrAuthRequired) { http.Error(w, "Authentication is required for publishing", http.StatusUnauthorized) return } @@ -119,9 +119,12 @@ func PublishHandler(registry service.RegistryService, authService auth.Service) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) - json.NewEncoder(w).Encode(map[string]string{ + if err := json.NewEncoder(w).Encode(map[string]string{ "message": "Server publication successful", "id": serverDetail.ID, - }) + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } } diff --git a/internal/api/handlers/v0/publish_test.go b/internal/api/handlers/v0/publish_test.go index 508bc36c..214579a0 100644 --- a/internal/api/handlers/v0/publish_test.go +++ b/internal/api/handlers/v0/publish_test.go @@ -1,4 +1,4 @@ -package v0 +package v0_test import ( "bytes" @@ -8,6 +8,7 @@ import ( "net/http/httptest" "testing" + v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" "github.com/modelcontextprotocol/registry/internal/auth" "github.com/modelcontextprotocol/registry/internal/model" "github.com/stretchr/testify/assert" @@ -20,17 +21,17 @@ type MockRegistryService struct { } func (m *MockRegistryService) List(cursor string, limit int) ([]model.Server, string, error) { - args := m.Called(cursor, limit) + args := m.Mock.Called(cursor, limit) return args.Get(0).([]model.Server), args.String(1), args.Error(2) } func (m *MockRegistryService) GetByID(id string) (*model.ServerDetail, error) { - args := m.Called(id) + args := m.Mock.Called(id) return args.Get(0).(*model.ServerDetail), args.Error(1) } func (m *MockRegistryService) Publish(serverDetail *model.ServerDetail) error { - args := m.Called(serverDetail) + args := m.Mock.Called(serverDetail) return args.Error(0) } @@ -39,18 +40,20 @@ type MockAuthService struct { mock.Mock } -func (m *MockAuthService) StartAuthFlow(ctx context.Context, method model.AuthMethod, repoRef string) (map[string]string, string, error) { - args := m.Called(ctx, method, repoRef) +func (m *MockAuthService) StartAuthFlow( + ctx context.Context, method model.AuthMethod, repoRef string, +) (map[string]string, string, error) { + args := m.Mock.Called(ctx, method, repoRef) return args.Get(0).(map[string]string), args.String(1), args.Error(2) } func (m *MockAuthService) CheckAuthStatus(ctx context.Context, statusToken string) (string, error) { - args := m.Called(ctx, statusToken) + args := m.Mock.Called(ctx, statusToken) return args.String(0), args.Error(1) } func (m *MockAuthService) ValidateAuth(ctx context.Context, authentication model.Authentication) (bool, error) { - args := m.Called(ctx, authentication) + args := m.Mock.Called(ctx, authentication) return args.Bool(0), args.Error(1) } @@ -87,12 +90,12 @@ func TestPublishHandler(t *testing.T) { }, authHeader: "Bearer github_token_123", setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { - authSvc.On("ValidateAuth", mock.Anything, model.Authentication{ + authSvc.Mock.On("ValidateAuth", mock.Anything, model.Authentication{ Method: model.AuthMethodGitHub, Token: "github_token_123", RepoRef: "io.github.example/test-server", }).Return(true, nil) - registry.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + registry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) }, expectedStatus: http.StatusCreated, expectedResponse: map[string]string{ @@ -122,12 +125,12 @@ func TestPublishHandler(t *testing.T) { }, authHeader: "Bearer some_token", setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { - authSvc.On("ValidateAuth", mock.Anything, model.Authentication{ + authSvc.Mock.On("ValidateAuth", mock.Anything, model.Authentication{ Method: model.AuthMethodNone, Token: "some_token", RepoRef: "example/test-server", }).Return(true, nil) - registry.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + registry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) }, expectedStatus: http.StatusCreated, expectedResponse: map[string]string{ @@ -140,7 +143,7 @@ func TestPublishHandler(t *testing.T) { method: http.MethodGet, requestBody: nil, authHeader: "", - setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) {}, + setupMocks: func(_ *MockRegistryService, _ *MockAuthService) {}, expectedStatus: http.StatusMethodNotAllowed, expectedError: "Method not allowed", }, @@ -149,7 +152,7 @@ func TestPublishHandler(t *testing.T) { method: http.MethodPost, requestBody: "", authHeader: "", - setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) {}, + setupMocks: func(_ *MockRegistryService, _ *MockAuthService) {}, expectedStatus: http.StatusBadRequest, expectedError: "Invalid request payload:", }, @@ -169,7 +172,7 @@ func TestPublishHandler(t *testing.T) { }, }, authHeader: "", - setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) {}, + setupMocks: func(_ *MockRegistryService, _ *MockAuthService) {}, expectedStatus: http.StatusBadRequest, expectedError: "Name is required", }, @@ -189,7 +192,7 @@ func TestPublishHandler(t *testing.T) { }, }, authHeader: "", - setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) {}, + setupMocks: func(_ *MockRegistryService, _ *MockAuthService) {}, expectedStatus: http.StatusBadRequest, expectedError: "Version is required", }, @@ -209,7 +212,7 @@ func TestPublishHandler(t *testing.T) { }, }, authHeader: "", // Missing auth header - setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) {}, + setupMocks: func(_ *MockRegistryService, _ *MockAuthService) {}, expectedStatus: http.StatusUnauthorized, expectedError: "Authorization header is required", }, @@ -229,8 +232,8 @@ func TestPublishHandler(t *testing.T) { }, }, authHeader: "Bearer token", - setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { - authSvc.On("ValidateAuth", mock.Anything, mock.Anything).Return(false, auth.ErrAuthRequired) + setupMocks: func(_ *MockRegistryService, authSvc *MockAuthService) { + authSvc.Mock.On("ValidateAuth", mock.Anything, mock.Anything).Return(false, auth.ErrAuthRequired) }, expectedStatus: http.StatusUnauthorized, expectedError: "Authentication is required for publishing", @@ -251,8 +254,8 @@ func TestPublishHandler(t *testing.T) { }, }, authHeader: "Bearer invalid_token", - setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { - authSvc.On("ValidateAuth", mock.Anything, mock.Anything).Return(false, nil) + setupMocks: func(_ *MockRegistryService, authSvc *MockAuthService) { + authSvc.Mock.On("ValidateAuth", mock.Anything, mock.Anything).Return(false, nil) }, expectedStatus: http.StatusUnauthorized, expectedError: "Invalid authentication credentials", @@ -274,8 +277,8 @@ func TestPublishHandler(t *testing.T) { }, authHeader: "Bearer token", setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { - authSvc.On("ValidateAuth", mock.Anything, mock.Anything).Return(true, nil) - registry.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(assert.AnError) + authSvc.Mock.On("ValidateAuth", mock.Anything, mock.Anything).Return(true, nil) + registry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(assert.AnError) }, expectedStatus: http.StatusInternalServerError, expectedError: "Failed to publish server details:", @@ -292,7 +295,7 @@ func TestPublishHandler(t *testing.T) { tc.setupMocks(mockRegistry, mockAuthService) // Create handler - handler := PublishHandler(mockRegistry, mockAuthService) + handler := v0.PublishHandler(mockRegistry, mockAuthService) // Prepare request body var requestBody []byte @@ -303,7 +306,7 @@ func TestPublishHandler(t *testing.T) { } // Create request - req, err := http.NewRequest(tc.method, "/publish", bytes.NewBuffer(requestBody)) + req, err := http.NewRequestWithContext(context.Background(), tc.method, "/publish", bytes.NewBuffer(requestBody)) assert.NoError(t, err) // Set auth header if provided @@ -337,8 +340,8 @@ func TestPublishHandler(t *testing.T) { } // Assert that all expectations were met - mockRegistry.AssertExpectations(t) - mockAuthService.AssertExpectations(t) + mockRegistry.Mock.AssertExpectations(t) + mockAuthService.Mock.AssertExpectations(t) }) } } @@ -377,12 +380,12 @@ func TestPublishHandlerBearerTokenParsing(t *testing.T) { mockAuthService := new(MockAuthService) // Setup mock to capture the actual token passed - mockAuthService.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { + mockAuthService.Mock.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { return auth.Token == tc.expectedToken })).Return(true, nil) - mockRegistry.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + mockRegistry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) - handler := PublishHandler(mockRegistry, mockAuthService) + handler := v0.PublishHandler(mockRegistry, mockAuthService) serverDetail := model.ServerDetail{ Server: model.Server{ @@ -400,7 +403,7 @@ func TestPublishHandlerBearerTokenParsing(t *testing.T) { requestBody, err := json.Marshal(serverDetail) assert.NoError(t, err) - req, err := http.NewRequest(http.MethodPost, "/publish", bytes.NewBuffer(requestBody)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/publish", bytes.NewBuffer(requestBody)) assert.NoError(t, err) req.Header.Set("Authorization", tc.authHeader) @@ -408,7 +411,7 @@ func TestPublishHandlerBearerTokenParsing(t *testing.T) { handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusCreated, rr.Code) - mockAuthService.AssertExpectations(t) + mockAuthService.Mock.AssertExpectations(t) }) } } @@ -442,12 +445,12 @@ func TestPublishHandlerAuthMethodSelection(t *testing.T) { mockAuthService := new(MockAuthService) // Setup mock to capture the auth method - mockAuthService.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { + mockAuthService.Mock.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { return auth.Method == tc.expectedAuthMethod })).Return(true, nil) - mockRegistry.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + mockRegistry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) - handler := PublishHandler(mockRegistry, mockAuthService) + handler := v0.PublishHandler(mockRegistry, mockAuthService) serverDetail := model.ServerDetail{ Server: model.Server{ @@ -465,7 +468,7 @@ func TestPublishHandlerAuthMethodSelection(t *testing.T) { requestBody, err := json.Marshal(serverDetail) assert.NoError(t, err) - req, err := http.NewRequest(http.MethodPost, "/publish", bytes.NewBuffer(requestBody)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/publish", bytes.NewBuffer(requestBody)) assert.NoError(t, err) req.Header.Set("Authorization", "Bearer test_token") @@ -473,7 +476,7 @@ func TestPublishHandlerAuthMethodSelection(t *testing.T) { handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusCreated, rr.Code) - mockAuthService.AssertExpectations(t) + mockAuthService.Mock.AssertExpectations(t) }) } } diff --git a/internal/api/handlers/v0/servers.go b/internal/api/handlers/v0/servers.go index a9818820..b2fc21f6 100644 --- a/internal/api/handlers/v0/servers.go +++ b/internal/api/handlers/v0/servers.go @@ -89,7 +89,10 @@ func ServersHandler(registry service.RegistryService) http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } } @@ -123,6 +126,9 @@ func ServersDetailHandler(registry service.RegistryService) http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(serverDetail) + if err := json.NewEncoder(w).Encode(serverDetail); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } } diff --git a/internal/api/handlers/v0/servers_test.go b/internal/api/handlers/v0/servers_test.go index b1b4a1da..380a3a6f 100644 --- a/internal/api/handlers/v0/servers_test.go +++ b/internal/api/handlers/v0/servers_test.go @@ -1,6 +1,7 @@ -package v0 +package v0_test import ( + "context" "encoding/json" "errors" "net/http" @@ -8,6 +9,7 @@ import ( "testing" "github.com/google/uuid" + v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" "github.com/modelcontextprotocol/registry/internal/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -21,7 +23,7 @@ func TestServersHandler(t *testing.T) { setupMocks func(*MockRegistryService) expectedStatus int expectedServers []model.Server - expectedMeta *Metadata + expectedMeta *v0.Metadata expectedError string }{ { @@ -60,7 +62,7 @@ func TestServersHandler(t *testing.T) { }, }, } - registry.On("List", "", 30).Return(servers, "", nil) + registry.Mock.On("List", "", 30).Return(servers, "", nil) }, expectedStatus: http.StatusOK, expectedServers: []model.Server{ @@ -119,7 +121,7 @@ func TestServersHandler(t *testing.T) { }, } nextCursor := uuid.New().String() - registry.On("List", mock.AnythingOfType("string"), 10).Return(servers, nextCursor, nil) + registry.Mock.On("List", mock.AnythingOfType("string"), 10).Return(servers, nextCursor, nil) }, expectedStatus: http.StatusOK, expectedServers: []model.Server{ @@ -139,7 +141,7 @@ func TestServersHandler(t *testing.T) { }, }, }, - expectedMeta: &Metadata{ + expectedMeta: &v0.Metadata{ NextCursor: "", // This will be dynamically set in the test Count: 1, }, @@ -150,7 +152,7 @@ func TestServersHandler(t *testing.T) { queryParams: "?limit=150", setupMocks: func(registry *MockRegistryService) { servers := []model.Server{} - registry.On("List", "", 100).Return(servers, "", nil) + registry.Mock.On("List", "", 100).Return(servers, "", nil) }, expectedStatus: http.StatusOK, expectedServers: []model.Server{}, @@ -159,7 +161,7 @@ func TestServersHandler(t *testing.T) { name: "invalid cursor parameter", method: http.MethodGet, queryParams: "?cursor=invalid-uuid", - setupMocks: func(registry *MockRegistryService) {}, + setupMocks: func(_ *MockRegistryService) {}, expectedStatus: http.StatusBadRequest, expectedError: "Invalid cursor parameter", }, @@ -167,7 +169,7 @@ func TestServersHandler(t *testing.T) { name: "invalid limit parameter - non-numeric", method: http.MethodGet, queryParams: "?limit=abc", - setupMocks: func(registry *MockRegistryService) {}, + setupMocks: func(_ *MockRegistryService) {}, expectedStatus: http.StatusBadRequest, expectedError: "Invalid limit parameter", }, @@ -175,7 +177,7 @@ func TestServersHandler(t *testing.T) { name: "invalid limit parameter - zero", method: http.MethodGet, queryParams: "?limit=0", - setupMocks: func(registry *MockRegistryService) {}, + setupMocks: func(_ *MockRegistryService) {}, expectedStatus: http.StatusBadRequest, expectedError: "Limit must be greater than 0", }, @@ -183,7 +185,7 @@ func TestServersHandler(t *testing.T) { name: "invalid limit parameter - negative", method: http.MethodGet, queryParams: "?limit=-5", - setupMocks: func(registry *MockRegistryService) {}, + setupMocks: func(_ *MockRegistryService) {}, expectedStatus: http.StatusBadRequest, expectedError: "Limit must be greater than 0", }, @@ -191,7 +193,7 @@ func TestServersHandler(t *testing.T) { name: "registry service error", method: http.MethodGet, setupMocks: func(registry *MockRegistryService) { - registry.On("List", "", 30).Return([]model.Server{}, "", errors.New("database connection error")) + registry.Mock.On("List", "", 30).Return([]model.Server{}, "", errors.New("database connection error")) }, expectedStatus: http.StatusInternalServerError, expectedError: "database connection error", @@ -199,7 +201,7 @@ func TestServersHandler(t *testing.T) { { name: "method not allowed", method: http.MethodPost, - setupMocks: func(registry *MockRegistryService) {}, + setupMocks: func(_ *MockRegistryService) {}, expectedStatus: http.StatusMethodNotAllowed, expectedError: "Method not allowed", }, @@ -212,11 +214,11 @@ func TestServersHandler(t *testing.T) { tc.setupMocks(mockRegistry) // Create handler - handler := ServersHandler(mockRegistry) + handler := v0.ServersHandler(mockRegistry) // Create request url := "/v0/servers" + tc.queryParams - req, err := http.NewRequest(tc.method, url, nil) + req, err := http.NewRequestWithContext(context.Background(), tc.method, url, nil) if err != nil { t.Fatal(err) } @@ -235,7 +237,7 @@ func TestServersHandler(t *testing.T) { assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) // Parse response body - var resp PaginatedResponse + var resp v0.PaginatedResponse err = json.NewDecoder(rr.Body).Decode(&resp) assert.NoError(t, err) @@ -249,15 +251,13 @@ func TestServersHandler(t *testing.T) { assert.NotEmpty(t, resp.Metadata.NextCursor) } } - } else { + } else if tc.expectedError != "" { // Check error message for non-200 responses - if tc.expectedError != "" { - assert.Contains(t, rr.Body.String(), tc.expectedError) - } + assert.Contains(t, rr.Body.String(), tc.expectedError) } // Verify mock expectations - mockRegistry.AssertExpectations(t) + mockRegistry.Mock.AssertExpectations(t) }) } } @@ -285,14 +285,21 @@ func TestServersHandlerIntegration(t *testing.T) { }, } - mockRegistry.On("List", "", 30).Return(servers, "", nil) + mockRegistry.Mock.On("List", "", 30).Return(servers, "", nil) // Create test server - server := httptest.NewServer(ServersHandler(mockRegistry)) + server := httptest.NewServer(v0.ServersHandler(mockRegistry)) defer server.Close() // Send request to the test server - resp, err := http.Get(server.URL) + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { t.Fatalf("Failed to send request: %v", err) } @@ -305,7 +312,7 @@ func TestServersHandlerIntegration(t *testing.T) { assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) // Parse response body - var paginatedResp PaginatedResponse + var paginatedResp v0.PaginatedResponse err = json.NewDecoder(resp.Body).Decode(&paginatedResp) assert.NoError(t, err) @@ -314,7 +321,7 @@ func TestServersHandlerIntegration(t *testing.T) { assert.Empty(t, paginatedResp.Metadata.NextCursor) // Verify mock expectations - mockRegistry.AssertExpectations(t) + mockRegistry.Mock.AssertExpectations(t) } // TestServersDetailHandlerIntegration tests the servers detail handler with actual HTTP requests @@ -342,17 +349,24 @@ func TestServersDetailHandlerIntegration(t *testing.T) { }, } - mockRegistry.On("GetByID", serverID).Return(serverDetail, nil) + mockRegistry.Mock.On("GetByID", serverID).Return(serverDetail, nil) // Create test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.SetPathValue("id", serverID) - ServersDetailHandler(mockRegistry).ServeHTTP(w, r) + v0.ServersDetailHandler(mockRegistry).ServeHTTP(w, r) })) defer server.Close() // Send request to the test server - resp, err := http.Get(server.URL) + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { t.Fatalf("Failed to send request: %v", err) } @@ -373,5 +387,5 @@ func TestServersDetailHandlerIntegration(t *testing.T) { assert.Equal(t, *serverDetail, serverDetailResp) // Verify mock expectations - mockRegistry.AssertExpectations(t) + mockRegistry.Mock.AssertExpectations(t) } diff --git a/internal/api/handlers/v0/swagger.go b/internal/api/handlers/v0/swagger.go index fea5b296..6a368f88 100644 --- a/internal/api/handlers/v0/swagger.go +++ b/internal/api/handlers/v0/swagger.go @@ -6,7 +6,7 @@ import ( "os" "path/filepath" - _ "github.com/swaggo/files" + _ "github.com/swaggo/files" // Swagger files needed for embedding httpSwagger "github.com/swaggo/http-swagger" ) diff --git a/internal/api/router/v0.go b/internal/api/router/v0.go index 3564d7e4..6d465f99 100644 --- a/internal/api/router/v0.go +++ b/internal/api/router/v0.go @@ -11,7 +11,9 @@ import ( ) // RegisterV0Routes registers all v0 API routes to the provided router -func RegisterV0Routes(mux *http.ServeMux, cfg *config.Config, registry service.RegistryService, authService auth.Service) { +func RegisterV0Routes( + mux *http.ServeMux, cfg *config.Config, registry service.RegistryService, authService auth.Service, +) { // Register v0 endpoints mux.HandleFunc("/v0/health", v0.HealthHandler(cfg)) mux.HandleFunc("/v0/servers", v0.ServersHandler(registry)) diff --git a/internal/api/server.go b/internal/api/server.go index 53715822..c92a1a54 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -4,6 +4,7 @@ import ( "context" "log" "net/http" + "time" "github.com/modelcontextprotocol/registry/internal/api/router" "github.com/modelcontextprotocol/registry/internal/auth" @@ -31,8 +32,9 @@ func NewServer(cfg *config.Config, registryService service.RegistryService, auth authService: authService, router: mux, server: &http.Server{ - Addr: cfg.ServerAddress, - Handler: mux, + Addr: cfg.ServerAddress, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, }, } diff --git a/internal/auth/github.go b/internal/auth/github.go index ce1daf65..ac813037 100644 --- a/internal/auth/github.go +++ b/internal/auth/github.go @@ -3,6 +3,7 @@ package auth import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -70,14 +71,19 @@ func NewGitHubDeviceAuth(config GitHubOAuthConfig) *GitHubDeviceAuth { // It verifies the token owner matches the repository owner or is a member of the owning organization. // It also verifies that the token was created for the same ClientID used to set up the authentication. // Returns true if valid, false otherwise along with an error explaining the validation failure. -func (g *GitHubDeviceAuth) ValidateToken(token string, requiredRepo string) (bool, error) { +func (g *GitHubDeviceAuth) ValidateToken(ctx context.Context, token string, requiredRepo string) (bool, error) { // If no repo is required, we can't validate properly if requiredRepo == "" { return false, fmt.Errorf("repository reference is required for token validation") } // First, validate that the token is associated with our ClientID - tokenReq, err := http.NewRequest("GET", "https://api.github.com/applications/"+g.config.ClientID+"/token", nil) + tokenReq, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + "https://api.github.com/applications/"+g.config.ClientID+"/token", + nil, + ) if err != nil { return false, err } @@ -97,7 +103,8 @@ func (g *GitHubDeviceAuth) ValidateToken(token string, requiredRepo string) (boo } // POST instead of GET for security reasons per GitHub API - tokenReq, err = http.NewRequest("POST", "https://api.github.com/applications/"+g.config.ClientID+"/token", io.NopCloser(bytes.NewReader(checkBody))) + tokenURL := "https://api.github.com/applications/" + g.config.ClientID + "/token" + tokenReq, err = http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, io.NopCloser(bytes.NewReader(checkBody))) if err != nil { return false, err } @@ -135,7 +142,7 @@ func (g *GitHubDeviceAuth) ValidateToken(token string, requiredRepo string) (boo } // Get the authenticated user - userReq, err := http.NewRequest("GET", "https://api.github.com/user", nil) + userReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/user", nil) if err != nil { return false, err } @@ -175,17 +182,20 @@ func (g *GitHubDeviceAuth) ValidateToken(token string, requiredRepo string) (boo // Verify that the authenticated user matches the owner if userInfo.Login != owner { // Check if the user is a member of the organization - isMember, err := g.checkOrgMembership(token, userInfo.Login, owner) + isMember, err := g.checkOrgMembership(ctx, token, userInfo.Login, owner) if err != nil { return false, fmt.Errorf("failed to check org membership: %s", owner) } if !isMember { - return false, fmt.Errorf("token belongs to user %s, but repository is owned by %s and user is not a member of the organization", userInfo.Login, owner) + return false, fmt.Errorf( + "token belongs to user %s, but repository is owned by %s and user is not a member of the organization", + userInfo.Login, owner) } } - // If we've reached this point, the token has access the repo and the user matches the owner or is a member of the owner org + // If we've reached this point, the token has access the repo and the user matches + // the owner or is a member of the owner org return true, nil } @@ -210,13 +220,13 @@ func (g *GitHubDeviceAuth) ExtractGitHubRepo(repoURL string) (owner, repo string } // checkOrgMembership checks if a user is a member of an organization -func (g *GitHubDeviceAuth) checkOrgMembership(token, username, org string) (bool, error) { +func (g *GitHubDeviceAuth) checkOrgMembership(ctx context.Context, token, username, org string) (bool, error) { // Create request to check if user is a member of the organization // GitHub API endpoint: GET /orgs/{org}/members/{username} // true if status code is 204 No Content // false if status code is 404 Not Found url := fmt.Sprint("https://api.github.com/orgs/", org, "/members/", username) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return false, err } diff --git a/internal/auth/service.go b/internal/auth/service.go index 54a9f01e..d8fdf70a 100644 --- a/internal/auth/service.go +++ b/internal/auth/service.go @@ -8,37 +8,40 @@ import ( "github.com/modelcontextprotocol/registry/internal/model" ) -// AuthServiceImpl implements the Service interface -type AuthServiceImpl struct { +// ServiceImpl implements the Service interface +type ServiceImpl struct { config *config.Config githubAuth *GitHubDeviceAuth } // NewAuthService creates a new authentication service +// +//nolint:ireturn // Factory function intentionally returns interface for dependency injection func NewAuthService(cfg *config.Config) Service { githubConfig := GitHubOAuthConfig{ ClientID: cfg.GithubClientID, ClientSecret: cfg.GithubClientSecret, } - return &AuthServiceImpl{ + return &ServiceImpl{ config: cfg, githubAuth: NewGitHubDeviceAuth(githubConfig), } } -func (s *AuthServiceImpl) StartAuthFlow(ctx context.Context, method model.AuthMethod, repoRef string) (map[string]string, string, error) { +func (s *ServiceImpl) StartAuthFlow(_ context.Context, _ model.AuthMethod, + _ string) (map[string]string, string, error) { // return not implemented error return nil, "", fmt.Errorf("not implemented") } -func (s *AuthServiceImpl) CheckAuthStatus(ctx context.Context, statusToken string) (string, error) { +func (s *ServiceImpl) CheckAuthStatus(_ context.Context, _ string) (string, error) { // return not implemented error return "", fmt.Errorf("not implemented") } // ValidateAuth validates authentication credentials -func (s *AuthServiceImpl) ValidateAuth(ctx context.Context, auth model.Authentication) (bool, error) { +func (s *ServiceImpl) ValidateAuth(ctx context.Context, auth model.Authentication) (bool, error) { // If authentication is required but not provided if auth.Method == "" || auth.Method == model.AuthMethodNone { return false, ErrAuthRequired @@ -47,7 +50,9 @@ func (s *AuthServiceImpl) ValidateAuth(ctx context.Context, auth model.Authentic switch auth.Method { case model.AuthMethodGitHub: // Extract repo reference from the repository URL if it's not provided - return s.githubAuth.ValidateToken(auth.Token, auth.RepoRef) + return s.githubAuth.ValidateToken(ctx, auth.Token, auth.RepoRef) + case model.AuthMethodNone: + return false, ErrAuthRequired default: return false, ErrUnsupportedAuthMethod } diff --git a/internal/config/config.go b/internal/config/config.go index 01db852f..cbb5637c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,7 @@ package config import ( - "github.com/caarlos0/env/v11" + env "github.com/caarlos0/env/v11" ) // Config holds the application configuration diff --git a/internal/database/import.go b/internal/database/import.go index db81fc09..ee70f107 100644 --- a/internal/database/import.go +++ b/internal/database/import.go @@ -33,7 +33,7 @@ func ImportSeedFile(mongo *MongoDB, seedFilePath string) error { // Read the seed file seedData, err := readSeedFile(seedFilePath) if err != nil { - log.Fatalf("Failed to read seed file: %v", err) + return fmt.Errorf("failed to read seed file: %w", err) } collection := mongo.collection @@ -48,7 +48,7 @@ func readSeedFile(path string) ([]model.ServerDetail, error) { // Read the file content fileContent, err := os.ReadFile(path) if err != nil { - return nil, fmt.Errorf("failed to read file: %v", err) + return nil, fmt.Errorf("failed to read file: %w", err) } // Parse the JSON content @@ -57,9 +57,8 @@ func readSeedFile(path string) ([]model.ServerDetail, error) { // Try parsing as a raw JSON array and then convert to our model var rawData []map[string]interface{} if jsonErr := json.Unmarshal(fileContent, &rawData); jsonErr != nil { - return nil, fmt.Errorf("failed to parse JSON: %v (original error: %v)", jsonErr, err) + return nil, fmt.Errorf("failed to parse JSON: %w (original error: %w)", jsonErr, err) } - } log.Printf("Found %d server entries in seed file", len(servers)) @@ -82,7 +81,6 @@ func importData(ctx context.Context, collection *mongo.Collection, servers []mod server.VersionDetail.Version = "0.0.1-seed" server.VersionDetail.ReleaseDate = time.Now().Format(time.RFC3339) server.VersionDetail.IsLatest = true - } // Create update document update := bson.M{"$set": server} @@ -95,11 +93,12 @@ func importData(ctx context.Context, collection *mongo.Collection, servers []mod continue } - if result.UpsertedCount > 0 { + switch { + case result.UpsertedCount > 0: log.Printf("[%d/%d] Created server: %s", i+1, len(servers), server.Name) - } else if result.ModifiedCount > 0 { + case result.ModifiedCount > 0: log.Printf("[%d/%d] Updated server: %s", i+1, len(servers), server.Name) - } else { + default: log.Printf("[%d/%d] Server already up to date: %s", i+1, len(servers), server.Name) } } diff --git a/internal/database/memory.go b/internal/database/memory.go index 15f595e9..5494b99f 100644 --- a/internal/database/memory.go +++ b/internal/database/memory.go @@ -84,7 +84,14 @@ func compareSemanticVersions(version1, version2 string) int { } // List retrieves all MCPRegistry entries with optional filtering and pagination -func (db *MemoryDB) List(ctx context.Context, filter map[string]interface{}, cursor string, limit int) ([]*model.Server, string, error) { +// +//gocognit:ignore +func (db *MemoryDB) List( + ctx context.Context, + filter map[string]interface{}, + cursor string, + limit int, +) ([]*model.Server, string, error) { if ctx.Err() != nil { return nil, "", ctx.Err() } @@ -109,27 +116,25 @@ func (db *MemoryDB) List(ctx context.Context, filter map[string]interface{}, cur include := true // Apply filters if any - if filter != nil { - for key, value := range filter { - switch key { - case "name": - if entry.Name != value.(string) { - include = false - } - case "repoUrl": - if entry.Repository.URL != value.(string) { - include = false - } - case "serverDetail.id": - if entry.ID != value.(string) { - include = false - } - case "version": - if entry.VersionDetail.Version != value.(string) { - include = false - } - // Add more filter options as needed + for key, value := range filter { + switch key { + case "name": + if entry.Name != value.(string) { + include = false + } + case "repoUrl": + if entry.Repository.URL != value.(string) { + include = false + } + case "serverDetail.id": + if entry.ID != value.(string) { + include = false + } + case "version": + if entry.VersionDetail.Version != value.(string) { + include = false } + // Add more filter options as needed } } diff --git a/internal/database/mongo.go b/internal/database/mongo.go index ef0eb779..a553d0fa 100644 --- a/internal/database/mongo.go +++ b/internal/database/mongo.go @@ -2,6 +2,7 @@ package database import ( "context" + "errors" "fmt" "log" "time" @@ -41,15 +42,15 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName // Create indexes for better query performance models := []mongo.IndexModel{ { - Keys: bson.D{{Key: "name", Value: 1}}, + Keys: bson.D{bson.E{Key: "name", Value: 1}}, }, { - Keys: bson.D{{Key: "id", Value: 1}}, + Keys: bson.D{bson.E{Key: "id", Value: 1}}, Options: options.Index().SetUnique(true), }, // add an index for the combination of name and version { - Keys: bson.D{{Key: "name", Value: 1}, {Key: "versiondetail.version", Value: 1}}, + Keys: bson.D{bson.E{Key: "name", Value: 1}, bson.E{Key: "versiondetail.version", Value: 1}}, Options: options.Index().SetUnique(true), }, } @@ -57,11 +58,11 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName _, err = collection.Indexes().CreateMany(ctx, models) if err != nil { // Mongo will error if the index already exists, we can ignore this and continue. - if err.(mongo.CommandError).Code != 86 { + var commandError mongo.CommandError + if errors.As(err, &commandError) && commandError.Code != 86 { return nil, err - } else { - log.Printf("Indexes already exists, skipping.") } + log.Printf("Indexes already exists, skipping.") } return &MongoDB{ @@ -72,7 +73,12 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName } // List retrieves MCPRegistry entries with optional filtering and pagination -func (db *MongoDB) List(ctx context.Context, filter map[string]interface{}, cursor string, limit int) ([]*model.Server, string, error) { +func (db *MongoDB) List( + ctx context.Context, + filter map[string]interface{}, + cursor string, + limit int, +) ([]*model.Server, string, error) { if limit <= 0 { // Set default limit if not provided limit = 10 @@ -113,11 +119,10 @@ func (db *MongoDB) List(ctx context.Context, filter map[string]interface{}, curs var cursorDoc model.Server err := db.collection.FindOne(ctx, bson.M{"id": cursor}).Decode(&cursorDoc) if err != nil { - if err == mongo.ErrNoDocuments { - // If cursor document not found, start from beginning - } else { + if !errors.Is(err, mongo.ErrNoDocuments) { return nil, "", err } + // If cursor document not found, start from beginning } else { // Use the cursor document's ID to paginate (records with ID > cursor's ID) mongoFilter["id"] = bson.M{"$gt": cursor} @@ -168,7 +173,7 @@ func (db *MongoDB) GetByID(ctx context.Context, id string) (*model.ServerDetail, var entry model.ServerDetail err := db.collection.FindOne(ctx, filter).Decode(&entry) if err != nil { - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { return nil, ErrNotFound } return nil, fmt.Errorf("error retrieving entry: %w", err) @@ -191,7 +196,7 @@ func (db *MongoDB) Publish(ctx context.Context, serverDetail *model.ServerDetail var existingEntry model.ServerDetail err := db.collection.FindOne(ctx, filter).Decode(&existingEntry) - if err != nil && err != mongo.ErrNoDocuments { + if err != nil && !errors.Is(err, mongo.ErrNoDocuments) { return fmt.Errorf("error checking existing entry: %w", err) } @@ -215,11 +220,13 @@ func (db *MongoDB) Publish(ctx context.Context, serverDetail *model.ServerDetail // update the existing entry to not be the latest version if existingEntry.ID != "" { - _, err = db.collection.UpdateOne(ctx, bson.M{"id": existingEntry.ID}, bson.M{"$set": bson.M{"versiondetail.islatest": false}}) + _, err = db.collection.UpdateOne( + ctx, + bson.M{"id": existingEntry.ID}, + bson.M{"$set": bson.M{"versiondetail.islatest": false}}) if err != nil { return fmt.Errorf("error updating existing entry: %w", err) } - } return nil diff --git a/internal/service/fake_service.go b/internal/service/fake_service.go index ba9c1a16..07aa805d 100644 --- a/internal/service/fake_service.go +++ b/internal/service/fake_service.go @@ -15,6 +15,8 @@ type fakeRegistryService struct { } // NewFakeRegistryService creates a new fake registry service with pre-populated data +// +//nolint:ireturn // Factory function intentionally returns interface for dependency injection func NewFakeRegistryService() RegistryService { // Sample registry entries with updated model structure registries := []*model.Server{ diff --git a/internal/service/registry_service.go b/internal/service/registry_service.go index 409d21c0..d9798be3 100644 --- a/internal/service/registry_service.go +++ b/internal/service/registry_service.go @@ -14,6 +14,8 @@ type registryServiceImpl struct { } // NewRegistryServiceWithDB creates a new registry service with the provided database +// +//nolint:ireturn // Factory function intentionally returns interface for dependency injection func NewRegistryServiceWithDB(db database.Database) RegistryService { return ®istryServiceImpl{ db: db, diff --git a/tools/publisher/main.go b/tools/publisher/main.go index 9ed544ed..dcde2296 100644 --- a/tools/publisher/main.go +++ b/tools/publisher/main.go @@ -2,10 +2,12 @@ package main import ( "bytes" + "context" "encoding/json" "flag" "fmt" "io" + "log" "net/http" "os" "strings" @@ -13,11 +15,10 @@ import ( ) const ( - tokenFilePath = ".mcpregistry_token" - + tokenFilePath = ".mcpregistry_token" // #nosec:G101 // GitHub OAuth URLs - GitHubDeviceCodeURL = "https://github.com/login/device/code" - GitHubAccessTokenURL = "https://github.com/login/oauth/access_token" + GitHubDeviceCodeURL = "https://github.com/login/device/code" // #nosec:G101 + GitHubAccessTokenURL = "https://github.com/login/oauth/access_token" // #nosec:G101 ) // DeviceCodeResponse represents the response from GitHub's device code endpoint @@ -39,7 +40,7 @@ type AccessTokenResponse struct { type ServerHealthResponse struct { Status string `json:"status"` - GitHubClientId string `json:"github_client_id"` + GitHubClientID string `json:"github_client_id"` } func main() { @@ -62,29 +63,36 @@ func main() { // get the clientID from the server's health endpoint healthURL := registryURL + "/v0/health" - resp, err := http.Get(healthURL) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, healthURL, nil) + if err != nil { + log.Printf("Error creating request: %s\n", err.Error()) + return + } + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { - fmt.Printf("Error fetching health endpoint: %s\n", err.Error()) + log.Printf("Error fetching health endpoint: %s\n", err.Error()) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - fmt.Printf("Health endpoint returned status %d: %s\n", resp.StatusCode, body) + log.Printf("Health endpoint returned status %d: %s\n", resp.StatusCode, body) return } var healthResponse ServerHealthResponse err = json.NewDecoder(resp.Body).Decode(&healthResponse) if err != nil { - fmt.Printf("Error decoding health response: %s\n", err.Error()) + log.Printf("Error decoding health response: %s\n", err.Error()) return } - if healthResponse.GitHubClientId == "" { - fmt.Println("GitHub Client ID is not set in the server's health response.") + if healthResponse.GitHubClientID == "" { + log.Println("GitHub Client ID is not set in the server's health response.") return } - githubClientID := healthResponse.GitHubClientId + githubClientID := healthResponse.GitHubClientID var token string @@ -97,7 +105,7 @@ func main() { if forceLogin || os.IsNotExist(statErr) { err := performDeviceFlowLogin(githubClientID) if err != nil { - fmt.Printf("Failed to perform device flow login: %s\n", err.Error()) + log.Printf("Failed to perform device flow login: %s\n", err.Error()) return } } @@ -106,7 +114,7 @@ func main() { var err error token, err = readToken() if err != nil { - fmt.Printf("Error reading token: %s\n", err.Error()) + log.Printf("Error reading token: %s\n", err.Error()) return } } @@ -114,22 +122,21 @@ func main() { // Read MCP file mcpData, err := os.ReadFile(mcpFilePath) if err != nil { - fmt.Printf("Error reading MCP file: %s\n", err.Error()) + log.Printf("Error reading MCP file: %s\n", err.Error()) return } // Publish to registry err = publishToRegistry(registryURL, mcpData, token) if err != nil { - fmt.Printf("Failed to publish to registry: %s\n", err.Error()) + log.Printf("Failed to publish to registry: %s\n", err.Error()) return } - fmt.Println("Successfully published to registry!") + log.Println("Successfully published to registry!") } func performDeviceFlowLogin(githubClientID string) error { - if githubClientID == "" { return fmt.Errorf("GitHub Client ID is required for device flow login") } @@ -142,13 +149,13 @@ func performDeviceFlowLogin(githubClientID string) error { } // Display instructions to the user - fmt.Println("\nTo authenticate, please:") - fmt.Println("1. Go to:", verificationURI) - fmt.Println("2. Enter code:", userCode) - fmt.Println("3. Authorize this application") + log.Println("\nTo authenticate, please:") + log.Println("1. Go to:", verificationURI) + log.Println("2. Enter code:", userCode) + log.Println("3. Authorize this application") // Poll for the token - fmt.Println("Waiting for authorization...") + log.Println("Waiting for authorization...") token, err := pollForToken(deviceCode, githubClientID) if err != nil { return fmt.Errorf("error polling for token: %w", err) @@ -160,7 +167,7 @@ func performDeviceFlowLogin(githubClientID string) error { return fmt.Errorf("error saving token: %w", err) } - fmt.Println("Successfully authenticated!") + log.Println("Successfully authenticated!") return nil } @@ -180,7 +187,7 @@ func requestDeviceCode(githubClientID string) (string, string, string, error) { return "", "", "", err } - req, err := http.NewRequest("POST", GitHubDeviceCodeURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, GitHubDeviceCodeURL, bytes.NewBuffer(jsonData)) if err != nil { return "", "", "", err } @@ -235,7 +242,7 @@ func pollForToken(deviceCode, githubClientID string) (string, error) { deadline := time.Now().Add(time.Duration(expiresIn) * time.Second) for time.Now().Before(deadline) { - req, err := http.NewRequest("POST", GitHubAccessTokenURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, GitHubAccessTokenURL, bytes.NewBuffer(jsonData)) if err != nil { return "", err } @@ -262,7 +269,6 @@ func pollForToken(deviceCode, githubClientID string) (string, error) { if tokenResp.Error == "authorization_pending" { // User hasn't authorized yet, wait and retry - fmt.Print(".") time.Sleep(time.Duration(interval) * time.Second) continue } @@ -272,7 +278,6 @@ func pollForToken(deviceCode, githubClientID string) (string, error) { } if tokenResp.AccessToken != "" { - fmt.Println() // Add newline after dots return tokenResp.AccessToken, nil } @@ -322,7 +327,7 @@ func publishToRegistry(registryURL string, mcpData []byte, token string) error { publishURL := registryURL + "v0/publish" // Create and send the request - req, err := http.NewRequest("POST", publishURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, publishURL, bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("error creating request: %w", err) } @@ -346,6 +351,6 @@ func publishToRegistry(registryURL string, mcpData []byte, token string) error { return fmt.Errorf("publication failed with status %d: %s", resp.StatusCode, body) } - println(string(body)) + log.Println(string(body)) return nil } From fc92f07dd65ac0061850a56dba566d65841ae6d3 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 19:03:09 -0400 Subject: [PATCH 12/15] feat(script): add usage instructions and server startup methods to test script --- scripts/test_endpoints.sh | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/scripts/test_endpoints.sh b/scripts/test_endpoints.sh index c5b7c038..2efd2f3b 100755 --- a/scripts/test_endpoints.sh +++ b/scripts/test_endpoints.sh @@ -2,6 +2,17 @@ set -e +echo "==================================================" +echo "MCP Registry Endpoint Test Script" +echo "==================================================" +echo "This script expects the MCP Registry server to be running locally." +echo "Please ensure the server is started using one of the following methods:" +echo " • Docker Compose: docker compose up" +echo " • Direct execution: go run cmd/registry/main.go" +echo " • Built binary: ./build/registry" +echo "==================================================" +echo "" + # Default values HOST="http://localhost:8080" ENDPOINT="all" From 434311a596161ab767c3ab45c50b1cb44b21d604 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 19:29:45 -0400 Subject: [PATCH 13/15] feat(script): update publish test script --- scripts/test_publish.sh | 179 +++++++++++++++++++++++++--------------- 1 file changed, 112 insertions(+), 67 deletions(-) diff --git a/scripts/test_publish.sh b/scripts/test_publish.sh index 9425a09c..4a373675 100755 --- a/scripts/test_publish.sh +++ b/scripts/test_publish.sh @@ -2,6 +2,20 @@ set -e +echo "==================================================" +echo "MCP Registry Publish Endpoint Test Script" +echo "==================================================" +echo "This script expects the MCP Registry server to be running locally." +echo "Please ensure the server is started using one of the following methods:" +echo " • Docker Compose: docker compose up" +echo " • Direct execution: go run cmd/registry/main.go" +echo " • Built binary: ./build/registry" +echo "" +echo "REQUIRED: Set the BEARER_TOKEN environment variable with a valid GitHub token" +echo "Example: export BEARER_TOKEN=your_github_token_here" +echo "==================================================" +echo "" + # Default values HOST="http://localhost:8080" VERBOSE=false @@ -13,9 +27,20 @@ function show_usage { echo " -h, --host Base URL of the MCP Registry service (default: http://localhost:8080)" echo " -v, --verbose Show verbose output including full request payload" echo " --help Show this help message" + echo "" + echo "Environment Variables:" + echo " BEARER_TOKEN Required: GitHub token for authentication" exit 1 } +# Check if bearer token is set +if [[ -z "$BEARER_TOKEN" ]]; then + echo "Error: BEARER_TOKEN environment variable is not set." + echo "Please set your GitHub token as an environment variable:" + echo " export BEARER_TOKEN=your_github_token_here" + exit 1 +fi + # Check if jq is installed if ! command -v jq &> /dev/null; then echo "Error: jq is required but not installed." @@ -27,19 +52,14 @@ if ! command -v jq &> /dev/null; then fi # Check if the API is running -echo "Checking if the API is running at $HOST..." +echo "Checking if the MCP Registry API is running at $HOST..." health_check=$(curl -s -o /dev/null -w "%{http_code}" "$HOST/v0/health" 2>/dev/null) if [[ "$health_check" != "200" ]]; then - echo "Warning: API might not be running at $HOST (health check returned $health_check)" - echo "Do you want to continue anyway? (y/n)" - read -r proceed - if [[ ! "$proceed" =~ ^[Yy]$ ]]; then - echo "Exiting. Please start the API and try again." - exit 1 - fi - echo "Continuing as requested..." + echo "Error: MCP Registry API is not running at $HOST (health check returned $health_check)" + echo "Please start the server using one of the methods mentioned above and try again." + exit 1 else - echo "API is running at $HOST" + echo "✓ MCP Registry API is running at $HOST" fi # Parse command line arguments @@ -56,46 +76,51 @@ done # Create a temporary file for our JSON payload PAYLOAD_FILE=$(mktemp) -# Create sample server detail payload +# Create sample server detail payload based on current model structure cat > "$PAYLOAD_FILE" << EOF { - "name": "Test MCP Server", - "description": "A test server for MCP Registry", - "version_detail": { - "version": "1.0.2", - "release_date": "$(date -u +"%Y-%m-%dT%H:%M:%SZ")", - "is_latest": true - }, + "name": "io.github.example/test-mcp-server", + "description": "A test server for MCP Registry validation - published at $(date)", "repository": { "url": "https://github.com/example/test-mcp-server", - "branch": "main" + "source": "github", + "id": "example/test-mcp-server" }, - "registries": [ - { - "name": "npm", - "package_name": "test-mcp-server", - "license": "MIT", - "command_arguments": { - "sub_commands": [ - { - "name": "start", - "description": "Start the server" - } - ], - "environment_variables": [ - { - "name": "PORT", - "description": "Port to run the server on", - "required": false - } - ] - } - } - ], - "remotes": [ + "version_detail": { + "version": "1.0.$(date +%s)" + }, + "packages": [ { - "transport_type": "http", - "url": "http://example.com/api" + "registry_name": "npm", + "name": "test-mcp-server", + "version": "1.0.$(date +%s)", + "runtime_hint": "node", + "runtime_arguments": [ + { + "type": "positional", + "name": "config", + "description": "Configuration file path", + "format": "file_path", + "is_required": false, + "default": "./config.json" + } + ], + "environment_variables": [ + { + "name": "PORT", + "description": "Port to run the server on", + "format": "number", + "is_required": false, + "default": "3000" + }, + { + "name": "API_KEY", + "description": "API key for external service", + "format": "string", + "is_required": true, + "is_secret": true + } + ] } ] } @@ -110,17 +135,16 @@ fi # Test publish endpoint echo "Testing publish endpoint: $HOST/v0/publish" +echo "Using Bearer Token: ${BEARER_TOKEN:0:10}..." # Show only first 10 chars for security + # Get response and status code in a single request response_file=$(mktemp) headers_file=$(mktemp) -# Get token for authentication (or use dummy token for testing) -AUTH_TOKEN=${AUTH_TOKEN:-"test_token"} - # Execute curl with response body to file and headers+status to another file curl -s -X POST \ -H "Content-Type: application/json" \ - -H "Authorization: Bearer ${AUTH_TOKEN}" \ + -H "Authorization: Bearer ${BEARER_TOKEN}" \ -d "@$PAYLOAD_FILE" \ -D "$headers_file" \ -o "$response_file" \ @@ -143,33 +167,54 @@ if [[ "${status_code:0:1}" == "2" ]]; then echo "Response:" echo "$http_response" | jq '.' 2>/dev/null || echo "$http_response" - # Extract the server ID from the response - server_id=$(echo "$http_response" | jq -r '.id') + # Check for server added message and extract UUID + message=$(echo "$http_response" | jq -r '.message // empty' 2>/dev/null) + server_id=$(echo "$http_response" | jq -r '.id // .server_id // empty' 2>/dev/null) + + # Validate the response contains success indicators + success_indicators=0 - echo "Publish successful with ID: $server_id" + if [[ ! -z "$message" && "$message" != "null" ]]; then + echo "✓ Success message received: $message" + if [[ "$message" == *"server"* && ("$message" == *"added"* || "$message" == *"published"* || "$message" == *"created"*) ]]; then + ((success_indicators++)) + echo "✓ Message indicates server was successfully added" + fi + fi + + if [[ ! -z "$server_id" && "$server_id" != "null" && "$server_id" != "empty" ]]; then + echo "✓ Server UUID received: $server_id" + # Validate UUID format (basic check for UUID pattern) + if [[ "$server_id" =~ ^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$ ]]; then + ((success_indicators++)) + echo "✓ Server ID appears to be a valid UUID format" + else + echo "⚠ Server ID format may not be a standard UUID: $server_id" + ((success_indicators++)) # Still count as success if we got an ID + fi + fi - # If we got a valid ID, verify it was actually created by calling the servers endpoint - if [[ ! -z "$server_id" && "$server_id" != "null" ]]; then - echo "-------------------------------------" - echo "Verifying server was published by checking servers endpoint..." - verify_response=$(curl -s "$HOST/v0/servers/$server_id") - echo "Response from servers endpoint:" - echo "$verify_response" | jq '.' 2>/dev/null || echo "$verify_response" - echo "-------------------------------------" - echo "Server verification response:" - echo "$verify_response" | jq '.' 2>/dev/null || echo "$verify_response" - echo "Server verification successful" + if [[ $success_indicators -ge 2 ]]; then + echo "" + echo "🎉 PUBLISH TEST PASSED!" + echo " ✓ Server successfully published with ID: $server_id" + echo " ✓ Success message: $message" else - echo "Error: No valid server ID returned from publish response" - echo "Response:" - echo "$http_response" | jq '.' 2>/dev/null || echo "$http_response" + echo "" + echo "❌ PUBLISH TEST FAILED!" + echo " Expected: Success message about server being added AND a server UUID" + echo " Received: message='$message', id='$server_id'" exit 1 fi else - echo "Response:" + echo "" + echo "❌ PUBLISH TEST FAILED!" + echo " Expected: 2xx status code" + echo " Received: $status_code" + echo " Response:" echo "$http_response" | jq '.' 2>/dev/null || echo "$http_response" - echo "Publish failed" + exit 1 fi echo "-------------------------------------" From 5faebbd58593385a21ee17c79c76ebab2b584984 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 21:05:25 -0400 Subject: [PATCH 14/15] feat(publish): escape server name to prevent HTML injection attacks --- internal/api/handlers/v0/publish.go | 5 +- internal/api/handlers/v0/publish_test.go | 74 ++++++++++++++++++++++++ internal/auth/github.go | 1 + 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/internal/api/handlers/v0/publish.go b/internal/api/handlers/v0/publish.go index dfc57cdc..cbbe04a9 100644 --- a/internal/api/handlers/v0/publish.go +++ b/internal/api/handlers/v0/publish.go @@ -12,6 +12,7 @@ import ( "github.com/modelcontextprotocol/registry/internal/database" "github.com/modelcontextprotocol/registry/internal/model" "github.com/modelcontextprotocol/registry/internal/service" + "golang.org/x/net/html" ) // PublishHandler handles requests to publish new server details to the registry @@ -83,11 +84,13 @@ func PublishHandler(registry service.RegistryService, authService auth.Service) authMethod = model.AuthMethodNone } + serverName := html.EscapeString(serverDetail.Name) + // Setup authentication info a := model.Authentication{ Method: authMethod, Token: token, - RepoRef: serverDetail.Name, + RepoRef: serverName, } valid, err := authService.ValidateAuth(r.Context(), a) diff --git a/internal/api/handlers/v0/publish_test.go b/internal/api/handlers/v0/publish_test.go index 214579a0..641e730a 100644 --- a/internal/api/handlers/v0/publish_test.go +++ b/internal/api/handlers/v0/publish_test.go @@ -283,6 +283,80 @@ func TestPublishHandler(t *testing.T) { expectedStatus: http.StatusInternalServerError, expectedError: "Failed to publish server details:", }, + { + name: "HTML injection attack in name field", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id-html", + Name: "io.github.malicious/test-server", + Description: "A test server with HTML injection attempt", + Repository: model.Repository{ + URL: "https://github.com/malicious/test-server", + Source: "github", + ID: "malicious/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer github_token_123", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + // The auth service should receive the escaped HTML version of the name + authSvc.Mock.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { + // Verify that the RepoRef contains escaped HTML, not the raw script tag + return auth.Method == model.AuthMethodGitHub && + auth.Token == "github_token_123" && + auth.RepoRef == "io.github.malicious/<script>alert('XSS')</script>test-server" + })).Return(true, nil) + registry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + }, + expectedStatus: http.StatusCreated, + expectedResponse: map[string]string{ + "message": "Server publication successful", + "id": "test-id-html", + }, + }, + { + name: "HTML injection attack in name field with non-GitHub prefix", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id-html-non-github", + Name: "malicious.com/test-server", + Description: "A test server with HTML injection attempt (non-GitHub)", + Repository: model.Repository{ + URL: "https://malicious.com/test-server", + Source: "custom", + ID: "malicious/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer some_token", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + // The auth service should receive the escaped HTML version of the name with AuthMethodNone + authSvc.Mock.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { + // Verify that the RepoRef contains escaped HTML, not the raw script tag + return auth.Method == model.AuthMethodNone && + auth.Token == "some_token" && + auth.RepoRef == "malicious.com/<script>alert('XSS')</script>test-server" + })).Return(true, nil) + registry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + }, + expectedStatus: http.StatusCreated, + expectedResponse: map[string]string{ + "message": "Server publication successful", + "id": "test-id-html-non-github", + }, + }, } for _, tc := range testCases { diff --git a/internal/auth/github.go b/internal/auth/github.go index ac813037..fc57ae1c 100644 --- a/internal/auth/github.go +++ b/internal/auth/github.go @@ -225,6 +225,7 @@ func (g *GitHubDeviceAuth) checkOrgMembership(ctx context.Context, token, userna // GitHub API endpoint: GET /orgs/{org}/members/{username} // true if status code is 204 No Content // false if status code is 404 Not Found + url := fmt.Sprint("https://api.github.com/orgs/", org, "/members/", username) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { From bc5a5c31460dc1a5013e3e401eb2a1be859cace2 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Tue, 27 May 2025 21:11:32 -0400 Subject: [PATCH 15/15] go mod tidy --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 73f43120..29c67e56 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/swaggo/files v1.0.1 github.com/swaggo/http-swagger v1.3.4 go.mongodb.org/mongo-driver v1.17.3 + golang.org/x/net v0.39.0 ) require ( @@ -31,7 +32,6 @@ require ( github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/net v0.39.0 // indirect golang.org/x/sync v0.13.0 // indirect golang.org/x/text v0.24.0 // indirect golang.org/x/tools v0.32.0 // indirect