diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..23b85e33 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,176 @@ +name: CI Pipeline + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +env: + GO_VERSION: '1.23.x' + +jobs: + # Lint and Format Check + lint: + name: Lint and Format + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Install golangci-lint + run: | + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.61.0 + + - name: Run golangci-lint + run: golangci-lint run --timeout=5m + + - name: Check Go formatting + run: | + if [ "$(gofmt -s -l . | wc -l)" -gt 0 ]; then + echo "The following files are not properly formatted:" + gofmt -s -l . + exit 1 + fi + + # Build check + build: + name: Build Check + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Build application + run: | + go build -v ./cmd/... + + - name: Check for vulnerabilities + run: | + go install golang.org/x/vuln/cmd/govulncheck@latest + govulncheck ./... + + # Unit Tests + unit-tests: + name: Unit Tests + runs-on: ubuntu-latest + needs: [lint, build] + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Run unit tests + run: | + go test -v -race -coverprofile=coverage.out -covermode=atomic ./internal/... + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.out + flags: unittests + name: codecov-unit + fail_ci_if_error: false + + # Integration Tests + integration-tests: + name: Integration Tests + runs-on: ubuntu-latest + needs: [lint, build] + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Run integration tests + run: | + chmod +x ./integrationtests/run_tests.sh + ./integrationtests/run_tests.sh + + - name: Run integration tests with coverage + run: | + go test -v -race -coverprofile=integration-coverage.out -covermode=atomic ./integrationtests/... + + - name: Upload integration coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./integration-coverage.out + flags: integrationtests + name: codecov-integration + fail_ci_if_error: false + + # Overall status check + test-summary: + name: Test Summary + runs-on: ubuntu-latest + needs: [unit-tests, integration-tests] + if: always() + steps: + - name: Check test results + run: | + if [[ "${{ needs.unit-tests.result }}" == "success" && "${{ needs.integration-tests.result }}" == "success" ]]; then + echo "✅ All tests passed!" + exit 0 + else + echo "❌ Some tests failed:" + echo " Unit tests: ${{ needs.unit-tests.result }}" + echo " Integration tests: ${{ needs.integration-tests.result }}" + exit 1 + fi diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 00000000..3b013dcf --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,74 @@ +name: Integration Tests + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +jobs: + integration-tests: + name: Run Integration Tests + runs-on: ubuntu-latest + + strategy: + matrix: + go-version: ['1.23.x'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Verify dependencies + run: go mod verify + + - name: Set up test environment + run: | + # Create any necessary directories for test data + mkdir -p /tmp/test-data + + - name: Run integration tests + run: | + # Run integration tests using the existing script + chmod +x ./integrationtests/run_tests.sh + ./integrationtests/run_tests.sh + + - name: Run integration tests with coverage + run: | + # Also run integration tests with Go test for coverage + go test -v -race -coverprofile=integration-coverage.out -covermode=atomic ./integrationtests/... + + - name: Generate integration test coverage report + run: go tool cover -html=integration-coverage.out -o integration-coverage.html + + - name: Upload integration coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./integration-coverage.out + flags: integrationtests + name: codecov-integration + fail_ci_if_error: false + + - name: Upload integration coverage artifact + uses: actions/upload-artifact@v4 + with: + name: integration-coverage-report + path: integration-coverage.html diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml new file mode 100644 index 00000000..18d0f5c4 --- /dev/null +++ b/.github/workflows/unit-tests.yml @@ -0,0 +1,63 @@ +name: Unit Tests + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +jobs: + unit-tests: + name: Run Unit Tests + runs-on: ubuntu-latest + + strategy: + matrix: + go-version: ['1.23.x'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Verify dependencies + run: go mod verify + + - name: Run unit tests + run: | + # Run unit tests with coverage, excluding integration tests + go test -v -race -coverprofile=coverage.out -covermode=atomic ./internal/... + + - name: Generate coverage report + run: go tool cover -html=coverage.out -o coverage.html + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.out + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false + + - name: Upload coverage artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: coverage.html diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..699cc571 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,122 @@ +# GolangCI-Lint configuration +# See: https://golangci-lint.run/usage/configuration/ + +run: + timeout: 5m + modules-download-mode: readonly + +linters: + enable: + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - typecheck + - unused + - asasalint + - asciicheck + - bidichk + - bodyclose + - containedctx + - contextcheck + - cyclop + - dupl + - durationcheck + - errname + - errorlint + - exhaustive + - forbidigo + - funlen + - gci + - gocognit + - goconst + - gocritic + - gocyclo + - godox + - gofmt + - goimports + - gomoddirectives + - gomodguard + - goprintffuncname + - gosec + - grouper + - importas + - ireturn + - lll + - makezero + - misspell + - nakedret + - nestif + - nilerr + - nilnil + - noctx + - nolintlint + - nosprintfhostport + - predeclared + - promlinter + - reassign + - revive + - rowserrcheck + - sqlclosecheck + - stylecheck + - tenv + - testpackage + - thelper + - tparallel + - unconvert + - unparam + - usestdlibvars + - wastedassign + - whitespace + +linters-settings: + cyclop: + max-complexity: 50 + funlen: + lines: 150 + statements: 150 + gocognit: + min-complexity: 50 + gocyclo: + min-complexity: 25 + goconst: + min-len: 3 + min-occurrences: 3 + mnd: + checks: + - argument + - case + - condition + - operation + - return + lll: + line-length: 150 + misspell: + locale: US + nestif: + min-complexity: 8 + +issues: + exclude-rules: + # Exclude some linters from running on tests files. + - path: _test\.go + linters: + - mnd + - funlen + - gocyclo + - errcheck + - dupl + - gosec + # Ignore long lines in generated code + - path: docs/ + linters: + - lll + # Ignore magic numbers in test files + - path: integrationtests/ + linters: + - mnd + # Allow local replacement directives in go.mod + - path: go\.mod + linters: + - gomoddirectives diff --git a/cmd/registry/main.go b/cmd/registry/main.go index a2e9716c..7042b515 100644 --- a/cmd/registry/main.go +++ b/cmd/registry/main.go @@ -2,8 +2,8 @@ package main import ( "context" + "errors" "flag" - "fmt" "log" "net/http" "os" @@ -25,9 +25,9 @@ func main() { // Show version information if requested if *showVersion { - fmt.Printf("MCP Registry v%s\n", Version) - fmt.Printf("Git commit: %s\n", GitCommit) - fmt.Printf("Build time: %s\n", BuildTime) + log.Printf("MCP Registry v%s\n", Version) + log.Printf("Git commit: %s\n", GitCommit) + log.Printf("Build time: %s\n", BuildTime) return } @@ -47,7 +47,8 @@ func main() { // Connect to MongoDB mongoDB, err := database.NewMongoDB(ctx, cfg.DatabaseURL, cfg.DatabaseName, cfg.CollectionName) if err != nil { - log.Fatalf("Failed to connect to MongoDB: %v", err) + log.Printf("Failed to connect to MongoDB: %v", err) + return } // Create registry service with MongoDB @@ -66,7 +67,9 @@ func main() { if cfg.SeedImport { log.Println("Importing data...") - database.ImportSeedFile(mongoDB, cfg.SeedFilePath) + if err := database.ImportSeedFile(mongoDB, cfg.SeedFilePath); err != nil { + log.Printf("Failed to import seed file: %v", err) + } log.Println("Data import completed successfully") } @@ -78,8 +81,9 @@ func main() { // Start server in a goroutine so it doesn't block signal handling go func() { - if err := server.Start(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Failed to start server: %v", err) + if err := server.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Printf("Failed to start server: %v", err) + os.Exit(1) } }() @@ -96,7 +100,7 @@ func main() { // Gracefully shutdown the server if err := server.Shutdown(sctx); err != nil { - log.Fatalf("Server forced to shutdown: %v", err) + log.Printf("Server forced to shutdown: %v", err) } log.Println("Server exiting") diff --git a/go.mod b/go.mod index 8a59528d..29c67e56 100644 --- a/go.mod +++ b/go.mod @@ -5,13 +5,16 @@ go 1.23.0 require ( github.com/caarlos0/env/v11 v11.3.1 github.com/google/uuid v1.6.0 + github.com/stretchr/testify v1.10.0 github.com/swaggo/files v1.0.1 github.com/swaggo/http-swagger v1.3.4 go.mongodb.org/mongo-driver v1.17.3 + golang.org/x/net v0.39.0 ) require ( github.com/KyleBanks/depth v1.2.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-openapi/jsonpointer v0.21.1 // indirect github.com/go-openapi/jsonreference v0.21.0 // indirect github.com/go-openapi/spec v0.21.0 // indirect @@ -21,13 +24,14 @@ require ( github.com/klauspost/compress v1.16.7 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/montanaflynn/stats v0.7.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/swaggo/swag v1.16.4 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/net v0.39.0 // indirect golang.org/x/sync v0.13.0 // indirect golang.org/x/text v0.24.0 // indirect golang.org/x/tools v0.32.0 // indirect diff --git a/go.sum b/go.sum index daa00519..5c7a5a65 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE= diff --git a/integrationtests/README.md b/integrationtests/README.md new file mode 100644 index 00000000..3fa2fd5f --- /dev/null +++ b/integrationtests/README.md @@ -0,0 +1,98 @@ +# Integration Tests + +This directory contains integration tests for the MCP Registry API using the fake service implementation. + +## Overview + +The integration tests are designed to test the complete flow of the publish endpoint using real service implementations (fake service) rather than mocks. This provides confidence that the entire request/response cycle works correctly. + +## Test Structure + +### `publish_integration_test.go` + +Contains comprehensive integration tests for the publish endpoint: + +- **TestPublishIntegration**: Tests various scenarios for publishing servers + - Successful publish with GitHub authentication + - Successful publish without authentication (for non-GitHub servers) + - Error cases: missing name, missing version, missing auth header, invalid JSON, unsupported HTTP methods + - Duplicate package handling: fails when same name+version, succeeds with different versions + +- **TestPublishIntegrationWithComplexPackages**: Tests publishing servers with complex package configurations + - Multiple runtime arguments (named and positional) + - Package arguments + - Environment variables (including secrets) + - Multiple remotes with different transport types + - Headers for HTTP remotes + +- **TestPublishIntegrationEndToEnd**: Tests the complete end-to-end flow + - Publishes a server and verifies it can be retrieved + - Checks that the server appears in the registry list + - Verifies count consistency + +## Mock Services + +### MockAuthService + +A simple mock implementation of the `auth.Service` interface that: +- Accepts any non-empty token for GitHub authentication +- Always allows authentication for `AuthMethodNone` +- Provides realistic responses for auth flow methods + +## Running the Tests + +From the project root directory: + +```bash +# Run all integration tests +go test ./integrationtests/... + +# Run with verbose output +go test -v ./integrationtests/... + +# Run a specific test +go test -v ./integrationtests/ -run TestPublishIntegration + +# Run tests with race detection +go test -race ./integrationtests/... + +# Use the convenient test runner script +./integrationtests/run_tests.sh +``` + +## Test Data + +The tests use the fake service which comes pre-populated with sample data: +- 3 sample MCP servers with different configurations +- Uses in-memory database for isolation between tests +- Each test creates unique server instances with UUIDs + +## Benefits of Integration Tests + +1. **Real Flow Testing**: Tests the actual HTTP request/response cycle +2. **Service Integration**: Validates that handlers work correctly with service implementations +3. **Data Persistence**: Verifies that published data can be retrieved +4. **Error Handling**: Tests complete error scenarios end-to-end +5. **Complex Scenarios**: Tests realistic server configurations with packages and remotes + +## Dependencies + +These tests use: +- `testify/assert` and `testify/require` for assertions +- `httptest` for HTTP testing utilities +- The fake service implementation for realistic data operations +- Standard Go testing package + +## Test Coverage + +The integration tests cover: +- ✅ Successful publish scenarios +- ✅ Authentication validation +- ✅ Input validation +- ✅ Duplicate package handling +- ✅ Complex package configurations +- ✅ Multiple remotes +- ✅ Error handling +- ✅ End-to-end data flow +- ✅ HTTP method validation +- ✅ JSON parsing errors diff --git a/integrationtests/publish_integration_test.go b/integrationtests/publish_integration_test.go new file mode 100644 index 00000000..639859ca --- /dev/null +++ b/integrationtests/publish_integration_test.go @@ -0,0 +1,727 @@ +package integrationtests_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" + "github.com/modelcontextprotocol/registry/internal/auth" + "github.com/modelcontextprotocol/registry/internal/model" + "github.com/modelcontextprotocol/registry/internal/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockAuthService implements a simple auth service for testing +type MockAuthService struct{} + +func (m *MockAuthService) StartAuthFlow( + _ context.Context, _ model.AuthMethod, _ string, +) (map[string]string, string, error) { + return map[string]string{ + "device_code": "mock_device_code", + "user_code": "ABCD-1234", + "verification_uri": "https://github.com/login/device", + }, "mock_status_token", nil +} + +func (m *MockAuthService) CheckAuthStatus(_ context.Context, statusToken string) (string, error) { + if statusToken == "mock_status_token" { + return "mock_access_token", nil + } + return "", fmt.Errorf("invalid status token") +} + +func (m *MockAuthService) ValidateAuth(_ context.Context, authentication model.Authentication) (bool, error) { + // Simple validation: for testing purposes, accept any non-empty token + switch authentication.Method { + case model.AuthMethodGitHub: + return authentication.Token != "", nil + case model.AuthMethodNone: + return true, nil + default: + return false, auth.ErrUnsupportedAuthMethod + } +} + +// TestPublishIntegration tests the complete flow of publishing a server using the fake service +func TestPublishIntegration(t *testing.T) { + // Setup fake service and auth service + registryService := service.NewFakeRegistryService() + authService := &MockAuthService{} + + // Create the publish handler + handler := v0.PublishHandler(registryService, authService) + + t.Run("successful publish with GitHub auth", func(t *testing.T) { + publishReq := model.PublishRequest{ + ServerDetail: model.ServerDetail{ + Server: model.Server{ + Name: "io.github.testuser/test-mcp-server", + Description: "A test MCP server for integration testing", + Repository: model.Repository{ + URL: "https://github.com/testuser/test-mcp-server", + Source: "github", + ID: "testuser/test-mcp-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + }, + }, + Packages: []model.Package{ + { + RegistryName: "npm", + Name: "test-mcp-server", + Version: "1.0.0", + RunTimeHint: "node", + RuntimeArguments: []model.Argument{ + { + Type: model.ArgumentTypeNamed, + Name: "config", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Configuration file path", + Format: model.FormatFilePath, + IsRequired: true, + }, + }, + }, + }, + }, + }, + Remotes: []model.Remote{ + { + TransportType: "http", + URL: "http://localhost:3000/mcp", + }, + }, + }, + } + + // Marshal the server detail to JSON + jsonData, err := json.Marshal(publishReq) + require.NoError(t, err) + + // Create a request + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_test_token_123") + + // Create a response recorder + recorder := httptest.NewRecorder() + + // Call the handler + handler(recorder, req) + + // Check the response + assert.Equal(t, http.StatusCreated, recorder.Code) + + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, "Server publication successful", response["message"]) + assert.NotEmpty(t, response["id"], "Server ID should be generated") + + // Verify the server was actually published by retrieving it + publishedServer, err := registryService.GetByID(response["id"]) + require.NoError(t, err) + assert.Equal(t, publishReq.ServerDetail.Name, publishedServer.Name) + assert.Equal(t, publishReq.ServerDetail.Description, publishedServer.Description) + assert.Equal(t, publishReq.ServerDetail.VersionDetail.Version, publishedServer.VersionDetail.Version) + assert.Len(t, publishedServer.Packages, 1) + assert.Len(t, publishedServer.Remotes, 1) + }) + + t.Run("successful publish without auth (no prefix)", func(t *testing.T) { + publishReq := &model.PublishRequest{ + ServerDetail: model.ServerDetail{ + Server: model.Server{ + Name: "custom-mcp-server", + Description: "A custom MCP server without auth", + Repository: model.Repository{ + URL: "https://example.com/custom-server", + Source: "custom", + ID: "custom/custom-server", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + }, + }, + }, + } + + jsonData, err := json.Marshal(publishReq) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "dummy_token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusCreated, recorder.Code) + + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, "Server publication successful", response["message"]) + assert.NotEmpty(t, response["id"], "Server ID should be generated") + }) + + t.Run("publish fails with missing name", func(t *testing.T) { + publishReq := &model.PublishRequest{ + ServerDetail: model.ServerDetail{ + Server: model.Server{ + Name: "", // Missing name + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + }, + }, + }, + } + + jsonData, err := json.Marshal(publishReq) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Name is required") + }) + + t.Run("publish fails with missing version", func(t *testing.T) { + serverDetail := &model.ServerDetail{ + Server: model.Server{ + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "", // Missing version + }, + }, + } + + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Version is required") + }) + + t.Run("publish fails with missing authorization header", func(t *testing.T) { + serverDetail := &model.ServerDetail{ + Server: model.Server{ + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + }, + }, + } + + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + // No Authorization header + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Authorization header is required") + }) + + t.Run("publish fails with invalid JSON", func(t *testing.T) { + invalidJSON := `{"name": "test", "version": ` + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBufferString(invalidJSON)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Invalid") + }) + + t.Run("publish fails with unsupported HTTP method", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/publish", nil) + req.Header.Set("Authorization", "Bearer token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusMethodNotAllowed, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Method not allowed") + }) + + t.Run("publish fails with duplicate name and version", func(t *testing.T) { + // First, publish a server successfully + firstServerDetail := &model.ServerDetail{ + Server: model.Server{ + Name: "io.github.duplicate/test-server", + Description: "First server for duplicate test", + Repository: model.Repository{ + URL: "https://github.com/duplicate/test-server", + Source: "github", + ID: "duplicate/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + }, + }, + } + + jsonData, err := json.Marshal(firstServerDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_token_first") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, http.StatusCreated, recorder.Code, "First publish should succeed") + + firstServerDetail.ID = response["id"] // Store the ID for later verification + + // Now try to publish another server with the same name and version + duplicateServerDetail := &model.ServerDetail{ + Server: model.Server{ + Name: "io.github.duplicate/test-server", // Same name + Description: "Duplicate server attempt", + Repository: model.Repository{ + URL: "https://github.com/duplicate/test-server-fork", + Source: "github", + ID: "duplicate/test-server-fork", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", // Same version + }, + }, + } + + duplicateJSONData, err := json.Marshal(duplicateServerDetail) + require.NoError(t, err) + + duplicateReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(duplicateJSONData)) + duplicateReq.Header.Set("Content-Type", "application/json") + duplicateReq.Header.Set("Authorization", "Bearer github_token_duplicate") + + duplicateRecorder := httptest.NewRecorder() + handler(duplicateRecorder, duplicateReq) + + // The duplicate should fail + assert.Equal(t, http.StatusBadRequest, duplicateRecorder.Code) + assert.Contains(t, duplicateRecorder.Body.String(), "Failed to publish server details") + + // Verify that only the first server was actually stored + retrievedServer, err := registryService.GetByID(firstServerDetail.ID) + require.NoError(t, err) + assert.Equal(t, firstServerDetail.Name, retrievedServer.Name) + assert.Equal(t, firstServerDetail.Description, retrievedServer.Description) + }) + + t.Run("publish succeeds with same name but different version", func(t *testing.T) { + // Publish first version + firstVersionDetail := &model.ServerDetail{ + Server: model.Server{ + Name: "io.github.versioned/test-server", + Description: "First version of the server", + Repository: model.Repository{ + URL: "https://github.com/versioned/test-server", + Source: "github", + ID: "versioned/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + }, + }, + } + + jsonData, err := json.Marshal(firstVersionDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_token_v1") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + firstVersionDetail.ID = response["id"] // Store the ID for later verification + + assert.Equal(t, http.StatusCreated, recorder.Code, "First version should succeed") + require.NotEmpty(t, firstVersionDetail.ID, "Server ID should be generated") + + // Publish second version with same name but different version + secondVersionDetail := &model.ServerDetail{ + Server: model.Server{ + Name: "io.github.versioned/test-server", // Same name + Description: "Second version of the server", + Repository: model.Repository{ + URL: "https://github.com/versioned/test-server", + Source: "github", + ID: "versioned/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", // Different version + }, + }, + } + + secondJSONData, err := json.Marshal(secondVersionDetail) + require.NoError(t, err) + + secondReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(secondJSONData)) + secondReq.Header.Set("Content-Type", "application/json") + secondReq.Header.Set("Authorization", "Bearer github_token_v2") + + secondRecorder := httptest.NewRecorder() + handler(secondRecorder, secondReq) + + var secondResponse map[string]string + err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondResponse) + require.NoError(t, err) + secondVersionDetail.ID = secondResponse["id"] // Store the ID for later verification + + // The second version should succeed + assert.Equal(t, http.StatusCreated, secondRecorder.Code) + require.NotEmpty(t, secondVersionDetail.ID, "Server ID for second version should be generated") + + // Verify both versions exist + firstRetrieved, err := registryService.GetByID(firstVersionDetail.ID) + require.NoError(t, err) + assert.Equal(t, "1.0.0", firstRetrieved.VersionDetail.Version) + + secondRetrieved, err := registryService.GetByID(secondVersionDetail.ID) + require.NoError(t, err) + assert.Equal(t, "2.0.0", secondRetrieved.VersionDetail.Version) + }) + + t.Run("publish fails when trying to publish older version after newer version", func(t *testing.T) { + // First, publish a newer version (2.0.0) + newerVersionDetail := &model.ServerDetail{ + Server: model.Server{ + Name: "io.github.versioning/version-order-test", + Description: "Newer version published first", + Repository: model.Repository{ + URL: "https://github.com/versioning/version-order-test", + Source: "github", + ID: "versioning/version-order-test", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + }, + }, + } + + jsonData, err := json.Marshal(newerVersionDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_token_newer") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + newerVersionDetail.ID = response["id"] // Store the ID for later verification + + assert.Equal(t, http.StatusCreated, recorder.Code, "Newer version should be published successfully") + require.NotEmpty(t, newerVersionDetail.ID, "Server ID for newer version should be generated") + + // Now try to publish an older version (1.0.0) of the same package + olderVersionDetail := &model.ServerDetail{ + Server: model.Server{ + Name: "io.github.versioning/version-order-test", // Same name + Description: "Older version published after newer", + Repository: model.Repository{ + URL: "https://github.com/versioning/version-order-test", + Source: "github", + ID: "versioning/version-order-test", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", // Older version + }, + }, + } + + olderJSONData, err := json.Marshal(olderVersionDetail) + require.NoError(t, err) + + olderReq := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(olderJSONData)) + olderReq.Header.Set("Content-Type", "application/json") + olderReq.Header.Set("Authorization", "Bearer github_token_older") + + olderRecorder := httptest.NewRecorder() + handler(olderRecorder, olderReq) + + // This should fail - we shouldn't allow publishing older versions after newer ones + assert.Equal(t, http.StatusBadRequest, olderRecorder.Code, "Publishing older version should fail") + assert.Contains(t, olderRecorder.Body.String(), "version", "Error message should mention version") + + // Verify that only the newer version exists + newerRetrieved, err := registryService.GetByID(newerVersionDetail.ID) + require.NoError(t, err) + assert.Equal(t, "2.0.0", newerRetrieved.VersionDetail.Version) + + // Verify the older version was not stored + _, err = registryService.GetByID(olderVersionDetail.ID) + assert.Error(t, err, "Older version should not have been stored") + }) +} + +// TestPublishIntegrationWithComplexPackages tests publishing with complex package configurations +func TestPublishIntegrationWithComplexPackages(t *testing.T) { + registryService := service.NewFakeRegistryService() + authService := &MockAuthService{} + handler := v0.PublishHandler(registryService, authService) + + t.Run("publish with complex package configuration", func(t *testing.T) { + serverDetail := &model.ServerDetail{ + Server: model.Server{ + Name: "io.github.complex/advanced-mcp-server", + Description: "An advanced MCP server with complex configuration", + Repository: model.Repository{ + URL: "https://github.com/complex/advanced-mcp-server", + Source: "github", + ID: "complex/advanced-mcp-server", + }, + VersionDetail: model.VersionDetail{ + Version: "2.1.0", + }, + }, + Packages: []model.Package{ + { + RegistryName: "npm", + Name: "@example/advanced-mcp-server", + Version: "43.1.0", + RunTimeHint: "node", + RuntimeArguments: []model.Argument{ + { + Type: model.ArgumentTypeNamed, + Name: "experimental-modules", + }, + { + Type: model.ArgumentTypeNamed, + Name: "config", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Main configuration file", + Format: model.FormatFilePath, + IsRequired: true, + Default: "./config.json", + }, + }, + }, + { + Type: model.ArgumentTypePositional, + Name: "mode", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Operation mode", + Format: model.FormatString, + IsRequired: false, + Default: "production", + Choices: []string{"development", "staging", "production"}, + }, + }, + }, + }, + PackageArguments: []model.Argument{ + { + Type: model.ArgumentTypeNamed, + Name: "install-deps", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Install dependencies", + Format: model.FormatBoolean, + Default: "true", + }, + }, + }, + }, + EnvironmentVariables: []model.KeyValueInput{ + { + Name: "LOG_LEVEL", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "Logging level", + Format: model.FormatString, + Default: "info", + Choices: []string{"debug", "info", "warn", "error"}, + }, + }, + }, + { + Name: "API_KEY", + InputWithVariables: model.InputWithVariables{ + Input: model.Input{ + Description: "API key for external service", + Format: model.FormatString, + IsRequired: true, + IsSecret: true, + }, + }, + }, + }, + }, + }, + Remotes: []model.Remote{ + { + TransportType: "http", + URL: "http://localhost:8080/mcp", + Headers: []model.Input{ + { + Description: "API Version Header", + Format: model.FormatString, + Value: "v1", + }, + }, + }, + }, + } + + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_complex_token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + assert.Equal(t, http.StatusCreated, recorder.Code) + + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + + serverDetail.ID = response["id"] // Store the ID for later verification + assert.Equal(t, "Server publication successful", response["message"]) + assert.NotEmpty(t, response["id"], "Server ID should be generated") + + // Verify the complex server was published correctly + publishedServer, err := registryService.GetByID(serverDetail.ID) + require.NoError(t, err) + + // Verify package details + require.Len(t, publishedServer.Packages, 1) + pkg := publishedServer.Packages[0] + assert.Equal(t, "npm", pkg.RegistryName) + assert.Equal(t, "@example/advanced-mcp-server", pkg.Name) + assert.Len(t, pkg.RuntimeArguments, 3) + assert.Len(t, pkg.PackageArguments, 1) + assert.Len(t, pkg.EnvironmentVariables, 2) + + // Verify remotes + require.Len(t, publishedServer.Remotes, 1) + assert.Equal(t, "http", publishedServer.Remotes[0].TransportType) + assert.Len(t, publishedServer.Remotes[0].Headers, 1) + }) +} + +// TestPublishIntegrationEndToEnd tests the complete end-to-end flow +func TestPublishIntegrationEndToEnd(t *testing.T) { + registryService := service.NewFakeRegistryService() + authService := &MockAuthService{} + handler := v0.PublishHandler(registryService, authService) + + t.Run("end-to-end publish and retrieve flow", func(t *testing.T) { + // Step 1: Get initial count of servers + initialServers, _, err := registryService.List("", 100) + require.NoError(t, err) + initialCount := len(initialServers) + + // Step 2: Publish a new server + serverDetail := &model.ServerDetail{ + Server: model.Server{ + Name: "io.github.e2e/end-to-end-server", + Description: "End-to-end test server", + Repository: model.Repository{ + URL: "https://github.com/e2e/end-to-end-server", + Source: "github", + ID: "e2e/end-to-end-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + }, + }, + } + + jsonData, err := json.Marshal(serverDetail) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v0/publish", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer github_e2e_token") + + recorder := httptest.NewRecorder() + handler(recorder, req) + + var response map[string]string + err = json.Unmarshal(recorder.Body.Bytes(), &response) + require.NoError(t, err) + serverDetail.ID = response["id"] // Store the ID for later verification + + require.Equal(t, http.StatusCreated, recorder.Code) + + // Step 3: Verify the count increased + updatedServers, _, err := registryService.List("", 100) + require.NoError(t, err) + assert.Equal(t, initialCount+1, len(updatedServers)) + + // Step 4: Verify the server can be retrieved by ID + retrievedServer, err := registryService.GetByID(serverDetail.ID) + require.NoError(t, err) + assert.Equal(t, serverDetail.Name, retrievedServer.Name) + assert.Equal(t, serverDetail.Description, retrievedServer.Description) + + // Step 5: Verify the server appears in the list + found := false + for _, server := range updatedServers { + if server.ID == serverDetail.ID { + found = true + assert.Equal(t, serverDetail.Name, server.Name) + break + } + } + assert.True(t, found, "Published server should appear in the list") + }) +} diff --git a/integrationtests/run_tests.sh b/integrationtests/run_tests.sh new file mode 100755 index 00000000..d035bb71 --- /dev/null +++ b/integrationtests/run_tests.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Integration Test Runner for MCP Registry +# This script runs the integration tests for the publish functionality + +echo "Running MCP Registry Integration Tests..." +echo "========================================" + +# Change to the project directory (parent of integrationtests) +cd "$(dirname "$0")/.." + +# Run integration tests with verbose output +echo "Running publish integration tests..." +go test -v ./integrationtests/... + +# Check exit code +if [ $? -eq 0 ]; then + echo "" + echo "✅ All integration tests passed!" +else + echo "" + echo "❌ Some integration tests failed!" + exit 1 +fi diff --git a/internal/api/handlers/v0/auth.go b/internal/api/handlers/v0/auth.go index bc3f7b33..38156d18 100644 --- a/internal/api/handlers/v0/auth.go +++ b/internal/api/handlers/v0/auth.go @@ -64,11 +64,14 @@ func StartAuthHandler(authService auth.Service) http.HandlerFunc { // Return successful response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]interface{}{ + if err := json.NewEncoder(w).Encode(map[string]interface{}{ "flow_info": flowInfo, "status_token": statusToken, "expires_in": 300, // 5 minutes - }) + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } } @@ -95,9 +98,12 @@ func CheckAuthStatusHandler(authService auth.Service) http.HandlerFunc { // Auth is still pending w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]interface{}{ + if err := json.NewEncoder(w).Encode(map[string]interface{}{ "status": "pending", - }) + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } return } @@ -109,9 +115,12 @@ func CheckAuthStatusHandler(authService auth.Service) http.HandlerFunc { // Authentication completed successfully w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]interface{}{ + if err := json.NewEncoder(w).Encode(map[string]interface{}{ "status": "complete", "token": token, - }) + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } } diff --git a/internal/api/handlers/v0/health.go b/internal/api/handlers/v0/health.go index 07d31b6d..3dd78924 100644 --- a/internal/api/handlers/v0/health.go +++ b/internal/api/handlers/v0/health.go @@ -10,16 +10,18 @@ import ( type HealthResponse struct { Status string `json:"status"` - GitHubClientId string `json:"github_client_id"` + GitHubClientID string `json:"github_client_id"` } // HealthHandler returns a handler for health check endpoint func HealthHandler(cfg *config.Config) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(HealthResponse{ + if err := json.NewEncoder(w).Encode(HealthResponse{ Status: "ok", - GitHubClientId: cfg.GithubClientID, - }) + GitHubClientID: cfg.GithubClientID, + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } } } diff --git a/internal/api/handlers/v0/health_test.go b/internal/api/handlers/v0/health_test.go new file mode 100644 index 00000000..baae604e --- /dev/null +++ b/internal/api/handlers/v0/health_test.go @@ -0,0 +1,122 @@ +package v0_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" + "github.com/modelcontextprotocol/registry/internal/config" + "github.com/stretchr/testify/assert" +) + +func TestHealthHandler(t *testing.T) { + // Test cases + testCases := []struct { + name string + config *config.Config + expectedStatus int + expectedBody v0.HealthResponse + }{ + { + name: "returns health status with github client id", + config: &config.Config{ + GithubClientID: "test-github-client-id", + }, + expectedStatus: http.StatusOK, + expectedBody: v0.HealthResponse{ + Status: "ok", + GitHubClientID: "test-github-client-id", + }, + }, + { + name: "works with empty github client id", + config: &config.Config{ + GithubClientID: "", + }, + expectedStatus: http.StatusOK, + expectedBody: v0.HealthResponse{ + Status: "ok", + GitHubClientID: "", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create handler with the test config + handler := v0.HealthHandler(tc.config) + + // Create request + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/health", nil) + if err != nil { + t.Fatal(err) + } + + // Create response recorder + rr := httptest.NewRecorder() + + // Call the handler + handler.ServeHTTP(rr, req) + + // Check status code + assert.Equal(t, tc.expectedStatus, rr.Code) + + // Check content type + assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) + + // Parse response body + var resp v0.HealthResponse + err = json.NewDecoder(rr.Body).Decode(&resp) + assert.NoError(t, err) + + // Check the response body + assert.Equal(t, tc.expectedBody, resp) + }) + } +} + +// TestHealthHandlerIntegration tests the handler with actual HTTP requests +func TestHealthHandlerIntegration(t *testing.T) { + // Create test server + cfg := &config.Config{ + GithubClientID: "integration-test-client-id", + } + + server := httptest.NewServer(v0.HealthHandler(cfg)) + defer server.Close() + + // Send request to the test server + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check status code + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Check content type + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + // Parse response body + var healthResp v0.HealthResponse + err = json.NewDecoder(resp.Body).Decode(&healthResp) + assert.NoError(t, err) + + // Check the response body + expectedResp := v0.HealthResponse{ + Status: "ok", + GitHubClientID: "integration-test-client-id", + } + assert.Equal(t, expectedResp, healthResp) +} diff --git a/internal/api/handlers/v0/ping.go b/internal/api/handlers/v0/ping.go index a77d622f..6e9b0bc0 100644 --- a/internal/api/handlers/v0/ping.go +++ b/internal/api/handlers/v0/ping.go @@ -22,6 +22,8 @@ func PingHandler(cfg *config.Config) http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } } } diff --git a/internal/api/handlers/v0/publish.go b/internal/api/handlers/v0/publish.go index 9e999d5a..cbbe04a9 100644 --- a/internal/api/handlers/v0/publish.go +++ b/internal/api/handlers/v0/publish.go @@ -3,13 +3,16 @@ package v0 import ( "encoding/json" + "errors" "io" "net/http" "strings" "github.com/modelcontextprotocol/registry/internal/auth" + "github.com/modelcontextprotocol/registry/internal/database" "github.com/modelcontextprotocol/registry/internal/model" "github.com/modelcontextprotocol/registry/internal/service" + "golang.org/x/net/html" ) // PublishHandler handles requests to publish new server details to the registry @@ -81,16 +84,18 @@ func PublishHandler(registry service.RegistryService, authService auth.Service) authMethod = model.AuthMethodNone } + serverName := html.EscapeString(serverDetail.Name) + // Setup authentication info a := model.Authentication{ Method: authMethod, Token: token, - RepoRef: serverDetail.Name, + RepoRef: serverName, } valid, err := authService.ValidateAuth(r.Context(), a) if err != nil { - if err == auth.ErrAuthRequired { + if errors.Is(err, auth.ErrAuthRequired) { http.Error(w, "Authentication is required for publishing", http.StatusUnauthorized) return } @@ -106,15 +111,23 @@ func PublishHandler(registry service.RegistryService, authService auth.Service) // Call the publish method on the registry service err = registry.Publish(&serverDetail) if err != nil { + // Check for specific error types and return appropriate HTTP status codes + if errors.Is(err, database.ErrInvalidVersion) || errors.Is(err, database.ErrAlreadyExists) { + http.Error(w, "Failed to publish server details: "+err.Error(), http.StatusBadRequest) + return + } http.Error(w, "Failed to publish server details: "+err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) - json.NewEncoder(w).Encode(map[string]string{ + if err := json.NewEncoder(w).Encode(map[string]string{ "message": "Server publication successful", "id": serverDetail.ID, - }) + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } } diff --git a/internal/api/handlers/v0/publish_test.go b/internal/api/handlers/v0/publish_test.go new file mode 100644 index 00000000..641e730a --- /dev/null +++ b/internal/api/handlers/v0/publish_test.go @@ -0,0 +1,556 @@ +package v0_test + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" + "github.com/modelcontextprotocol/registry/internal/auth" + "github.com/modelcontextprotocol/registry/internal/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockRegistryService is a mock implementation of the RegistryService interface +type MockRegistryService struct { + mock.Mock +} + +func (m *MockRegistryService) List(cursor string, limit int) ([]model.Server, string, error) { + args := m.Mock.Called(cursor, limit) + return args.Get(0).([]model.Server), args.String(1), args.Error(2) +} + +func (m *MockRegistryService) GetByID(id string) (*model.ServerDetail, error) { + args := m.Mock.Called(id) + return args.Get(0).(*model.ServerDetail), args.Error(1) +} + +func (m *MockRegistryService) Publish(serverDetail *model.ServerDetail) error { + args := m.Mock.Called(serverDetail) + return args.Error(0) +} + +// MockAuthService is a mock implementation of the auth.Service interface +type MockAuthService struct { + mock.Mock +} + +func (m *MockAuthService) StartAuthFlow( + ctx context.Context, method model.AuthMethod, repoRef string, +) (map[string]string, string, error) { + args := m.Mock.Called(ctx, method, repoRef) + return args.Get(0).(map[string]string), args.String(1), args.Error(2) +} + +func (m *MockAuthService) CheckAuthStatus(ctx context.Context, statusToken string) (string, error) { + args := m.Mock.Called(ctx, statusToken) + return args.String(0), args.Error(1) +} + +func (m *MockAuthService) ValidateAuth(ctx context.Context, authentication model.Authentication) (bool, error) { + args := m.Mock.Called(ctx, authentication) + return args.Bool(0), args.Error(1) +} + +func TestPublishHandler(t *testing.T) { + testCases := []struct { + name string + method string + requestBody interface{} + authHeader string + setupMocks func(*MockRegistryService, *MockAuthService) + expectedStatus int + expectedResponse map[string]string + expectedError string + }{ + { + name: "successful publish with GitHub auth", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "io.github.example/test-server", + Description: "A test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server", + Source: "github", + ID: "example/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer github_token_123", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + authSvc.Mock.On("ValidateAuth", mock.Anything, model.Authentication{ + Method: model.AuthMethodGitHub, + Token: "github_token_123", + RepoRef: "io.github.example/test-server", + }).Return(true, nil) + registry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + }, + expectedStatus: http.StatusCreated, + expectedResponse: map[string]string{ + "message": "Server publication successful", + "id": "test-id", + }, + }, + { + name: "successful publish with no auth (AuthMethodNone)", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id-2", + Name: "example/test-server", + Description: "A test server without auth", + Repository: model.Repository{ + URL: "https://example.com/test-server", + Source: "example", + ID: "example/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer some_token", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + authSvc.Mock.On("ValidateAuth", mock.Anything, model.Authentication{ + Method: model.AuthMethodNone, + Token: "some_token", + RepoRef: "example/test-server", + }).Return(true, nil) + registry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + }, + expectedStatus: http.StatusCreated, + expectedResponse: map[string]string{ + "message": "Server publication successful", + "id": "test-id-2", + }, + }, + { + name: "method not allowed", + method: http.MethodGet, + requestBody: nil, + authHeader: "", + setupMocks: func(_ *MockRegistryService, _ *MockAuthService) {}, + expectedStatus: http.StatusMethodNotAllowed, + expectedError: "Method not allowed", + }, + { + name: "missing request body", + method: http.MethodPost, + requestBody: "", + authHeader: "", + setupMocks: func(_ *MockRegistryService, _ *MockAuthService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Invalid request payload:", + }, + { + name: "missing server name", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "", // Missing name + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "", + setupMocks: func(_ *MockRegistryService, _ *MockAuthService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Name is required", + }, + { + name: "missing version", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "", // Missing version + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "", + setupMocks: func(_ *MockRegistryService, _ *MockAuthService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Version is required", + }, + { + name: "missing authorization header", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "", // Missing auth header + setupMocks: func(_ *MockRegistryService, _ *MockAuthService) {}, + expectedStatus: http.StatusUnauthorized, + expectedError: "Authorization header is required", + }, + { + name: "authentication required error", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer token", + setupMocks: func(_ *MockRegistryService, authSvc *MockAuthService) { + authSvc.Mock.On("ValidateAuth", mock.Anything, mock.Anything).Return(false, auth.ErrAuthRequired) + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "Authentication is required for publishing", + }, + { + name: "authentication failed", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer invalid_token", + setupMocks: func(_ *MockRegistryService, authSvc *MockAuthService) { + authSvc.Mock.On("ValidateAuth", mock.Anything, mock.Anything).Return(false, nil) + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "Invalid authentication credentials", + }, + { + name: "registry service error", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer token", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + authSvc.Mock.On("ValidateAuth", mock.Anything, mock.Anything).Return(true, nil) + registry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(assert.AnError) + }, + expectedStatus: http.StatusInternalServerError, + expectedError: "Failed to publish server details:", + }, + { + name: "HTML injection attack in name field", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id-html", + Name: "io.github.malicious/test-server", + Description: "A test server with HTML injection attempt", + Repository: model.Repository{ + URL: "https://github.com/malicious/test-server", + Source: "github", + ID: "malicious/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer github_token_123", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + // The auth service should receive the escaped HTML version of the name + authSvc.Mock.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { + // Verify that the RepoRef contains escaped HTML, not the raw script tag + return auth.Method == model.AuthMethodGitHub && + auth.Token == "github_token_123" && + auth.RepoRef == "io.github.malicious/<script>alert('XSS')</script>test-server" + })).Return(true, nil) + registry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + }, + expectedStatus: http.StatusCreated, + expectedResponse: map[string]string{ + "message": "Server publication successful", + "id": "test-id-html", + }, + }, + { + name: "HTML injection attack in name field with non-GitHub prefix", + method: http.MethodPost, + requestBody: model.ServerDetail{ + Server: model.Server{ + ID: "test-id-html-non-github", + Name: "malicious.com/test-server", + Description: "A test server with HTML injection attempt (non-GitHub)", + Repository: model.Repository{ + URL: "https://malicious.com/test-server", + Source: "custom", + ID: "malicious/test-server", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + }, + authHeader: "Bearer some_token", + setupMocks: func(registry *MockRegistryService, authSvc *MockAuthService) { + // The auth service should receive the escaped HTML version of the name with AuthMethodNone + authSvc.Mock.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { + // Verify that the RepoRef contains escaped HTML, not the raw script tag + return auth.Method == model.AuthMethodNone && + auth.Token == "some_token" && + auth.RepoRef == "malicious.com/<script>alert('XSS')</script>test-server" + })).Return(true, nil) + registry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + }, + expectedStatus: http.StatusCreated, + expectedResponse: map[string]string{ + "message": "Server publication successful", + "id": "test-id-html-non-github", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create mocks + mockRegistry := new(MockRegistryService) + mockAuthService := new(MockAuthService) + + // Setup mocks + tc.setupMocks(mockRegistry, mockAuthService) + + // Create handler + handler := v0.PublishHandler(mockRegistry, mockAuthService) + + // Prepare request body + var requestBody []byte + if tc.requestBody != nil { + var err error + requestBody, err = json.Marshal(tc.requestBody) + assert.NoError(t, err) + } + + // Create request + req, err := http.NewRequestWithContext(context.Background(), tc.method, "/publish", bytes.NewBuffer(requestBody)) + assert.NoError(t, err) + + // Set auth header if provided + if tc.authHeader != "" { + req.Header.Set("Authorization", tc.authHeader) + } + + // Create response recorder + rr := httptest.NewRecorder() + + // Call the handler + handler.ServeHTTP(rr, req) + + // Check status code + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedResponse != nil { + // Check content type for successful responses + assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) + + // Parse and verify response body + var response map[string]string + err = json.NewDecoder(rr.Body).Decode(&response) + assert.NoError(t, err) + assert.Equal(t, tc.expectedResponse, response) + } + + if tc.expectedError != "" { + // Check that the error message is contained in the response + assert.Contains(t, rr.Body.String(), tc.expectedError) + } + + // Assert that all expectations were met + mockRegistry.Mock.AssertExpectations(t) + mockAuthService.Mock.AssertExpectations(t) + }) + } +} + +func TestPublishHandlerBearerTokenParsing(t *testing.T) { + testCases := []struct { + name string + authHeader string + expectedToken string + }{ + { + name: "bearer token with Bearer prefix", + authHeader: "Bearer github_token_123", + expectedToken: "github_token_123", + }, + { + name: "bearer token with bearer prefix (lowercase)", + authHeader: "bearer github_token_123", + expectedToken: "github_token_123", + }, + { + name: "token without Bearer prefix", + authHeader: "github_token_123", + expectedToken: "github_token_123", + }, + { + name: "mixed case Bearer prefix", + authHeader: "BeArEr github_token_123", + expectedToken: "github_token_123", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockRegistry := new(MockRegistryService) + mockAuthService := new(MockAuthService) + + // Setup mock to capture the actual token passed + mockAuthService.Mock.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { + return auth.Token == tc.expectedToken + })).Return(true, nil) + mockRegistry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + + handler := v0.PublishHandler(mockRegistry, mockAuthService) + + serverDetail := model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: "test-server", + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + } + + requestBody, err := json.Marshal(serverDetail) + assert.NoError(t, err) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/publish", bytes.NewBuffer(requestBody)) + assert.NoError(t, err) + req.Header.Set("Authorization", tc.authHeader) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusCreated, rr.Code) + mockAuthService.Mock.AssertExpectations(t) + }) + } +} + +func TestPublishHandlerAuthMethodSelection(t *testing.T) { + testCases := []struct { + name string + serverName string + expectedAuthMethod model.AuthMethod + }{ + { + name: "GitHub prefix triggers GitHub auth", + serverName: "io.github.example/test-server", + expectedAuthMethod: model.AuthMethodGitHub, + }, + { + name: "non-GitHub prefix uses no auth", + serverName: "example.com/test-server", + expectedAuthMethod: model.AuthMethodNone, + }, + { + name: "empty prefix uses no auth", + serverName: "test-server", + expectedAuthMethod: model.AuthMethodNone, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockRegistry := new(MockRegistryService) + mockAuthService := new(MockAuthService) + + // Setup mock to capture the auth method + mockAuthService.Mock.On("ValidateAuth", mock.Anything, mock.MatchedBy(func(auth model.Authentication) bool { + return auth.Method == tc.expectedAuthMethod + })).Return(true, nil) + mockRegistry.Mock.On("Publish", mock.AnythingOfType("*model.ServerDetail")).Return(nil) + + handler := v0.PublishHandler(mockRegistry, mockAuthService) + + serverDetail := model.ServerDetail{ + Server: model.Server{ + ID: "test-id", + Name: tc.serverName, + Description: "A test server", + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + } + + requestBody, err := json.Marshal(serverDetail) + assert.NoError(t, err) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/publish", bytes.NewBuffer(requestBody)) + assert.NoError(t, err) + req.Header.Set("Authorization", "Bearer test_token") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusCreated, rr.Code) + mockAuthService.Mock.AssertExpectations(t) + }) + } +} diff --git a/internal/api/handlers/v0/servers.go b/internal/api/handlers/v0/servers.go index a9818820..b2fc21f6 100644 --- a/internal/api/handlers/v0/servers.go +++ b/internal/api/handlers/v0/servers.go @@ -89,7 +89,10 @@ func ServersHandler(registry service.RegistryService) http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } } @@ -123,6 +126,9 @@ func ServersDetailHandler(registry service.RegistryService) http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(serverDetail) + if err := json.NewEncoder(w).Encode(serverDetail); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } } diff --git a/internal/api/handlers/v0/servers_test.go b/internal/api/handlers/v0/servers_test.go new file mode 100644 index 00000000..380a3a6f --- /dev/null +++ b/internal/api/handlers/v0/servers_test.go @@ -0,0 +1,391 @@ +package v0_test + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" + "github.com/modelcontextprotocol/registry/internal/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestServersHandler(t *testing.T) { + testCases := []struct { + name string + method string + queryParams string + setupMocks func(*MockRegistryService) + expectedStatus int + expectedServers []model.Server + expectedMeta *v0.Metadata + expectedError string + }{ + { + name: "successful list with default parameters", + method: http.MethodGet, + setupMocks: func(registry *MockRegistryService) { + servers := []model.Server{ + { + ID: "550e8400-e29b-41d4-a716-446655440001", + Name: "test-server-1", + Description: "First test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-1", + Source: "github", + ID: "example/test-server-1", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + { + ID: "550e8400-e29b-41d4-a716-446655440002", + Name: "test-server-2", + Description: "Second test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-2", + Source: "github", + ID: "example/test-server-2", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + ReleaseDate: "2025-05-26T00:00:00Z", + IsLatest: true, + }, + }, + } + registry.Mock.On("List", "", 30).Return(servers, "", nil) + }, + expectedStatus: http.StatusOK, + expectedServers: []model.Server{ + { + ID: "550e8400-e29b-41d4-a716-446655440001", + Name: "test-server-1", + Description: "First test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-1", + Source: "github", + ID: "example/test-server-1", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-25T00:00:00Z", + IsLatest: true, + }, + }, + { + ID: "550e8400-e29b-41d4-a716-446655440002", + Name: "test-server-2", + Description: "Second test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-2", + Source: "github", + ID: "example/test-server-2", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + ReleaseDate: "2025-05-26T00:00:00Z", + IsLatest: true, + }, + }, + }, + }, + { + name: "successful list with cursor and limit", + method: http.MethodGet, + queryParams: "?cursor=550e8400-e29b-41d4-a716-446655440000" + "&limit=10", + setupMocks: func(registry *MockRegistryService) { + servers := []model.Server{ + { + ID: "550e8400-e29b-41d4-a716-446655440003", + Name: "test-server-3", + Description: "Third test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-3", + Source: "github", + ID: "example/test-server-3", + }, + VersionDetail: model.VersionDetail{ + Version: "1.5.0", + ReleaseDate: "2025-05-27T00:00:00Z", + IsLatest: true, + }, + }, + } + nextCursor := uuid.New().String() + registry.Mock.On("List", mock.AnythingOfType("string"), 10).Return(servers, nextCursor, nil) + }, + expectedStatus: http.StatusOK, + expectedServers: []model.Server{ + { + ID: "550e8400-e29b-41d4-a716-446655440003", + Name: "test-server-3", + Description: "Third test server", + Repository: model.Repository{ + URL: "https://github.com/example/test-server-3", + Source: "github", + ID: "example/test-server-3", + }, + VersionDetail: model.VersionDetail{ + Version: "1.5.0", + ReleaseDate: "2025-05-27T00:00:00Z", + IsLatest: true, + }, + }, + }, + expectedMeta: &v0.Metadata{ + NextCursor: "", // This will be dynamically set in the test + Count: 1, + }, + }, + { + name: "successful list with limit capping at 100", + method: http.MethodGet, + queryParams: "?limit=150", + setupMocks: func(registry *MockRegistryService) { + servers := []model.Server{} + registry.Mock.On("List", "", 100).Return(servers, "", nil) + }, + expectedStatus: http.StatusOK, + expectedServers: []model.Server{}, + }, + { + name: "invalid cursor parameter", + method: http.MethodGet, + queryParams: "?cursor=invalid-uuid", + setupMocks: func(_ *MockRegistryService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Invalid cursor parameter", + }, + { + name: "invalid limit parameter - non-numeric", + method: http.MethodGet, + queryParams: "?limit=abc", + setupMocks: func(_ *MockRegistryService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Invalid limit parameter", + }, + { + name: "invalid limit parameter - zero", + method: http.MethodGet, + queryParams: "?limit=0", + setupMocks: func(_ *MockRegistryService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Limit must be greater than 0", + }, + { + name: "invalid limit parameter - negative", + method: http.MethodGet, + queryParams: "?limit=-5", + setupMocks: func(_ *MockRegistryService) {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Limit must be greater than 0", + }, + { + name: "registry service error", + method: http.MethodGet, + setupMocks: func(registry *MockRegistryService) { + registry.Mock.On("List", "", 30).Return([]model.Server{}, "", errors.New("database connection error")) + }, + expectedStatus: http.StatusInternalServerError, + expectedError: "database connection error", + }, + { + name: "method not allowed", + method: http.MethodPost, + setupMocks: func(_ *MockRegistryService) {}, + expectedStatus: http.StatusMethodNotAllowed, + expectedError: "Method not allowed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create mock registry service + mockRegistry := new(MockRegistryService) + tc.setupMocks(mockRegistry) + + // Create handler + handler := v0.ServersHandler(mockRegistry) + + // Create request + url := "/v0/servers" + tc.queryParams + req, err := http.NewRequestWithContext(context.Background(), tc.method, url, nil) + if err != nil { + t.Fatal(err) + } + + // Create response recorder + rr := httptest.NewRecorder() + + // Call the handler + handler.ServeHTTP(rr, req) + + // Check status code + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + // Check content type + assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) + + // Parse response body + var resp v0.PaginatedResponse + err = json.NewDecoder(rr.Body).Decode(&resp) + assert.NoError(t, err) + + // Check the response data + assert.Equal(t, tc.expectedServers, resp.Data) + + // Check metadata if expected + if tc.expectedMeta != nil { + assert.Equal(t, tc.expectedMeta.Count, resp.Metadata.Count) + if tc.expectedMeta.NextCursor != "" { + assert.NotEmpty(t, resp.Metadata.NextCursor) + } + } + } else if tc.expectedError != "" { + // Check error message for non-200 responses + assert.Contains(t, rr.Body.String(), tc.expectedError) + } + + // Verify mock expectations + mockRegistry.Mock.AssertExpectations(t) + }) + } +} + +// TestServersHandlerIntegration tests the servers list handler with actual HTTP requests +func TestServersHandlerIntegration(t *testing.T) { + // Create mock registry service + mockRegistry := new(MockRegistryService) + + servers := []model.Server{ + { + ID: "550e8400-e29b-41d4-a716-446655440004", + Name: "integration-test-server", + Description: "Integration test server", + Repository: model.Repository{ + URL: "https://github.com/example/integration-test", + Source: "github", + ID: "example/integration-test", + }, + VersionDetail: model.VersionDetail{ + Version: "1.0.0", + ReleaseDate: "2025-05-27T00:00:00Z", + IsLatest: true, + }, + }, + } + + mockRegistry.Mock.On("List", "", 30).Return(servers, "", nil) + + // Create test server + server := httptest.NewServer(v0.ServersHandler(mockRegistry)) + defer server.Close() + + // Send request to the test server + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check status code + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Check content type + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + // Parse response body + var paginatedResp v0.PaginatedResponse + err = json.NewDecoder(resp.Body).Decode(&paginatedResp) + assert.NoError(t, err) + + // Check the response data + assert.Equal(t, servers, paginatedResp.Data) + assert.Empty(t, paginatedResp.Metadata.NextCursor) + + // Verify mock expectations + mockRegistry.Mock.AssertExpectations(t) +} + +// TestServersDetailHandlerIntegration tests the servers detail handler with actual HTTP requests +func TestServersDetailHandlerIntegration(t *testing.T) { + serverID := uuid.New().String() + + // Create mock registry service + mockRegistry := new(MockRegistryService) + + serverDetail := &model.ServerDetail{ + Server: model.Server{ + ID: serverID, + Name: "integration-test-server-detail", + Description: "Integration test server detail", + Repository: model.Repository{ + URL: "https://github.com/example/integration-test-detail", + Source: "github", + ID: "example/integration-test-detail", + }, + VersionDetail: model.VersionDetail{ + Version: "2.0.0", + ReleaseDate: "2025-05-27T12:00:00Z", + IsLatest: true, + }, + }, + } + + mockRegistry.Mock.On("GetByID", serverID).Return(serverDetail, nil) + + // Create test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.SetPathValue("id", serverID) + v0.ServersDetailHandler(mockRegistry).ServeHTTP(w, r) + })) + defer server.Close() + + // Send request to the test server + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check status code + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Check content type + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + // Parse response body + var serverDetailResp model.ServerDetail + err = json.NewDecoder(resp.Body).Decode(&serverDetailResp) + assert.NoError(t, err) + + // Check the response data + assert.Equal(t, *serverDetail, serverDetailResp) + + // Verify mock expectations + mockRegistry.Mock.AssertExpectations(t) +} diff --git a/internal/api/handlers/v0/swagger.go b/internal/api/handlers/v0/swagger.go index fea5b296..6a368f88 100644 --- a/internal/api/handlers/v0/swagger.go +++ b/internal/api/handlers/v0/swagger.go @@ -6,7 +6,7 @@ import ( "os" "path/filepath" - _ "github.com/swaggo/files" + _ "github.com/swaggo/files" // Swagger files needed for embedding httpSwagger "github.com/swaggo/http-swagger" ) diff --git a/internal/api/router/v0.go b/internal/api/router/v0.go index 3564d7e4..6d465f99 100644 --- a/internal/api/router/v0.go +++ b/internal/api/router/v0.go @@ -11,7 +11,9 @@ import ( ) // RegisterV0Routes registers all v0 API routes to the provided router -func RegisterV0Routes(mux *http.ServeMux, cfg *config.Config, registry service.RegistryService, authService auth.Service) { +func RegisterV0Routes( + mux *http.ServeMux, cfg *config.Config, registry service.RegistryService, authService auth.Service, +) { // Register v0 endpoints mux.HandleFunc("/v0/health", v0.HealthHandler(cfg)) mux.HandleFunc("/v0/servers", v0.ServersHandler(registry)) diff --git a/internal/api/server.go b/internal/api/server.go index 53715822..c92a1a54 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -4,6 +4,7 @@ import ( "context" "log" "net/http" + "time" "github.com/modelcontextprotocol/registry/internal/api/router" "github.com/modelcontextprotocol/registry/internal/auth" @@ -31,8 +32,9 @@ func NewServer(cfg *config.Config, registryService service.RegistryService, auth authService: authService, router: mux, server: &http.Server{ - Addr: cfg.ServerAddress, - Handler: mux, + Addr: cfg.ServerAddress, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, }, } diff --git a/internal/auth/github.go b/internal/auth/github.go index ce1daf65..fc57ae1c 100644 --- a/internal/auth/github.go +++ b/internal/auth/github.go @@ -3,6 +3,7 @@ package auth import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -70,14 +71,19 @@ func NewGitHubDeviceAuth(config GitHubOAuthConfig) *GitHubDeviceAuth { // It verifies the token owner matches the repository owner or is a member of the owning organization. // It also verifies that the token was created for the same ClientID used to set up the authentication. // Returns true if valid, false otherwise along with an error explaining the validation failure. -func (g *GitHubDeviceAuth) ValidateToken(token string, requiredRepo string) (bool, error) { +func (g *GitHubDeviceAuth) ValidateToken(ctx context.Context, token string, requiredRepo string) (bool, error) { // If no repo is required, we can't validate properly if requiredRepo == "" { return false, fmt.Errorf("repository reference is required for token validation") } // First, validate that the token is associated with our ClientID - tokenReq, err := http.NewRequest("GET", "https://api.github.com/applications/"+g.config.ClientID+"/token", nil) + tokenReq, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + "https://api.github.com/applications/"+g.config.ClientID+"/token", + nil, + ) if err != nil { return false, err } @@ -97,7 +103,8 @@ func (g *GitHubDeviceAuth) ValidateToken(token string, requiredRepo string) (boo } // POST instead of GET for security reasons per GitHub API - tokenReq, err = http.NewRequest("POST", "https://api.github.com/applications/"+g.config.ClientID+"/token", io.NopCloser(bytes.NewReader(checkBody))) + tokenURL := "https://api.github.com/applications/" + g.config.ClientID + "/token" + tokenReq, err = http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, io.NopCloser(bytes.NewReader(checkBody))) if err != nil { return false, err } @@ -135,7 +142,7 @@ func (g *GitHubDeviceAuth) ValidateToken(token string, requiredRepo string) (boo } // Get the authenticated user - userReq, err := http.NewRequest("GET", "https://api.github.com/user", nil) + userReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/user", nil) if err != nil { return false, err } @@ -175,17 +182,20 @@ func (g *GitHubDeviceAuth) ValidateToken(token string, requiredRepo string) (boo // Verify that the authenticated user matches the owner if userInfo.Login != owner { // Check if the user is a member of the organization - isMember, err := g.checkOrgMembership(token, userInfo.Login, owner) + isMember, err := g.checkOrgMembership(ctx, token, userInfo.Login, owner) if err != nil { return false, fmt.Errorf("failed to check org membership: %s", owner) } if !isMember { - return false, fmt.Errorf("token belongs to user %s, but repository is owned by %s and user is not a member of the organization", userInfo.Login, owner) + return false, fmt.Errorf( + "token belongs to user %s, but repository is owned by %s and user is not a member of the organization", + userInfo.Login, owner) } } - // If we've reached this point, the token has access the repo and the user matches the owner or is a member of the owner org + // If we've reached this point, the token has access the repo and the user matches + // the owner or is a member of the owner org return true, nil } @@ -210,13 +220,14 @@ func (g *GitHubDeviceAuth) ExtractGitHubRepo(repoURL string) (owner, repo string } // checkOrgMembership checks if a user is a member of an organization -func (g *GitHubDeviceAuth) checkOrgMembership(token, username, org string) (bool, error) { +func (g *GitHubDeviceAuth) checkOrgMembership(ctx context.Context, token, username, org string) (bool, error) { // Create request to check if user is a member of the organization // GitHub API endpoint: GET /orgs/{org}/members/{username} // true if status code is 204 No Content // false if status code is 404 Not Found + url := fmt.Sprint("https://api.github.com/orgs/", org, "/members/", username) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return false, err } diff --git a/internal/auth/service.go b/internal/auth/service.go index 54a9f01e..d8fdf70a 100644 --- a/internal/auth/service.go +++ b/internal/auth/service.go @@ -8,37 +8,40 @@ import ( "github.com/modelcontextprotocol/registry/internal/model" ) -// AuthServiceImpl implements the Service interface -type AuthServiceImpl struct { +// ServiceImpl implements the Service interface +type ServiceImpl struct { config *config.Config githubAuth *GitHubDeviceAuth } // NewAuthService creates a new authentication service +// +//nolint:ireturn // Factory function intentionally returns interface for dependency injection func NewAuthService(cfg *config.Config) Service { githubConfig := GitHubOAuthConfig{ ClientID: cfg.GithubClientID, ClientSecret: cfg.GithubClientSecret, } - return &AuthServiceImpl{ + return &ServiceImpl{ config: cfg, githubAuth: NewGitHubDeviceAuth(githubConfig), } } -func (s *AuthServiceImpl) StartAuthFlow(ctx context.Context, method model.AuthMethod, repoRef string) (map[string]string, string, error) { +func (s *ServiceImpl) StartAuthFlow(_ context.Context, _ model.AuthMethod, + _ string) (map[string]string, string, error) { // return not implemented error return nil, "", fmt.Errorf("not implemented") } -func (s *AuthServiceImpl) CheckAuthStatus(ctx context.Context, statusToken string) (string, error) { +func (s *ServiceImpl) CheckAuthStatus(_ context.Context, _ string) (string, error) { // return not implemented error return "", fmt.Errorf("not implemented") } // ValidateAuth validates authentication credentials -func (s *AuthServiceImpl) ValidateAuth(ctx context.Context, auth model.Authentication) (bool, error) { +func (s *ServiceImpl) ValidateAuth(ctx context.Context, auth model.Authentication) (bool, error) { // If authentication is required but not provided if auth.Method == "" || auth.Method == model.AuthMethodNone { return false, ErrAuthRequired @@ -47,7 +50,9 @@ func (s *AuthServiceImpl) ValidateAuth(ctx context.Context, auth model.Authentic switch auth.Method { case model.AuthMethodGitHub: // Extract repo reference from the repository URL if it's not provided - return s.githubAuth.ValidateToken(auth.Token, auth.RepoRef) + return s.githubAuth.ValidateToken(ctx, auth.Token, auth.RepoRef) + case model.AuthMethodNone: + return false, ErrAuthRequired default: return false, ErrUnsupportedAuthMethod } diff --git a/internal/config/config.go b/internal/config/config.go index 01db852f..cbb5637c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,7 @@ package config import ( - "github.com/caarlos0/env/v11" + env "github.com/caarlos0/env/v11" ) // Config holds the application configuration diff --git a/internal/database/database.go b/internal/database/database.go index 9c457bce..1d5fc4f8 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -9,10 +9,11 @@ import ( // Common database errors var ( - ErrNotFound = errors.New("record not found") - ErrAlreadyExists = errors.New("record already exists") - ErrInvalidInput = errors.New("invalid input") - ErrDatabase = errors.New("database error") + ErrNotFound = errors.New("record not found") + ErrAlreadyExists = errors.New("record already exists") + ErrInvalidInput = errors.New("invalid input") + ErrDatabase = errors.New("database error") + ErrInvalidVersion = errors.New("invalid version: cannot publish older version after newer version") ) // Database defines the interface for database operations on MCPRegistry entries diff --git a/internal/database/import.go b/internal/database/import.go index db81fc09..ee70f107 100644 --- a/internal/database/import.go +++ b/internal/database/import.go @@ -33,7 +33,7 @@ func ImportSeedFile(mongo *MongoDB, seedFilePath string) error { // Read the seed file seedData, err := readSeedFile(seedFilePath) if err != nil { - log.Fatalf("Failed to read seed file: %v", err) + return fmt.Errorf("failed to read seed file: %w", err) } collection := mongo.collection @@ -48,7 +48,7 @@ func readSeedFile(path string) ([]model.ServerDetail, error) { // Read the file content fileContent, err := os.ReadFile(path) if err != nil { - return nil, fmt.Errorf("failed to read file: %v", err) + return nil, fmt.Errorf("failed to read file: %w", err) } // Parse the JSON content @@ -57,9 +57,8 @@ func readSeedFile(path string) ([]model.ServerDetail, error) { // Try parsing as a raw JSON array and then convert to our model var rawData []map[string]interface{} if jsonErr := json.Unmarshal(fileContent, &rawData); jsonErr != nil { - return nil, fmt.Errorf("failed to parse JSON: %v (original error: %v)", jsonErr, err) + return nil, fmt.Errorf("failed to parse JSON: %w (original error: %w)", jsonErr, err) } - } log.Printf("Found %d server entries in seed file", len(servers)) @@ -82,7 +81,6 @@ func importData(ctx context.Context, collection *mongo.Collection, servers []mod server.VersionDetail.Version = "0.0.1-seed" server.VersionDetail.ReleaseDate = time.Now().Format(time.RFC3339) server.VersionDetail.IsLatest = true - } // Create update document update := bson.M{"$set": server} @@ -95,11 +93,12 @@ func importData(ctx context.Context, collection *mongo.Collection, servers []mod continue } - if result.UpsertedCount > 0 { + switch { + case result.UpsertedCount > 0: log.Printf("[%d/%d] Created server: %s", i+1, len(servers), server.Name) - } else if result.ModifiedCount > 0 { + case result.ModifiedCount > 0: log.Printf("[%d/%d] Updated server: %s", i+1, len(servers), server.Name) - } else { + default: log.Printf("[%d/%d] Server already up to date: %s", i+1, len(servers), server.Name) } } diff --git a/internal/database/memory.go b/internal/database/memory.go index 6df02a28..5494b99f 100644 --- a/internal/database/memory.go +++ b/internal/database/memory.go @@ -3,26 +3,95 @@ package database import ( "context" "sort" + "strconv" + "strings" "sync" + "time" + "github.com/google/uuid" "github.com/modelcontextprotocol/registry/internal/model" ) // MemoryDB is an in-memory implementation of the Database interface type MemoryDB struct { - entries map[string]*model.Server + entries map[string]*model.ServerDetail mu sync.RWMutex } // NewMemoryDB creates a new instance of the in-memory database func NewMemoryDB(e map[string]*model.Server) *MemoryDB { + // Convert Server entries to ServerDetail entries + serverDetails := make(map[string]*model.ServerDetail) + for k, v := range e { + serverDetails[k] = &model.ServerDetail{ + Server: *v, + } + } return &MemoryDB{ - entries: e, + entries: serverDetails, + } +} + +// compareSemanticVersions compares two semantic version strings +// Returns: +// +// -1 if version1 < version2 +// 0 if version1 == version2 +// +1 if version1 > version2 +func compareSemanticVersions(version1, version2 string) int { + // Simple semantic version comparison + // Assumes format: major.minor.patch + + parts1 := strings.Split(version1, ".") + parts2 := strings.Split(version2, ".") + + // Pad with zeros if needed + maxLen := len(parts1) + if len(parts2) > maxLen { + maxLen = len(parts2) + } + + for len(parts1) < maxLen { + parts1 = append(parts1, "0") + } + for len(parts2) < maxLen { + parts2 = append(parts2, "0") + } + + // Compare each part + for i := 0; i < maxLen; i++ { + num1, err1 := strconv.Atoi(parts1[i]) + num2, err2 := strconv.Atoi(parts2[i]) + + // If parsing fails, fall back to string comparison + if err1 != nil || err2 != nil { + if parts1[i] < parts2[i] { + return -1 + } else if parts1[i] > parts2[i] { + return 1 + } + continue + } + + if num1 < num2 { + return -1 + } else if num1 > num2 { + return 1 + } } + + return 0 } // List retrieves all MCPRegistry entries with optional filtering and pagination -func (db *MemoryDB) List(ctx context.Context, filter map[string]interface{}, cursor string, limit int) ([]*model.Server, string, error) { +// +//gocognit:ignore +func (db *MemoryDB) List( + ctx context.Context, + filter map[string]interface{}, + cursor string, + limit int, +) ([]*model.Server, string, error) { if ctx.Err() != nil { return nil, "", ctx.Err() } @@ -37,8 +106,8 @@ func (db *MemoryDB) List(ctx context.Context, filter map[string]interface{}, cur // Convert all entries to a slice for pagination var allEntries []*model.Server for _, entry := range db.entries { - entryCopy := *entry - allEntries = append(allEntries, &entryCopy) + serverCopy := entry.Server + allEntries = append(allEntries, &serverCopy) } // Simple filtering implementation @@ -47,27 +116,25 @@ func (db *MemoryDB) List(ctx context.Context, filter map[string]interface{}, cur include := true // Apply filters if any - if filter != nil { - for key, value := range filter { - switch key { - case "name": - if entry.Name != value.(string) { - include = false - } - case "repoUrl": - if entry.Repository.URL != value.(string) { - include = false - } - case "serverDetail.id": - if entry.ID != value.(string) { - include = false - } - case "version": - if entry.VersionDetail.Version != value.(string) { - include = false - } - // Add more filter options as needed + for key, value := range filter { + switch key { + case "name": + if entry.Name != value.(string) { + include = false + } + case "repoUrl": + if entry.Repository.URL != value.(string) { + include = false + } + case "serverDetail.id": + if entry.ID != value.(string) { + include = false + } + case "version": + if entry.VersionDetail.Version != value.(string) { + include = false } + // Add more filter options as needed } } @@ -124,15 +191,9 @@ func (db *MemoryDB) GetByID(ctx context.Context, id string) (*model.ServerDetail defer db.mu.RUnlock() if entry, exists := db.entries[id]; exists { - return &model.ServerDetail{ - Server: model.Server{ - ID: entry.ID, - Name: entry.Name, - Description: entry.Description, - VersionDetail: entry.VersionDetail, - Repository: entry.Repository, - }, - }, nil + // Return a copy of the ServerDetail + serverDetailCopy := *entry + return &serverDetailCopy, nil } return nil, ErrNotFound @@ -153,24 +214,37 @@ func (db *MemoryDB) Publish(ctx context.Context, serverDetail *model.ServerDetai } // check that the name and the version are unique - + // Also check version ordering - don't allow publishing older versions after newer ones + var latestVersion string for _, entry := range db.entries { - if entry.Name == serverDetail.Name && entry.VersionDetail.Version == serverDetail.VersionDetail.Version { - return ErrAlreadyExists + if entry.Name == serverDetail.Name { + if entry.VersionDetail.Version == serverDetail.VersionDetail.Version { + return ErrAlreadyExists + } + + // Track the latest version for this package name + if latestVersion == "" || compareSemanticVersions(entry.VersionDetail.Version, latestVersion) > 0 { + latestVersion = entry.VersionDetail.Version + } } } + // If we found existing versions, check if the new version is older than the latest + if latestVersion != "" && compareSemanticVersions(serverDetail.VersionDetail.Version, latestVersion) < 0 { + return ErrInvalidVersion + } + if serverDetail.Repository.URL == "" { return ErrInvalidInput } - db.entries[serverDetail.ID] = &model.Server{ - ID: serverDetail.ID, - Name: serverDetail.Name, - Description: serverDetail.Description, - VersionDetail: serverDetail.VersionDetail, - Repository: serverDetail.Repository, - } + // Generate a new ID for the server detail + serverDetail.ID = uuid.New().String() + serverDetail.VersionDetail.IsLatest = true // Assume the new version is the latest + serverDetail.VersionDetail.ReleaseDate = time.Now().Format(time.RFC3339) + // Store a copy of the entire ServerDetail + serverDetailCopy := *serverDetail + db.entries[serverDetail.ID] = &serverDetailCopy return nil } diff --git a/internal/database/mongo.go b/internal/database/mongo.go index ef0eb779..a553d0fa 100644 --- a/internal/database/mongo.go +++ b/internal/database/mongo.go @@ -2,6 +2,7 @@ package database import ( "context" + "errors" "fmt" "log" "time" @@ -41,15 +42,15 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName // Create indexes for better query performance models := []mongo.IndexModel{ { - Keys: bson.D{{Key: "name", Value: 1}}, + Keys: bson.D{bson.E{Key: "name", Value: 1}}, }, { - Keys: bson.D{{Key: "id", Value: 1}}, + Keys: bson.D{bson.E{Key: "id", Value: 1}}, Options: options.Index().SetUnique(true), }, // add an index for the combination of name and version { - Keys: bson.D{{Key: "name", Value: 1}, {Key: "versiondetail.version", Value: 1}}, + Keys: bson.D{bson.E{Key: "name", Value: 1}, bson.E{Key: "versiondetail.version", Value: 1}}, Options: options.Index().SetUnique(true), }, } @@ -57,11 +58,11 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName _, err = collection.Indexes().CreateMany(ctx, models) if err != nil { // Mongo will error if the index already exists, we can ignore this and continue. - if err.(mongo.CommandError).Code != 86 { + var commandError mongo.CommandError + if errors.As(err, &commandError) && commandError.Code != 86 { return nil, err - } else { - log.Printf("Indexes already exists, skipping.") } + log.Printf("Indexes already exists, skipping.") } return &MongoDB{ @@ -72,7 +73,12 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName } // List retrieves MCPRegistry entries with optional filtering and pagination -func (db *MongoDB) List(ctx context.Context, filter map[string]interface{}, cursor string, limit int) ([]*model.Server, string, error) { +func (db *MongoDB) List( + ctx context.Context, + filter map[string]interface{}, + cursor string, + limit int, +) ([]*model.Server, string, error) { if limit <= 0 { // Set default limit if not provided limit = 10 @@ -113,11 +119,10 @@ func (db *MongoDB) List(ctx context.Context, filter map[string]interface{}, curs var cursorDoc model.Server err := db.collection.FindOne(ctx, bson.M{"id": cursor}).Decode(&cursorDoc) if err != nil { - if err == mongo.ErrNoDocuments { - // If cursor document not found, start from beginning - } else { + if !errors.Is(err, mongo.ErrNoDocuments) { return nil, "", err } + // If cursor document not found, start from beginning } else { // Use the cursor document's ID to paginate (records with ID > cursor's ID) mongoFilter["id"] = bson.M{"$gt": cursor} @@ -168,7 +173,7 @@ func (db *MongoDB) GetByID(ctx context.Context, id string) (*model.ServerDetail, var entry model.ServerDetail err := db.collection.FindOne(ctx, filter).Decode(&entry) if err != nil { - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { return nil, ErrNotFound } return nil, fmt.Errorf("error retrieving entry: %w", err) @@ -191,7 +196,7 @@ func (db *MongoDB) Publish(ctx context.Context, serverDetail *model.ServerDetail var existingEntry model.ServerDetail err := db.collection.FindOne(ctx, filter).Decode(&existingEntry) - if err != nil && err != mongo.ErrNoDocuments { + if err != nil && !errors.Is(err, mongo.ErrNoDocuments) { return fmt.Errorf("error checking existing entry: %w", err) } @@ -215,11 +220,13 @@ func (db *MongoDB) Publish(ctx context.Context, serverDetail *model.ServerDetail // update the existing entry to not be the latest version if existingEntry.ID != "" { - _, err = db.collection.UpdateOne(ctx, bson.M{"id": existingEntry.ID}, bson.M{"$set": bson.M{"versiondetail.islatest": false}}) + _, err = db.collection.UpdateOne( + ctx, + bson.M{"id": existingEntry.ID}, + bson.M{"$set": bson.M{"versiondetail.islatest": false}}) if err != nil { return fmt.Errorf("error updating existing entry: %w", err) } - } return nil diff --git a/internal/service/fake_service.go b/internal/service/fake_service.go index ba9c1a16..07aa805d 100644 --- a/internal/service/fake_service.go +++ b/internal/service/fake_service.go @@ -15,6 +15,8 @@ type fakeRegistryService struct { } // NewFakeRegistryService creates a new fake registry service with pre-populated data +// +//nolint:ireturn // Factory function intentionally returns interface for dependency injection func NewFakeRegistryService() RegistryService { // Sample registry entries with updated model structure registries := []*model.Server{ diff --git a/internal/service/registry_service.go b/internal/service/registry_service.go index 409d21c0..d9798be3 100644 --- a/internal/service/registry_service.go +++ b/internal/service/registry_service.go @@ -14,6 +14,8 @@ type registryServiceImpl struct { } // NewRegistryServiceWithDB creates a new registry service with the provided database +// +//nolint:ireturn // Factory function intentionally returns interface for dependency injection func NewRegistryServiceWithDB(db database.Database) RegistryService { return ®istryServiceImpl{ db: db, diff --git a/scripts/test_endpoints.sh b/scripts/test_endpoints.sh index c5b7c038..2efd2f3b 100755 --- a/scripts/test_endpoints.sh +++ b/scripts/test_endpoints.sh @@ -2,6 +2,17 @@ set -e +echo "==================================================" +echo "MCP Registry Endpoint Test Script" +echo "==================================================" +echo "This script expects the MCP Registry server to be running locally." +echo "Please ensure the server is started using one of the following methods:" +echo " • Docker Compose: docker compose up" +echo " • Direct execution: go run cmd/registry/main.go" +echo " • Built binary: ./build/registry" +echo "==================================================" +echo "" + # Default values HOST="http://localhost:8080" ENDPOINT="all" diff --git a/scripts/test_publish.sh b/scripts/test_publish.sh index 9425a09c..4a373675 100755 --- a/scripts/test_publish.sh +++ b/scripts/test_publish.sh @@ -2,6 +2,20 @@ set -e +echo "==================================================" +echo "MCP Registry Publish Endpoint Test Script" +echo "==================================================" +echo "This script expects the MCP Registry server to be running locally." +echo "Please ensure the server is started using one of the following methods:" +echo " • Docker Compose: docker compose up" +echo " • Direct execution: go run cmd/registry/main.go" +echo " • Built binary: ./build/registry" +echo "" +echo "REQUIRED: Set the BEARER_TOKEN environment variable with a valid GitHub token" +echo "Example: export BEARER_TOKEN=your_github_token_here" +echo "==================================================" +echo "" + # Default values HOST="http://localhost:8080" VERBOSE=false @@ -13,9 +27,20 @@ function show_usage { echo " -h, --host Base URL of the MCP Registry service (default: http://localhost:8080)" echo " -v, --verbose Show verbose output including full request payload" echo " --help Show this help message" + echo "" + echo "Environment Variables:" + echo " BEARER_TOKEN Required: GitHub token for authentication" exit 1 } +# Check if bearer token is set +if [[ -z "$BEARER_TOKEN" ]]; then + echo "Error: BEARER_TOKEN environment variable is not set." + echo "Please set your GitHub token as an environment variable:" + echo " export BEARER_TOKEN=your_github_token_here" + exit 1 +fi + # Check if jq is installed if ! command -v jq &> /dev/null; then echo "Error: jq is required but not installed." @@ -27,19 +52,14 @@ if ! command -v jq &> /dev/null; then fi # Check if the API is running -echo "Checking if the API is running at $HOST..." +echo "Checking if the MCP Registry API is running at $HOST..." health_check=$(curl -s -o /dev/null -w "%{http_code}" "$HOST/v0/health" 2>/dev/null) if [[ "$health_check" != "200" ]]; then - echo "Warning: API might not be running at $HOST (health check returned $health_check)" - echo "Do you want to continue anyway? (y/n)" - read -r proceed - if [[ ! "$proceed" =~ ^[Yy]$ ]]; then - echo "Exiting. Please start the API and try again." - exit 1 - fi - echo "Continuing as requested..." + echo "Error: MCP Registry API is not running at $HOST (health check returned $health_check)" + echo "Please start the server using one of the methods mentioned above and try again." + exit 1 else - echo "API is running at $HOST" + echo "✓ MCP Registry API is running at $HOST" fi # Parse command line arguments @@ -56,46 +76,51 @@ done # Create a temporary file for our JSON payload PAYLOAD_FILE=$(mktemp) -# Create sample server detail payload +# Create sample server detail payload based on current model structure cat > "$PAYLOAD_FILE" << EOF { - "name": "Test MCP Server", - "description": "A test server for MCP Registry", - "version_detail": { - "version": "1.0.2", - "release_date": "$(date -u +"%Y-%m-%dT%H:%M:%SZ")", - "is_latest": true - }, + "name": "io.github.example/test-mcp-server", + "description": "A test server for MCP Registry validation - published at $(date)", "repository": { "url": "https://github.com/example/test-mcp-server", - "branch": "main" + "source": "github", + "id": "example/test-mcp-server" }, - "registries": [ - { - "name": "npm", - "package_name": "test-mcp-server", - "license": "MIT", - "command_arguments": { - "sub_commands": [ - { - "name": "start", - "description": "Start the server" - } - ], - "environment_variables": [ - { - "name": "PORT", - "description": "Port to run the server on", - "required": false - } - ] - } - } - ], - "remotes": [ + "version_detail": { + "version": "1.0.$(date +%s)" + }, + "packages": [ { - "transport_type": "http", - "url": "http://example.com/api" + "registry_name": "npm", + "name": "test-mcp-server", + "version": "1.0.$(date +%s)", + "runtime_hint": "node", + "runtime_arguments": [ + { + "type": "positional", + "name": "config", + "description": "Configuration file path", + "format": "file_path", + "is_required": false, + "default": "./config.json" + } + ], + "environment_variables": [ + { + "name": "PORT", + "description": "Port to run the server on", + "format": "number", + "is_required": false, + "default": "3000" + }, + { + "name": "API_KEY", + "description": "API key for external service", + "format": "string", + "is_required": true, + "is_secret": true + } + ] } ] } @@ -110,17 +135,16 @@ fi # Test publish endpoint echo "Testing publish endpoint: $HOST/v0/publish" +echo "Using Bearer Token: ${BEARER_TOKEN:0:10}..." # Show only first 10 chars for security + # Get response and status code in a single request response_file=$(mktemp) headers_file=$(mktemp) -# Get token for authentication (or use dummy token for testing) -AUTH_TOKEN=${AUTH_TOKEN:-"test_token"} - # Execute curl with response body to file and headers+status to another file curl -s -X POST \ -H "Content-Type: application/json" \ - -H "Authorization: Bearer ${AUTH_TOKEN}" \ + -H "Authorization: Bearer ${BEARER_TOKEN}" \ -d "@$PAYLOAD_FILE" \ -D "$headers_file" \ -o "$response_file" \ @@ -143,33 +167,54 @@ if [[ "${status_code:0:1}" == "2" ]]; then echo "Response:" echo "$http_response" | jq '.' 2>/dev/null || echo "$http_response" - # Extract the server ID from the response - server_id=$(echo "$http_response" | jq -r '.id') + # Check for server added message and extract UUID + message=$(echo "$http_response" | jq -r '.message // empty' 2>/dev/null) + server_id=$(echo "$http_response" | jq -r '.id // .server_id // empty' 2>/dev/null) + + # Validate the response contains success indicators + success_indicators=0 - echo "Publish successful with ID: $server_id" + if [[ ! -z "$message" && "$message" != "null" ]]; then + echo "✓ Success message received: $message" + if [[ "$message" == *"server"* && ("$message" == *"added"* || "$message" == *"published"* || "$message" == *"created"*) ]]; then + ((success_indicators++)) + echo "✓ Message indicates server was successfully added" + fi + fi + + if [[ ! -z "$server_id" && "$server_id" != "null" && "$server_id" != "empty" ]]; then + echo "✓ Server UUID received: $server_id" + # Validate UUID format (basic check for UUID pattern) + if [[ "$server_id" =~ ^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$ ]]; then + ((success_indicators++)) + echo "✓ Server ID appears to be a valid UUID format" + else + echo "⚠ Server ID format may not be a standard UUID: $server_id" + ((success_indicators++)) # Still count as success if we got an ID + fi + fi - # If we got a valid ID, verify it was actually created by calling the servers endpoint - if [[ ! -z "$server_id" && "$server_id" != "null" ]]; then - echo "-------------------------------------" - echo "Verifying server was published by checking servers endpoint..." - verify_response=$(curl -s "$HOST/v0/servers/$server_id") - echo "Response from servers endpoint:" - echo "$verify_response" | jq '.' 2>/dev/null || echo "$verify_response" - echo "-------------------------------------" - echo "Server verification response:" - echo "$verify_response" | jq '.' 2>/dev/null || echo "$verify_response" - echo "Server verification successful" + if [[ $success_indicators -ge 2 ]]; then + echo "" + echo "🎉 PUBLISH TEST PASSED!" + echo " ✓ Server successfully published with ID: $server_id" + echo " ✓ Success message: $message" else - echo "Error: No valid server ID returned from publish response" - echo "Response:" - echo "$http_response" | jq '.' 2>/dev/null || echo "$http_response" + echo "" + echo "❌ PUBLISH TEST FAILED!" + echo " Expected: Success message about server being added AND a server UUID" + echo " Received: message='$message', id='$server_id'" exit 1 fi else - echo "Response:" + echo "" + echo "❌ PUBLISH TEST FAILED!" + echo " Expected: 2xx status code" + echo " Received: $status_code" + echo " Response:" echo "$http_response" | jq '.' 2>/dev/null || echo "$http_response" - echo "Publish failed" + exit 1 fi echo "-------------------------------------" diff --git a/tools/publisher/main.go b/tools/publisher/main.go index 9ed544ed..dcde2296 100644 --- a/tools/publisher/main.go +++ b/tools/publisher/main.go @@ -2,10 +2,12 @@ package main import ( "bytes" + "context" "encoding/json" "flag" "fmt" "io" + "log" "net/http" "os" "strings" @@ -13,11 +15,10 @@ import ( ) const ( - tokenFilePath = ".mcpregistry_token" - + tokenFilePath = ".mcpregistry_token" // #nosec:G101 // GitHub OAuth URLs - GitHubDeviceCodeURL = "https://github.com/login/device/code" - GitHubAccessTokenURL = "https://github.com/login/oauth/access_token" + GitHubDeviceCodeURL = "https://github.com/login/device/code" // #nosec:G101 + GitHubAccessTokenURL = "https://github.com/login/oauth/access_token" // #nosec:G101 ) // DeviceCodeResponse represents the response from GitHub's device code endpoint @@ -39,7 +40,7 @@ type AccessTokenResponse struct { type ServerHealthResponse struct { Status string `json:"status"` - GitHubClientId string `json:"github_client_id"` + GitHubClientID string `json:"github_client_id"` } func main() { @@ -62,29 +63,36 @@ func main() { // get the clientID from the server's health endpoint healthURL := registryURL + "/v0/health" - resp, err := http.Get(healthURL) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, healthURL, nil) + if err != nil { + log.Printf("Error creating request: %s\n", err.Error()) + return + } + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { - fmt.Printf("Error fetching health endpoint: %s\n", err.Error()) + log.Printf("Error fetching health endpoint: %s\n", err.Error()) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - fmt.Printf("Health endpoint returned status %d: %s\n", resp.StatusCode, body) + log.Printf("Health endpoint returned status %d: %s\n", resp.StatusCode, body) return } var healthResponse ServerHealthResponse err = json.NewDecoder(resp.Body).Decode(&healthResponse) if err != nil { - fmt.Printf("Error decoding health response: %s\n", err.Error()) + log.Printf("Error decoding health response: %s\n", err.Error()) return } - if healthResponse.GitHubClientId == "" { - fmt.Println("GitHub Client ID is not set in the server's health response.") + if healthResponse.GitHubClientID == "" { + log.Println("GitHub Client ID is not set in the server's health response.") return } - githubClientID := healthResponse.GitHubClientId + githubClientID := healthResponse.GitHubClientID var token string @@ -97,7 +105,7 @@ func main() { if forceLogin || os.IsNotExist(statErr) { err := performDeviceFlowLogin(githubClientID) if err != nil { - fmt.Printf("Failed to perform device flow login: %s\n", err.Error()) + log.Printf("Failed to perform device flow login: %s\n", err.Error()) return } } @@ -106,7 +114,7 @@ func main() { var err error token, err = readToken() if err != nil { - fmt.Printf("Error reading token: %s\n", err.Error()) + log.Printf("Error reading token: %s\n", err.Error()) return } } @@ -114,22 +122,21 @@ func main() { // Read MCP file mcpData, err := os.ReadFile(mcpFilePath) if err != nil { - fmt.Printf("Error reading MCP file: %s\n", err.Error()) + log.Printf("Error reading MCP file: %s\n", err.Error()) return } // Publish to registry err = publishToRegistry(registryURL, mcpData, token) if err != nil { - fmt.Printf("Failed to publish to registry: %s\n", err.Error()) + log.Printf("Failed to publish to registry: %s\n", err.Error()) return } - fmt.Println("Successfully published to registry!") + log.Println("Successfully published to registry!") } func performDeviceFlowLogin(githubClientID string) error { - if githubClientID == "" { return fmt.Errorf("GitHub Client ID is required for device flow login") } @@ -142,13 +149,13 @@ func performDeviceFlowLogin(githubClientID string) error { } // Display instructions to the user - fmt.Println("\nTo authenticate, please:") - fmt.Println("1. Go to:", verificationURI) - fmt.Println("2. Enter code:", userCode) - fmt.Println("3. Authorize this application") + log.Println("\nTo authenticate, please:") + log.Println("1. Go to:", verificationURI) + log.Println("2. Enter code:", userCode) + log.Println("3. Authorize this application") // Poll for the token - fmt.Println("Waiting for authorization...") + log.Println("Waiting for authorization...") token, err := pollForToken(deviceCode, githubClientID) if err != nil { return fmt.Errorf("error polling for token: %w", err) @@ -160,7 +167,7 @@ func performDeviceFlowLogin(githubClientID string) error { return fmt.Errorf("error saving token: %w", err) } - fmt.Println("Successfully authenticated!") + log.Println("Successfully authenticated!") return nil } @@ -180,7 +187,7 @@ func requestDeviceCode(githubClientID string) (string, string, string, error) { return "", "", "", err } - req, err := http.NewRequest("POST", GitHubDeviceCodeURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, GitHubDeviceCodeURL, bytes.NewBuffer(jsonData)) if err != nil { return "", "", "", err } @@ -235,7 +242,7 @@ func pollForToken(deviceCode, githubClientID string) (string, error) { deadline := time.Now().Add(time.Duration(expiresIn) * time.Second) for time.Now().Before(deadline) { - req, err := http.NewRequest("POST", GitHubAccessTokenURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, GitHubAccessTokenURL, bytes.NewBuffer(jsonData)) if err != nil { return "", err } @@ -262,7 +269,6 @@ func pollForToken(deviceCode, githubClientID string) (string, error) { if tokenResp.Error == "authorization_pending" { // User hasn't authorized yet, wait and retry - fmt.Print(".") time.Sleep(time.Duration(interval) * time.Second) continue } @@ -272,7 +278,6 @@ func pollForToken(deviceCode, githubClientID string) (string, error) { } if tokenResp.AccessToken != "" { - fmt.Println() // Add newline after dots return tokenResp.AccessToken, nil } @@ -322,7 +327,7 @@ func publishToRegistry(registryURL string, mcpData []byte, token string) error { publishURL := registryURL + "v0/publish" // Create and send the request - req, err := http.NewRequest("POST", publishURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, publishURL, bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("error creating request: %w", err) } @@ -346,6 +351,6 @@ func publishToRegistry(registryURL string, mcpData []byte, token string) error { return fmt.Errorf("publication failed with status %d: %s", resp.StatusCode, body) } - println(string(body)) + log.Println(string(body)) return nil }