diff --git a/cmd/registry/main.go b/cmd/registry/main.go index 387462da..01d2b0a2 100644 --- a/cmd/registry/main.go +++ b/cmd/registry/main.go @@ -55,7 +55,7 @@ func main() { defer cancel() // Connect to MongoDB - db, err = database.NewMongoDB(ctx, cfg.DatabaseURL, cfg.DatabaseName, cfg.CollectionName) + db, err = database.NewMongoDB(ctx, cfg.DatabaseURL, cfg.DatabaseName, cfg.CollectionName, cfg.VerificationCollectionName) if err != nil { log.Printf("Failed to connect to MongoDB: %v", err) return @@ -65,6 +65,7 @@ func main() { registryService = service.NewRegistryServiceWithDB(db) log.Printf("MongoDB database name: %s", cfg.DatabaseName) log.Printf("MongoDB collection name: %s", cfg.CollectionName) + log.Printf("MongoDB verification collection name: %s", cfg.VerificationCollectionName) // Store the MongoDB instance for later cleanup defer func() { diff --git a/examples/dns-verify/main.go b/examples/dns-verify/main.go new file mode 100644 index 00000000..98a76a4b --- /dev/null +++ b/examples/dns-verify/main.go @@ -0,0 +1,54 @@ +package main + +import ( + "fmt" + "log" + "os" + + "github.com/modelcontextprotocol/registry/internal/verification" +) + +func main() { + if len(os.Args) != 3 { + fmt.Println("Usage: dns-verify ") + fmt.Println("Example: dns-verify example.com TBeVXe_X4npM6p8vpzStnA") + os.Exit(1) + } + + domain := os.Args[1] + token := os.Args[2] + + fmt.Printf("๐Ÿ” Verifying DNS record for domain: %s\n", domain) + fmt.Printf("๐ŸŽฏ Expected token: %s\n", token) + fmt.Printf("๐Ÿ“‹ Expected DNS record: mcp-verify=%s\n\n", token) + + // Perform DNS verification + result, err := verification.VerifyDNSRecord(domain, token) + if err != nil { + log.Printf("โŒ DNS verification error: %v", err) + os.Exit(1) + } + + // Display results + fmt.Printf("๐Ÿ“Š Verification Results:\n") + fmt.Printf(" Success: %t\n", result.Success) + fmt.Printf(" Domain: %s\n", result.Domain) + fmt.Printf(" Token: %s\n", result.Token) + fmt.Printf(" Duration: %s\n", result.Duration) + fmt.Printf(" Message: %s\n", result.Message) + + if len(result.TXTRecords) > 0 { + fmt.Printf("\n๐Ÿ“ Found TXT Records:\n") + for i, record := range result.TXTRecords { + fmt.Printf(" %d. %s\n", i+1, record) + } + } + + if result.Success { + fmt.Println("\nโœ… Domain verification successful!") + os.Exit(0) + } else { + fmt.Println("\nโŒ Domain verification failed!") + os.Exit(1) + } +} diff --git a/internal/api/handlers/v0/domain_normalization_test.go b/internal/api/handlers/v0/domain_normalization_test.go new file mode 100644 index 00000000..46f0e97f --- /dev/null +++ b/internal/api/handlers/v0/domain_normalization_test.go @@ -0,0 +1,111 @@ +package v0 + +import ( + "testing" +) + +func TestNormalizeDomain(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple domain", + input: "example.com", + expected: "example.com", + }, + { + name: "domain with subdomain", + input: "api.example.com", + expected: "api.example.com", + }, + { + name: "domain with https protocol", + input: "https://example.com", + expected: "example.com", + }, + { + name: "domain with http protocol", + input: "http://example.com", + expected: "example.com", + }, + { + name: "domain with path", + input: "https://example.com/path/to/resource", + expected: "example.com", + }, + { + name: "domain with query parameters", + input: "https://example.com?param=value", + expected: "example.com", + }, + { + name: "domain with port", + input: "https://example.com:8080", + expected: "example.com:8080", + }, + { + name: "mixed case domain", + input: "EXAMPLE.COM", + expected: "example.com", + }, + { + name: "domain with mixed case and protocol", + input: "https://API.EXAMPLE.COM/path", + expected: "api.example.com", + }, + { + name: "github.io domain", + input: "username.github.io", + expected: "username.github.io", + }, + { + name: "github.io with protocol and path", + input: "https://username.github.io/project", + expected: "username.github.io", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := normalizeDomain(tt.input) + if err != nil { + t.Errorf("normalizeDomain(%q) returned unexpected error: %v", tt.input, err) + return + } + if result != tt.expected { + t.Errorf("normalizeDomain(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestNormalizeDomainErrors(t *testing.T) { + errorTests := []struct { + name string + input string + }{ + { + name: "empty string", + input: "", + }, + { + name: "whitespace only", + input: " ", + }, + { + name: "malformed URL", + input: "http://", + }, + } + + for _, tt := range errorTests { + t.Run(tt.name, func(t *testing.T) { + result, err := normalizeDomain(tt.input) + if err == nil { + t.Errorf("normalizeDomain(%q) = %q, expected error", tt.input, result) + } + }) + } +} diff --git a/internal/api/handlers/v0/domain_verification_test.go b/internal/api/handlers/v0/domain_verification_test.go new file mode 100644 index 00000000..e004de2c --- /dev/null +++ b/internal/api/handlers/v0/domain_verification_test.go @@ -0,0 +1,230 @@ +package v0_test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" + "github.com/modelcontextprotocol/registry/internal/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// MockRegistryServiceForDomainVerification is a mock implementation of the RegistryService interface for domain verification tests +type MockRegistryServiceForDomainVerification struct { + mock.Mock +} + +func (m *MockRegistryServiceForDomainVerification) List(cursor string, limit int) ([]model.Server, string, error) { + args := m.Mock.Called(cursor, limit) + return args.Get(0).([]model.Server), args.String(1), args.Error(2) +} + +func (m *MockRegistryServiceForDomainVerification) GetByID(id string) (*model.ServerDetail, error) { + args := m.Mock.Called(id) + return args.Get(0).(*model.ServerDetail), args.Error(1) +} + +func (m *MockRegistryServiceForDomainVerification) Publish(serverDetail *model.ServerDetail) error { + args := m.Mock.Called(serverDetail) + return args.Error(0) +} + +func (m *MockRegistryServiceForDomainVerification) ClaimDomain(domain string) (*model.VerificationToken, error) { + args := m.Mock.Called(domain) + return args.Get(0).(*model.VerificationToken), args.Error(1) +} + +func (m *MockRegistryServiceForDomainVerification) GetDomainVerificationStatus(domain string) (*model.VerificationTokens, error) { + args := m.Mock.Called(domain) + return args.Get(0).(*model.VerificationTokens), args.Error(1) +} + +func TestClaimDomainHandler(t *testing.T) { + tests := []struct { + name string + method string + requestBody interface{} + setupMocks func(*MockRegistryServiceForDomainVerification) + expectedStatus int + checkResponse func(t *testing.T, response *v0.DomainClaimResponse) + }{ + { + name: "successful domain claim", + method: http.MethodPost, + requestBody: v0.DomainClaimRequest{ + Domain: "example.com", + }, + setupMocks: func(registry *MockRegistryServiceForDomainVerification) { + registry.On("ClaimDomain", "example.com").Return(&model.VerificationToken{ + Token: "test-token-123", + CreatedAt: time.Now(), + }, nil) + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, response *v0.DomainClaimResponse) { + assert.Equal(t, "example.com", response.Domain) + assert.Equal(t, "example.com", response.NormalizedDomain) + assert.Equal(t, "test-token-123", response.Token) + assert.NotEmpty(t, response.CreatedAt) + }, + }, + { + name: "method not allowed", + method: http.MethodGet, + requestBody: nil, + expectedStatus: http.StatusMethodNotAllowed, + }, + { + name: "missing domain", + method: http.MethodPost, + requestBody: v0.DomainClaimRequest{ + Domain: "", + }, + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRegistry := new(MockRegistryServiceForDomainVerification) + + if tt.setupMocks != nil { + tt.setupMocks(mockRegistry) + } + + handler := v0.ClaimDomainHandler(mockRegistry) + + var reqBody []byte + if tt.requestBody != nil { + var err error + reqBody, err = json.Marshal(tt.requestBody) + require.NoError(t, err) + } + + req := httptest.NewRequest(tt.method, "/v0/domains/claim", bytes.NewReader(reqBody)) + + w := httptest.NewRecorder() + handler(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code) + + if tt.checkResponse != nil && w.Code == http.StatusCreated { + var response v0.DomainClaimResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + tt.checkResponse(t, &response) + } + + mockRegistry.AssertExpectations(t) + }) + } +} + +func TestGetDomainStatusHandler(t *testing.T) { + tests := []struct { + name string + method string + queryParam string + setupMocks func(*MockRegistryServiceForDomainVerification) + expectedStatus int + checkResponse func(t *testing.T, response *v0.DomainStatusResponse) + }{ + { + name: "domain with verified token", + method: http.MethodGet, + queryParam: "domain=verified.com", + setupMocks: func(registry *MockRegistryServiceForDomainVerification) { + verifiedAt := time.Now() + registry.On("GetDomainVerificationStatus", "verified.com").Return(&model.VerificationTokens{ + VerifiedToken: &model.VerificationToken{ + Token: "verified-token", + CreatedAt: time.Now(), + LastVerifiedAt: &verifiedAt, + }, + PendingTokens: []model.VerificationToken{}, + }, nil) + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, response *v0.DomainStatusResponse) { + assert.Equal(t, "verified.com", response.Domain) + assert.Equal(t, "verified", response.Status) + }, + }, + { + name: "domain with pending tokens only", + method: http.MethodGet, + queryParam: "domain=pending.com", + setupMocks: func(registry *MockRegistryServiceForDomainVerification) { + registry.On("GetDomainVerificationStatus", "pending.com").Return(&model.VerificationTokens{ + VerifiedToken: nil, + PendingTokens: []model.VerificationToken{ + { + Token: "pending-token-1", + CreatedAt: time.Now(), + }, + { + Token: "pending-token-2", + CreatedAt: time.Now(), + }, + }, + }, nil) + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, response *v0.DomainStatusResponse) { + assert.Equal(t, "pending.com", response.Domain) + assert.Equal(t, "unverified", response.Status) + }, + }, + { + name: "method not allowed", + method: http.MethodPost, + queryParam: "", + expectedStatus: http.StatusMethodNotAllowed, + }, + { + name: "missing domain parameter", + method: http.MethodGet, + queryParam: "", + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRegistry := new(MockRegistryServiceForDomainVerification) + + if tt.setupMocks != nil { + tt.setupMocks(mockRegistry) + } + + handler := v0.GetDomainStatusHandler(mockRegistry) + + url := "/v0/domains/status" + if tt.queryParam != "" { + url = url + "?" + tt.queryParam + } + + req := httptest.NewRequest(tt.method, url, nil) + + w := httptest.NewRecorder() + handler(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code) + + if tt.checkResponse != nil && w.Code == http.StatusOK { + var response v0.DomainStatusResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + tt.checkResponse(t, &response) + } + + mockRegistry.AssertExpectations(t) + }) + } +} diff --git a/internal/api/handlers/v0/publish_test.go b/internal/api/handlers/v0/publish_test.go index 414c6be1..5d00c60c 100644 --- a/internal/api/handlers/v0/publish_test.go +++ b/internal/api/handlers/v0/publish_test.go @@ -37,6 +37,16 @@ func (m *MockRegistryService) Publish(serverDetail *model.ServerDetail) error { return args.Error(0) } +func (m *MockRegistryService) ClaimDomain(domain string) (*model.VerificationToken, error) { + args := m.Mock.Called(domain) + return args.Get(0).(*model.VerificationToken), args.Error(1) +} + +func (m *MockRegistryService) GetDomainVerificationStatus(domain string) (*model.VerificationTokens, error) { + args := m.Mock.Called(domain) + return args.Get(0).(*model.VerificationTokens), args.Error(1) +} + // MockAuthService is a mock implementation of the auth.Service interface type MockAuthService struct { mock.Mock diff --git a/internal/api/handlers/v0/verification.go b/internal/api/handlers/v0/verification.go new file mode 100644 index 00000000..3afabc4d --- /dev/null +++ b/internal/api/handlers/v0/verification.go @@ -0,0 +1,170 @@ +package v0 + +import ( + "encoding/json" + "errors" + "net/http" + "net/url" + "strings" + + "github.com/modelcontextprotocol/registry/internal/database" + "github.com/modelcontextprotocol/registry/internal/service" +) + +// normalizeDomain extracts and cleans the domain from a URL or domain string +// It removes protocols, paths, and query parameters, returning just the hostname +func normalizeDomain(domain string) (string, error) { + domain = strings.TrimSpace(domain) + if domain == "" { + return "", errors.New("domain cannot be empty") + } + + // Try parsing as URL first (handles cases with protocol) + if u, err := url.Parse(domain); err == nil && u.Host != "" { + return strings.ToLower(u.Host), nil + } + + // If no protocol, try adding one and parsing again + if !strings.Contains(domain, "://") { + if u, err := url.Parse("https://" + domain); err == nil && u.Host != "" { + return strings.ToLower(u.Host), nil + } + } + + // If we get here, the input is not a valid domain/URL + return "", errors.New("invalid domain format") +} + +// DomainClaimRequest represents the request body for domain claiming +type DomainClaimRequest struct { + Domain string `json:"domain"` +} + +// DomainStatusRequest represents the request body for domain status checking +type DomainStatusRequest struct { + Domain string `json:"domain"` +} + +// DomainClaimResponse represents the response for domain claim operations +type DomainClaimResponse struct { + Domain string `json:"domain"` // Original domain from request + NormalizedDomain string `json:"normalized_domain"` // Cleaned domain (TLD + subdomains) + Token string `json:"token"` + CreatedAt string `json:"created_at"` +} + +// DomainStatusResponse represents the response for domain verification status +type DomainStatusResponse struct { + Domain string `json:"domain"` + Status string `json:"status"` // "verified" or "unverified" +} + +// ClaimDomainHandler handles requests to claim a domain for verification +func ClaimDomainHandler(registry service.RegistryService) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Only allow POST method + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse request body + var req DomainClaimRequest + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + http.Error(w, "Invalid request payload: "+err.Error(), http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Validate required fields + if req.Domain == "" { + http.Error(w, "domain is required", http.StatusBadRequest) + return + } + + // Normalize the domain (remove protocol, path, etc.) + normalizedDomain, err := normalizeDomain(req.Domain) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Generate and store the verification token for the normalized domain + verificationToken, err := registry.ClaimDomain(normalizedDomain) + if err != nil { + http.Error(w, "Failed to claim domain: "+err.Error(), http.StatusInternalServerError) + return + } + + // Prepare response + response := DomainClaimResponse{ + Domain: req.Domain, + NormalizedDomain: normalizedDomain, + Token: verificationToken.Token, + CreatedAt: verificationToken.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// GetDomainStatusHandler handles requests to get domain verification status +func GetDomainStatusHandler(registry service.RegistryService) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Only allow GET method + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Get domain from query parameter + domain := r.URL.Query().Get("domain") + if domain == "" { + http.Error(w, "domain query parameter is required", http.StatusBadRequest) + return + } + + // Normalize the domain (remove protocol, path, etc.) + normalizedDomain, err := normalizeDomain(domain) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Get the domain verification status using normalized domain + verificationTokens, err := registry.GetDomainVerificationStatus(normalizedDomain) + if err != nil { + if errors.Is(err, database.ErrNotFound) { + http.Error(w, "Domain not found", http.StatusNotFound) + return + } + http.Error(w, "Failed to retrieve domain status: "+err.Error(), http.StatusInternalServerError) + return + } + + // Determine status + status := "unverified" + if verificationTokens.VerifiedToken != nil { + status = "verified" + } + + // Prepare response with normalized domain + response := DomainStatusResponse{ + Domain: normalizedDomain, + Status: status, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} diff --git a/internal/api/router/v0.go b/internal/api/router/v0.go index 6d465f99..82f147ca 100644 --- a/internal/api/router/v0.go +++ b/internal/api/router/v0.go @@ -20,6 +20,8 @@ func RegisterV0Routes( mux.HandleFunc("/v0/servers/{id}", v0.ServersDetailHandler(registry)) mux.HandleFunc("/v0/ping", v0.PingHandler(cfg)) mux.HandleFunc("/v0/publish", v0.PublishHandler(registry, authService)) + mux.HandleFunc("/v0/domains/claim", v0.ClaimDomainHandler(registry)) + mux.HandleFunc("/v0/domains/status", v0.GetDomainStatusHandler(registry)) // Register Swagger UI routes mux.HandleFunc("/v0/swagger/", v0.SwaggerHandler()) diff --git a/internal/config/config.go b/internal/config/config.go index 950c3f58..445178b5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,17 +13,18 @@ const ( // Config holds the application configuration type Config struct { - ServerAddress string `env:"SERVER_ADDRESS" envDefault:":8080"` - DatabaseType DatabaseType `env:"DATABASE_TYPE" envDefault:"mongodb"` - DatabaseURL string `env:"DATABASE_URL" envDefault:"mongodb://localhost:27017"` - DatabaseName string `env:"DATABASE_NAME" envDefault:"mcp-registry"` - CollectionName string `env:"COLLECTION_NAME" envDefault:"servers_v2"` - LogLevel string `env:"LOG_LEVEL" envDefault:"info"` - SeedFilePath string `env:"SEED_FILE_PATH" envDefault:"data/seed.json"` - SeedImport bool `env:"SEED_IMPORT" envDefault:"true"` - Version string `env:"VERSION" envDefault:"dev"` - GithubClientID string `env:"GITHUB_CLIENT_ID" envDefault:""` - GithubClientSecret string `env:"GITHUB_CLIENT_SECRET" envDefault:""` + ServerAddress string `env:"SERVER_ADDRESS" envDefault:":8080"` + DatabaseType DatabaseType `env:"DATABASE_TYPE" envDefault:"mongodb"` + DatabaseURL string `env:"DATABASE_URL" envDefault:"mongodb://localhost:27017"` + DatabaseName string `env:"DATABASE_NAME" envDefault:"mcp-registry"` + CollectionName string `env:"COLLECTION_NAME" envDefault:"servers_v2"` + VerificationCollectionName string `env:"VERIFICATION_COLLECTION_NAME" envDefault:"verification"` + LogLevel string `env:"LOG_LEVEL" envDefault:"info"` + SeedFilePath string `env:"SEED_FILE_PATH" envDefault:"data/seed.json"` + SeedImport bool `env:"SEED_IMPORT" envDefault:"true"` + Version string `env:"VERSION" envDefault:"dev"` + GithubClientID string `env:"GITHUB_CLIENT_ID" envDefault:""` + GithubClientSecret string `env:"GITHUB_CLIENT_SECRET" envDefault:""` } // NewConfig creates a new configuration with default values diff --git a/internal/database/database.go b/internal/database/database.go index d145f0c6..aba506c0 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -9,11 +9,13 @@ import ( // Common database errors var ( - ErrNotFound = errors.New("record not found") - ErrAlreadyExists = errors.New("record already exists") - ErrInvalidInput = errors.New("invalid input") - ErrDatabase = errors.New("database error") - ErrInvalidVersion = errors.New("invalid version: cannot publish older version after newer version") + ErrNotFound = errors.New("record not found") + ErrAlreadyExists = errors.New("record already exists") + ErrInvalidInput = errors.New("invalid input") + ErrDatabase = errors.New("database error") + ErrInvalidVersion = errors.New("invalid version: cannot publish older version after newer version") + ErrMaxAttemptsExceeded = errors.New("maximum attempts exceeded for token generation") + ErrTokenAlreadyExists = errors.New("verification token already exists") ) // Database defines the interface for database operations on MCPRegistry entries @@ -24,6 +26,10 @@ type Database interface { GetByID(ctx context.Context, id string) (*model.ServerDetail, error) // Publish adds a new ServerDetail to the database Publish(ctx context.Context, serverDetail *model.ServerDetail) error + // StoreVerificationToken atomically stores a verification token for a domain if the token is unique + StoreVerificationToken(ctx context.Context, domain string, token *model.VerificationToken) error + // GetVerificationTokens retrieves all verification tokens by domain + GetVerificationTokens(ctx context.Context, domain string) (*model.VerificationTokens, error) // ImportSeed imports initial data from a seed file ImportSeed(ctx context.Context, seedFilePath string) error // Close closes the database connection diff --git a/internal/database/memory.go b/internal/database/memory.go index 6cd6cb01..b729a3e3 100644 --- a/internal/database/memory.go +++ b/internal/database/memory.go @@ -16,8 +16,9 @@ import ( // MemoryDB is an in-memory implementation of the Database interface type MemoryDB struct { - entries map[string]*model.ServerDetail - mu sync.RWMutex + entries map[string]*model.ServerDetail + domainVerifications map[string]*model.DomainVerification // key: domain + mu sync.RWMutex } // NewMemoryDB creates a new instance of the in-memory database @@ -30,7 +31,8 @@ func NewMemoryDB(e map[string]*model.Server) *MemoryDB { } } return &MemoryDB{ - entries: serverDetails, + entries: serverDetails, + domainVerifications: make(map[string]*model.DomainVerification), } } @@ -306,3 +308,71 @@ func (db *MemoryDB) Connection() *ConnectionInfo { Raw: db.entries, } } + +// StoreVerificationToken atomically stores a verification token for a domain if the token is unique +func (db *MemoryDB) StoreVerificationToken(ctx context.Context, domain string, token *model.VerificationToken) error { + db.mu.Lock() + defer db.mu.Unlock() + + // Check if the token is unique across all domains + for _, domainVerification := range db.domainVerifications { + if domainVerification.VerificationTokens == nil { + continue + } + + // Check verified token + if domainVerification.VerificationTokens.VerifiedToken != nil && + domainVerification.VerificationTokens.VerifiedToken.Token == token.Token { + return ErrTokenAlreadyExists + } + + // Check pending tokens + for _, pendingToken := range domainVerification.VerificationTokens.PendingTokens { + if pendingToken.Token == token.Token { + return ErrTokenAlreadyExists + } + } + } + + // Token is unique, store it + existingVerification, exists := db.domainVerifications[domain] + + var verificationTokens *model.VerificationTokens + + if exists && existingVerification.VerificationTokens != nil { + // Add to existing pending tokens + verificationTokens = existingVerification.VerificationTokens + verificationTokens.PendingTokens = append(verificationTokens.PendingTokens, *token) + } else { + // No existing record or no verification tokens - create new structure + verificationTokens = &model.VerificationTokens{ + PendingTokens: []model.VerificationToken{*token}, + } + } + + // Create or update domain verification + domainVerification := &model.DomainVerification{ + Domain: domain, + VerificationTokens: verificationTokens, + } + + db.domainVerifications[domain] = domainVerification + return nil +} + +// GetVerificationTokens retrieves verification tokens by domain +func (db *MemoryDB) GetVerificationTokens(ctx context.Context, domain string) (*model.VerificationTokens, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + domainVerification, exists := db.domainVerifications[domain] + if !exists { + return nil, ErrNotFound + } + + if domainVerification.VerificationTokens == nil { + return nil, fmt.Errorf("verification tokens data is missing from domain verification") + } + + return domainVerification.VerificationTokens, nil +} diff --git a/internal/database/mongo.go b/internal/database/mongo.go index 21538493..edc2a45d 100644 --- a/internal/database/mongo.go +++ b/internal/database/mongo.go @@ -16,13 +16,14 @@ import ( // MongoDB is an implementation of the Database interface using MongoDB type MongoDB struct { - client *mongo.Client - database *mongo.Database - collection *mongo.Collection + client *mongo.Client + database *mongo.Database + serverCollection *mongo.Collection + verificationCollection *mongo.Collection } // NewMongoDB creates a new instance of the MongoDB database -func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName string) (*MongoDB, error) { +func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName, verificationCollectionName string) (*MongoDB, error) { // Set client options and connect to MongoDB clientOptions := options.Client().ApplyURI(connectionURI) client, err := mongo.Connect(ctx, clientOptions) @@ -37,7 +38,8 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName // Get database and collection database := client.Database(databaseName) - collection := database.Collection(collectionName) + serverCollection := database.Collection(collectionName) + verificationCollection := database.Collection(verificationCollectionName) // Create indexes for better query performance models := []mongo.IndexModel{ @@ -55,7 +57,7 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName }, } - _, err = collection.Indexes().CreateMany(ctx, models) + _, err = serverCollection.Indexes().CreateMany(ctx, models) if err != nil { // Mongo will error if the index already exists, we can ignore this and continue. var commandError mongo.CommandError @@ -65,10 +67,38 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName log.Printf("Indexes already exists, skipping.") } + // Create indexes for verification collection + verificationIndexes := []mongo.IndexModel{ + { + Keys: bson.D{bson.E{Key: "domain", Value: 1}}, + Options: options.Index().SetUnique(true), + }, + // Create a sparse unique index on verification tokens to ensure global token uniqueness + { + Keys: bson.D{bson.E{Key: "verification_tokens.verified_token.token", Value: 1}}, + Options: options.Index().SetUnique(true).SetSparse(true), + }, + { + Keys: bson.D{bson.E{Key: "verification_tokens.pending_tokens.token", Value: 1}}, + Options: options.Index().SetUnique(true).SetSparse(true), + }, + } + + _, err = verificationCollection.Indexes().CreateMany(ctx, verificationIndexes) + if err != nil { + // Mongo will error if the index already exists, we can ignore this and continue. + var commandError mongo.CommandError + if errors.As(err, &commandError) && commandError.Code != 86 { + return nil, err + } + log.Printf("Verification collection indexes already exist, skipping.") + } + return &MongoDB{ - client: client, - database: database, - collection: collection, + client: client, + database: database, + serverCollection: serverCollection, + verificationCollection: verificationCollection, }, nil } @@ -117,7 +147,7 @@ func (db *MongoDB) List( // Fetch the document at the cursor to get its sort values var cursorDoc model.Server - err := db.collection.FindOne(ctx, bson.M{"id": cursor}).Decode(&cursorDoc) + err := db.serverCollection.FindOne(ctx, bson.M{"id": cursor}).Decode(&cursorDoc) if err != nil { if !errors.Is(err, mongo.ErrNoDocuments) { return nil, "", err @@ -138,7 +168,7 @@ func (db *MongoDB) List( } // Execute find operation with options - mongoCursor, err := db.collection.Find(ctx, mongoFilter, findOptions) + mongoCursor, err := db.serverCollection.Find(ctx, mongoFilter, findOptions) if err != nil { return nil, "", err } @@ -171,7 +201,7 @@ func (db *MongoDB) GetByID(ctx context.Context, id string) (*model.ServerDetail, // Find the entry in the database var entry model.ServerDetail - err := db.collection.FindOne(ctx, filter).Decode(&entry) + err := db.serverCollection.FindOne(ctx, filter).Decode(&entry) if err != nil { if errors.Is(err, mongo.ErrNoDocuments) { return nil, ErrNotFound @@ -195,7 +225,7 @@ func (db *MongoDB) Publish(ctx context.Context, serverDetail *model.ServerDetail } var existingEntry model.ServerDetail - err := db.collection.FindOne(ctx, filter).Decode(&existingEntry) + err := db.serverCollection.FindOne(ctx, filter).Decode(&existingEntry) if err != nil && !errors.Is(err, mongo.ErrNoDocuments) { return fmt.Errorf("error checking existing entry: %w", err) } @@ -210,7 +240,7 @@ func (db *MongoDB) Publish(ctx context.Context, serverDetail *model.ServerDetail serverDetail.VersionDetail.ReleaseDate = time.Now().Format(time.RFC3339) // Insert the entry into the database - _, err = db.collection.InsertOne(ctx, serverDetail) + _, err = db.serverCollection.InsertOne(ctx, serverDetail) if err != nil { if mongo.IsDuplicateKeyError(err) { return ErrAlreadyExists @@ -220,7 +250,7 @@ func (db *MongoDB) Publish(ctx context.Context, serverDetail *model.ServerDetail // update the existing entry to not be the latest version if existingEntry.ID != "" { - _, err = db.collection.UpdateOne( + _, err = db.serverCollection.UpdateOne( ctx, bson.M{"id": existingEntry.ID}, bson.M{"$set": bson.M{"version_detail.islatest": false}}) @@ -240,7 +270,7 @@ func (db *MongoDB) ImportSeed(ctx context.Context, seedFilePath string) error { return fmt.Errorf("failed to read seed file: %w", err) } - collection := db.collection + collection := db.serverCollection log.Printf("Importing %d servers into collection %s", len(servers), collection.Name()) @@ -307,3 +337,51 @@ func (db *MongoDB) Connection() *ConnectionInfo { Raw: db.client, } } + +// StoreVerificationToken atomically stores a verification token for a domain if the token is unique +func (db *MongoDB) StoreVerificationToken(ctx context.Context, domain string, token *model.VerificationToken) error { + domainFilter := bson.M{"domain": domain} + + update := bson.M{ + "$setOnInsert": bson.M{ + "domain": domain, + }, + "$addToSet": bson.M{ + "verification_tokens.pending_tokens": token, + }, + } + + opts := options.Update().SetUpsert(true) + _, err := db.verificationCollection.UpdateOne(ctx, domainFilter, update, opts) + if err != nil { + // Check if this is a duplicate key error due to token uniqueness constraint + if mongo.IsDuplicateKeyError(err) { + return ErrTokenAlreadyExists + } + return fmt.Errorf("failed to store verification token: %w", err) + } + + return nil +} + +// GetVerificationTokens retrieves verification tokens by domain +func (db *MongoDB) GetVerificationTokens(ctx context.Context, domain string) (*model.VerificationTokens, error) { + filter := bson.M{ + "domain": domain, + } + + var domainVerification model.DomainVerification + err := db.verificationCollection.FindOne(ctx, filter).Decode(&domainVerification) + if err != nil { + if errors.Is(err, mongo.ErrNoDocuments) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("failed to get verification tokens: %w", err) + } + + if domainVerification.VerificationTokens == nil { + return nil, fmt.Errorf("verification tokens data is missing from domain verification") + } + + return domainVerification.VerificationTokens, nil +} diff --git a/internal/docs/swagger.yaml b/internal/docs/swagger.yaml index 17aea9fe..046fdf3b 100644 --- a/internal/docs/swagger.yaml +++ b/internal/docs/swagger.yaml @@ -7,12 +7,16 @@ info: servers: - url: / description: Default server + - url: http://localhost:8080 + description: Local development server tags: - name: health description: Health checking operations - name: servers description: Server registry operations + - name: domains + description: Domain verification operations paths: /v0/health: @@ -200,6 +204,96 @@ paths: '405': description: Method not allowed + /v0/domains/claim: + post: + tags: + - domains + summary: Claim a domain for verification + description: Generate a verification token for a domain to enable domain-based server publishing + operationId: claimDomain + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/DomainClaimRequest' + responses: + '201': + description: Domain claimed successfully + content: + application/json: + schema: + $ref: '#/components/schemas/DomainClaimResponse' + '400': + description: Invalid request + content: + application/json: + schema: + type: object + properties: + error: + type: string + example: domain is required + '405': + description: Method not allowed + '500': + description: Server error + content: + application/json: + schema: + type: object + properties: + error: + type: string + example: Failed to claim domain + + /v0/domains/status: + get: + tags: + - domains + summary: Get domain verification status + description: Check the verification status of a domain + operationId: getDomainStatus + parameters: + - name: domain + in: query + required: true + description: The domain to check verification status for + schema: + type: string + example: example.com + responses: + '200': + description: Successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/DomainStatusResponse' + '400': + description: Invalid request + content: + application/json: + schema: + type: object + properties: + error: + type: string + example: domain query parameter is required + '404': + description: Domain not found + '405': + description: Method not allowed + '500': + description: Server error + content: + application/json: + schema: + type: object + properties: + error: + type: string + example: Failed to retrieve domain status + components: securitySchemes: bearerAuth: @@ -243,3 +337,57 @@ components: count: type: integer example: 30 + + DomainClaimRequest: + type: object + required: + - domain + properties: + domain: + type: string + description: Domain to claim for verification (can include protocol and path, will be normalized) + example: "https://example.com" + + DomainClaimResponse: + type: object + properties: + domain: + type: string + description: Original domain from request + example: "https://example.com" + normalized_domain: + type: string + description: Cleaned domain (hostname only) + example: "example.com" + token: + type: string + description: Verification token to be used for domain verification + example: "mcp-verification-abc123def456" + created_at: + type: string + format: date-time + description: Timestamp when the token was created + example: "2025-08-04T12:34:56Z" + + DomainStatusRequest: + type: object + required: + - domain + properties: + domain: + type: string + description: Domain to check verification status for + example: "example.com" + + DomainStatusResponse: + type: object + properties: + domain: + type: string + description: Normalized domain name + example: "example.com" + status: + type: string + enum: ["verified", "unverified"] + description: Verification status of the domain + example: "verified" diff --git a/internal/model/model.go b/internal/model/model.go index 293deab1..10dc9bfb 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -1,5 +1,7 @@ package model +import "time" + // AuthMethod represents the authentication method used type AuthMethod string @@ -135,3 +137,36 @@ type ServerDetail struct { Packages []Package `json:"packages,omitempty" bson:"packages,omitempty"` Remotes []Remote `json:"remotes,omitempty" bson:"remotes,omitempty"` } + +// VerificationToken represents a domain verification token for a server +type VerificationToken struct { + Token string `json:"token" bson:"token"` + CreatedAt time.Time `json:"created_at" bson:"created_at"` + DisabledAt *time.Time `json:"disabled_at,omitempty" bson:"disabled_at,omitempty"` + LastVerifiedAt *time.Time `json:"last_verified_at,omitempty" bson:"last_verified_at,omitempty"` +} + +// VerificationTokens represents the collection of verification tokens for a domain +type VerificationTokens struct { + VerifiedToken *VerificationToken `json:"verified_token,omitempty" bson:"verified_token,omitempty"` + PendingTokens []VerificationToken `json:"pending_tokens,omitempty" bson:"pending_tokens,omitempty"` +} + +// DomainVerification represents verification data for a specific domain +type DomainVerification struct { + Domain string `json:"domain" bson:"domain"` + VerificationTokens *VerificationTokens `json:"verification_tokens,omitempty" bson:"verification_tokens,omitempty"` +} + +// DomainVerificationRequest represents a request to generate a verification token for a domain +type DomainVerificationRequest struct { + Domain string `json:"domain"` +} + +// DomainVerificationResponse represents the response for domain verification operations +type DomainVerificationResponse struct { + Domain string `json:"domain"` + Token string `json:"token"` + CreatedAt string `json:"created_at"` + DNSRecord string `json:"dns_record"` +} diff --git a/internal/service/fake_service.go b/internal/service/fake_service.go index 07aa805d..146d9798 100644 --- a/internal/service/fake_service.go +++ b/internal/service/fake_service.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/modelcontextprotocol/registry/internal/database" "github.com/modelcontextprotocol/registry/internal/model" + "github.com/modelcontextprotocol/registry/internal/verification" ) // fakeRegistryService implements RegistryService interface with an in-memory database @@ -123,6 +124,62 @@ func (s *fakeRegistryService) Publish(serverDetail *model.ServerDetail) error { return s.db.Publish(ctx, serverDetail) } +// ClaimDomain generates a verification token for a domain and stores it as pending +func (s *fakeRegistryService) ClaimDomain(domain string) (*model.VerificationToken, error) { + // Create a timeout context for the database operation + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + const maxAttempts = 10 + for attempt := 1; attempt <= maxAttempts; attempt++ { + // Generate a verification token + token, err := verification.GenerateVerificationToken() + if err != nil { + return nil, err + } + + // Create the verification token + verificationToken := &model.VerificationToken{ + Token: token, + CreatedAt: time.Now(), + } + + // Try to store the token in the database as pending for the domain + // The unique index will ensure token uniqueness + err = s.db.StoreVerificationToken(ctx, domain, verificationToken) + if err == nil { + // Successfully stored, return the token + return verificationToken, nil + } + + // If it's a token already exists error, retry with a new token + if err == database.ErrTokenAlreadyExists { + continue + } + + // For any other error, return it + return nil, err + } + + // If we've exhausted all attempts, return an error + return nil, database.ErrMaxAttemptsExceeded +} + +// GetDomainVerificationStatus retrieves the verification status for a domain +func (s *fakeRegistryService) GetDomainVerificationStatus(domain string) (*model.VerificationTokens, error) { + // Create a timeout context for the database operation + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Get the verification tokens from the database + tokens, err := s.db.GetVerificationTokens(ctx, domain) + if err != nil { + return nil, err + } + + return tokens, nil +} + // Close closes the in-memory database connection func (s *fakeRegistryService) Close() error { return s.db.Close() diff --git a/internal/service/registry_service.go b/internal/service/registry_service.go index d9798be3..e02acead 100644 --- a/internal/service/registry_service.go +++ b/internal/service/registry_service.go @@ -2,10 +2,12 @@ package service import ( "context" + "errors" "time" "github.com/modelcontextprotocol/registry/internal/database" "github.com/modelcontextprotocol/registry/internal/model" + "github.com/modelcontextprotocol/registry/internal/verification" ) // registryServiceImpl implements the RegistryService interface using our Database @@ -101,3 +103,57 @@ func (s *registryServiceImpl) Publish(serverDetail *model.ServerDetail) error { return nil } + +// ClaimDomain generates a verification token for a domain and stores it as pending +func (s *registryServiceImpl) ClaimDomain(domain string) (*model.VerificationToken, error) { + // Create a timeout context for the database operation + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + const maxAttempts = 10 + for attempt := 1; attempt <= maxAttempts; attempt++ { + // Generate a verification token + token, err := verification.GenerateVerificationToken() + if err != nil { + return nil, err + } + + // Create the verification token object + verificationToken := &model.VerificationToken{ + Token: token, + CreatedAt: time.Now(), + } + + // Try to store the token atomically + err = s.db.StoreVerificationToken(ctx, domain, verificationToken) + if err != nil { + if errors.Is(err, database.ErrTokenAlreadyExists) { + // Token collision, try again with a new token + continue + } + // Other error, return it + return nil, err + } + + // Success! Token was stored atomically + return verificationToken, nil + } + + // If we've exhausted all attempts, return an error + return nil, database.ErrMaxAttemptsExceeded +} + +// GetDomainVerificationStatus retrieves the verification status for a domain +func (s *registryServiceImpl) GetDomainVerificationStatus(domain string) (*model.VerificationTokens, error) { + // Create a timeout context for the database operation + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Get the verification tokens from the database + tokens, err := s.db.GetVerificationTokens(ctx, domain) + if err != nil { + return nil, err + } + + return tokens, nil +} diff --git a/internal/service/registry_service_test.go b/internal/service/registry_service_test.go new file mode 100644 index 00000000..0a01f7bd --- /dev/null +++ b/internal/service/registry_service_test.go @@ -0,0 +1,149 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/modelcontextprotocol/registry/internal/database" + "github.com/modelcontextprotocol/registry/internal/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClaimDomain_Uniqueness(t *testing.T) { + // Create an in-memory database for testing + memDB := database.NewMemoryDB(make(map[string]*model.Server)) + service := NewRegistryServiceWithDB(memDB) + + domain1 := "example.com" + domain2 := "test.org" + + // Generate first token + token1, err := service.ClaimDomain(domain1) + require.NoError(t, err) + require.NotNil(t, token1) + assert.NotEmpty(t, token1.Token) + + // Generate second token + token2, err := service.ClaimDomain(domain2) + require.NoError(t, err) + require.NotNil(t, token2) + assert.NotEmpty(t, token2.Token) + + // Tokens should be different + assert.NotEqual(t, token1.Token, token2.Token, "Generated tokens should be unique") + + // Verify tokens are stored correctly + retrievedTokens1, err := service.GetDomainVerificationStatus(domain1) + require.NoError(t, err) + require.Len(t, retrievedTokens1.PendingTokens, 1) + assert.Equal(t, token1.Token, retrievedTokens1.PendingTokens[0].Token) + + retrievedTokens2, err := service.GetDomainVerificationStatus(domain2) + require.NoError(t, err) + require.Len(t, retrievedTokens2.PendingTokens, 1) + assert.Equal(t, token2.Token, retrievedTokens2.PendingTokens[0].Token) +} + +func TestClaimDomain_TokenUniqueness(t *testing.T) { + // Create an in-memory database for testing + memDB := database.NewMemoryDB(make(map[string]*model.Server)) + ctx := context.Background() + + // Store a token directly in the database to simulate an existing token + existingToken := &model.VerificationToken{ + Token: "existing-token-123", + CreatedAt: time.Now(), + } + err := memDB.StoreVerificationToken(ctx, "example.com", existingToken) + require.NoError(t, err) + + // Try to store the same token again - should fail with ErrTokenAlreadyExists + duplicateToken := &model.VerificationToken{ + Token: "existing-token-123", + CreatedAt: time.Now(), + } + err = memDB.StoreVerificationToken(ctx, "another-domain.com", duplicateToken) + require.Error(t, err) + assert.Equal(t, database.ErrTokenAlreadyExists, err) + + // Store a different token - should succeed + uniqueToken := &model.VerificationToken{ + Token: "unique-token-456", + CreatedAt: time.Now(), + } + err = memDB.StoreVerificationToken(ctx, "another-domain.com", uniqueToken) + require.NoError(t, err) +} + +func TestGetDomainVerificationStatus(t *testing.T) { + // Create an in-memory database for testing + memDB := database.NewMemoryDB(make(map[string]*model.Server)) + service := NewRegistryServiceWithDB(memDB) + + domain := "example.com" + + // Test when domain does not exist + _, err := service.GetDomainVerificationStatus(domain) + require.Error(t, err) + assert.Equal(t, database.ErrNotFound, err) + + // Claim the domain (adds a pending token) + token, err := service.ClaimDomain(domain) + require.NoError(t, err) + + // Now status should be unverified with a pending token + status, err := service.GetDomainVerificationStatus(domain) + require.NoError(t, err) + assert.Nil(t, status.VerifiedToken) + assert.Len(t, status.PendingTokens, 1) + assert.Equal(t, token.Token, status.PendingTokens[0].Token) +} + +func TestClaimDomain_MaxAttempts(t *testing.T) { + // Create a mock database that always returns ErrTokenAlreadyExists for StoreVerificationToken + // to simulate the scenario where we can't find a unique token + memDB := &mockDBAlwaysNonUnique{} + service := NewRegistryServiceWithDB(memDB) + + domain := "example.com" + + // Attempt to claim domain should fail after max attempts + token, err := service.ClaimDomain(domain) + require.Error(t, err) + assert.Nil(t, token) + assert.Equal(t, database.ErrMaxAttemptsExceeded, err) +} + +// mockDBAlwaysNonUnique is a mock database that always returns ErrTokenAlreadyExists for StoreVerificationToken +type mockDBAlwaysNonUnique struct{} + +func (m *mockDBAlwaysNonUnique) List(ctx context.Context, filter map[string]any, cursor string, limit int) ([]*model.Server, string, error) { + return nil, "", nil +} + +func (m *mockDBAlwaysNonUnique) GetByID(ctx context.Context, id string) (*model.ServerDetail, error) { + return nil, database.ErrNotFound +} + +func (m *mockDBAlwaysNonUnique) Publish(ctx context.Context, serverDetail *model.ServerDetail) error { + return nil +} + +func (m *mockDBAlwaysNonUnique) StoreVerificationToken(ctx context.Context, domain string, token *model.VerificationToken) error { + // Always return ErrTokenAlreadyExists to simulate that tokens are never unique + return database.ErrTokenAlreadyExists +} + +func (m *mockDBAlwaysNonUnique) GetVerificationTokens(ctx context.Context, domain string) (*model.VerificationTokens, error) { + return nil, database.ErrNotFound +} + +func (m *mockDBAlwaysNonUnique) ImportSeed(ctx context.Context, seedFilePath string) error { + return nil +} + +func (m *mockDBAlwaysNonUnique) Close() error { + return nil +} diff --git a/internal/service/service.go b/internal/service/service.go index a3e14019..3bf19d4a 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -7,4 +7,6 @@ type RegistryService interface { List(cursor string, limit int) ([]model.Server, string, error) GetByID(id string) (*model.ServerDetail, error) Publish(serverDetail *model.ServerDetail) error + ClaimDomain(domain string) (*model.VerificationToken, error) + GetDomainVerificationStatus(domain string) (*model.VerificationTokens, error) } diff --git a/internal/verification/README.md b/internal/verification/README.md new file mode 100644 index 00000000..987d05d6 --- /dev/null +++ b/internal/verification/README.md @@ -0,0 +1,414 @@ +# Domain Verification Package + +This package provides cryptographically secure token generation and DNS verification for domain ownership verification in the MCP Registry. It implements the requirements specified in the Server Name Verification system. + +## Overview + +The verification package generates 128-bit cryptographically secure random tokens used for proving domain ownership through two verification methods: + +1. **DNS TXT Record Verification**: Add `mcp-verify=` to your domain's DNS +2. **HTTP-01 Web Challenge**: Serve the token at `https://domain/.well-known/mcp-challenge/` + +## Functions + +### Token Generation + +#### GenerateVerificationToken() + +Generates a cryptographically secure 128-bit random token encoded in base64url format. + +```go +token, err := verification.GenerateVerificationToken() +if err != nil { + return fmt.Errorf("failed to generate token: %w", err) +} +// token: "TBeVXe_X4npM6p8vpzStnA" (22 characters) +``` + +**Features:** +- Uses `crypto/rand` for cryptographically secure randomness +- 128 bits (16 bytes) of entropy +- Base64url encoding (URL-safe and DNS-safe) +- No padding characters +- 22-character output length + +#### GenerateTokenWithInfo() + +Generates a token with additional metadata about how to use it. + +```go +tokenInfo, err := verification.GenerateTokenWithInfo() +if err != nil { + return fmt.Errorf("failed to generate token info: %w", err) +} + +fmt.Printf("Token: %s\n", tokenInfo.Token) +fmt.Printf("DNS Record: %s\n", tokenInfo.DNSRecord) +fmt.Printf("HTTP Path: %s\n", tokenInfo.HTTPPath) +``` + +**Output:** +``` +Token: TBeVXe_X4npM6p8vpzStnA +DNS Record: mcp-verify=TBeVXe_X4npM6p8vpzStnA +HTTP Path: /.well-known/mcp-challenge/TBeVXe_X4npM6p8vpzStnA +``` + +### DNS Verification + +#### VerifyDNSRecord(domain, expectedToken string) + +Verifies domain ownership by checking for a specific TXT record containing the expected verification token. + +```go +result, err := verification.VerifyDNSRecord("example.com", "TBeVXe_X4npM6p8vpzStnA") +if err != nil { + log.Printf("DNS verification error: %v", err) + return err +} + +if result.Success { + log.Printf("Domain %s verified successfully", result.Domain) +} else { + log.Printf("Domain %s verification failed: %s", result.Domain, result.Message) +} +``` + +**Features:** +- Queries DNS TXT records for verification tokens +- Uses secure DNS resolvers (8.8.8.8, 1.1.1.1) by default +- Implements retry logic with exponential backoff +- Supports custom DNS resolver configuration +- Validates token format before verification +- Comprehensive error handling and logging + +#### VerifyDNSRecordWithConfig(domain, expectedToken string, config *DNSVerificationConfig) + +Performs DNS verification with custom configuration. + +```go +config := &verification.DNSVerificationConfig{ + Timeout: 5 * time.Second, + MaxRetries: 2, + RetryDelay: 1 * time.Second, + UseSecureResolvers: true, + CustomResolvers: []string{"8.8.8.8:53", "1.1.1.1:53"}, + RecordPrefix: "my-custom-prefix", // Custom prefix instead of "mcp-verify" +} + +result, err := verification.VerifyDNSRecordWithConfig("example.com", token, config) +``` + +#### DefaultDNSConfig() + +Returns the default configuration for DNS verification. + +```go +config := verification.DefaultDNSConfig() +// Returns: &DNSVerificationConfig{ +// Timeout: 10 * time.Second, +// MaxRetries: 3, +// RetryDelay: 1 * time.Second, +// UseSecureResolvers: true, +// CustomResolvers: []string{"8.8.8.8:53", "1.1.1.1:53"}, +// RecordPrefix: "mcp-verify", +// } +``` + +## Types and Structures + +### DNSVerificationResult + +```go +type DNSVerificationResult struct { + Success bool `json:"success"` + Domain string `json:"domain"` + Token string `json:"token"` + Message string `json:"message"` + TXTRecords []string `json:"txt_records,omitempty"` + Duration string `json:"duration"` +} +``` + +### DNSVerificationConfig + +```go +type DNSVerificationConfig struct { + Timeout time.Duration // Default: 10 seconds + MaxRetries int // Default: 3 + RetryDelay time.Duration // Default: 1 second + UseSecureResolvers bool // Default: true + CustomResolvers []string // Default: ["8.8.8.8:53", "1.1.1.1:53"] + RecordPrefix string // Default: "mcp-verify" +} +``` + +### DNSVerificationError + +```go +type DNSVerificationError struct { + Domain string + Token string + Message string + Cause error +} +``` + +## Security Considerations + +### Cryptographic Security +- Uses `crypto/rand` which provides cryptographically secure random numbers +- 128 bits provides 2^128 possible values (negligible collision probability) +- Suitable for cryptographic applications requiring unpredictable tokens + +### DNS Security +- Uses secure DNS resolvers (8.8.8.8, 1.1.1.1) by default to prevent DNS spoofing +- Implements retry logic for transient DNS failures +- Validates domain ownership through industry-standard DNS TXT records +- Supports DNSSEC-aware resolvers + +### Token Properties +- **Single-use**: Tokens should be used only once for verification +- **Time-limited**: Implement appropriate expiration policies +- **Secure transmission**: Always use HTTPS when transmitting tokens +- **Secure storage**: Store tokens securely on both client and server side + +## Usage Examples + +### Complete DNS Verification Workflow + +```go +package main + +import ( + "errors" + "fmt" + "log" + "time" + "github.com/modelcontextprotocol/registry/internal/verification" +) + +func verifyDomainOwnership(domain string) error { + // 1. Generate verification token + tokenInfo, err := verification.GenerateTokenWithInfo() + if err != nil { + return fmt.Errorf("failed to generate token: %w", err) + } + + // 2. Instruct user to add DNS record + fmt.Printf("Add this TXT record to %s:\n", domain) + fmt.Printf("Name: %s\n", domain) + fmt.Printf("Type: TXT\n") + fmt.Printf("Value: %s\n", tokenInfo.DNSRecord) + fmt.Println("Press Enter after adding the DNS record...") + fmt.Scanln() + + // 3. Verify the DNS record + result, err := verification.VerifyDNSRecord(domain, tokenInfo.Token) + if err != nil { + return fmt.Errorf("DNS verification failed: %w", err) + } + + if result.Success { + log.Printf("โœ… Domain %s verified successfully!", domain) + log.Printf("Verification completed in %s", result.Duration) + return nil + } else { + return fmt.Errorf("โŒ Domain verification failed: %s", result.Message) + } +} +``` + +### Custom DNS Configuration + +```go +func verifyWithCustomConfig(domain, token string) error { + config := &verification.DNSVerificationConfig{ + Timeout: 5 * time.Second, + MaxRetries: 2, + RetryDelay: 500 * time.Millisecond, + UseSecureResolvers: true, + CustomResolvers: []string{"1.1.1.1:53", "8.8.8.8:53"}, + } + + result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + if err != nil { + return err + } + + log.Printf("Verification result: %+v", result) + return nil +} +``` + +### Custom Record Prefix + +You can configure a custom DNS record prefix instead of the default "mcp-verify": + +```go +func verifyWithCustomPrefix(domain, token string) error { + config := verification.DefaultDNSConfig() + config.RecordPrefix = "my-service-verify" + + // This will look for: my-service-verify= + result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + if err != nil { + return err + } + + if result.Success { + log.Printf("โœ… Domain verified with custom prefix!") + } else { + log.Printf("โŒ Verification failed: %s", result.Message) + } + + return nil +} +``` + +### Error Handling and Retry Logic + +```go +func robustDNSVerification(domain, token string) error { + maxAttempts := 3 + + for attempt := 1; attempt <= maxAttempts; attempt++ { + log.Printf("DNS verification attempt %d/%d for domain %s", attempt, maxAttempts, domain) + + result, err := verification.VerifyDNSRecord(domain, token) + if err != nil { + var dnsErr *verification.DNSVerificationError + if errors.As(err, &dnsErr) { + log.Printf("DNS error: %s", dnsErr.Message) + if attempt < maxAttempts { + time.Sleep(time.Duration(attempt) * time.Second) + continue + } + } + return err + } + + if result.Success { + log.Printf("โœ… Domain verified on attempt %d", attempt) + return nil + } + + log.Printf("โŒ Verification failed: %s", result.Message) + if attempt < maxAttempts { + time.Sleep(time.Duration(attempt) * time.Second) + } + } + + return fmt.Errorf("domain verification failed after %d attempts", maxAttempts) +} +``` + +### HTTP-01 Challenge Setup +```go +func setupHTTPChallenge(domain string) error { + token, err := verification.GenerateVerificationToken() + if err != nil { + return err + } + + fmt.Printf("Serve the token at: https://%s/.well-known/mcp-challenge/%s\n", domain, token) + fmt.Printf("Content: %s\n", token) + + return nil +} +``` + +### Token String Comparison +```go +func validateUserToken(userToken, expectedToken string) bool { + // For verification, simply compare the token strings + // No format validation needed - just string comparison + return userToken == expectedToken +} +``` + +## Constants + +- `TokenLength`: 16 bytes (128 bits) - the entropy size of generated tokens + +## Error Handling + +### DNS Verification Errors + +The DNS verification functions can return various types of errors: + +- **Input validation errors**: Invalid domain or token format +- **Network errors**: DNS resolution failures, timeouts +- **Verification errors**: Token not found in DNS records + +```go +result, err := verification.VerifyDNSRecord(domain, token) +if err != nil { + var dnsErr *verification.DNSVerificationError + if errors.As(err, &dnsErr) { + log.Printf("DNS verification failed for domain %s: %s", dnsErr.Domain, dnsErr.Message) + if dnsErr.Cause != nil { + log.Printf("Underlying cause: %v", dnsErr.Cause) + } + } else { + log.Printf("Unexpected error: %v", err) + } + return err +} +``` + +### Token Generation Errors +The function returns errors in the following case: + +- `GenerateVerificationToken()`: When the system's entropy source is unavailable + +Always check for errors and handle them appropriately: + +```go +token, err := verification.GenerateVerificationToken() +if err != nil { + log.Printf("Failed to generate verification token: %v", err) + // Handle error appropriately (retry, fallback, etc.) + return err +} +``` + +## Performance + +The DNS verification system is designed for real-world performance: + +- **Token generation**: Sub-microsecond performance +- **DNS queries**: Typically 10-100ms depending on network conditions +- **Retry logic**: Exponential backoff prevents overwhelming DNS servers +- **Concurrent verification**: Safe for use in goroutines + +## Testing + +The package includes comprehensive tests covering: + +- Token generation and uniqueness +- Entropy validation (exactly 128 bits) +- Format validation +- URL and DNS safety +- DNS verification functionality +- Error handling scenarios +- Performance benchmarks + +Run tests with: +```bash +go test ./internal/verification -v +go test ./internal/verification -bench=. +``` + +## Integration + +This package is designed to integrate with the MCP Registry's domain verification system as specified in `server-name-verification.md`. It provides both token generation and DNS verification capabilities required for the dual-method verification approach. + +### Integration Points + +1. **Registry API**: Use for generating tokens when users claim domain namespaces +2. **Background verification**: Use for continuous verification of existing domains +3. **CLI tools**: Use for domain verification during package publishing +4. **Admin tools**: Use for debugging verification issues + +```` diff --git a/internal/verification/dns.go b/internal/verification/dns.go new file mode 100644 index 00000000..9dc1c2e2 --- /dev/null +++ b/internal/verification/dns.go @@ -0,0 +1,332 @@ +package verification + +import ( + "context" + "errors" + "fmt" + "log" + "net" + "strings" + "time" +) + +// DNSVerificationError represents errors that can occur during DNS verification +type DNSVerificationError struct { + Domain string + Token string + Message string + Cause error +} + +func (e *DNSVerificationError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("DNS verification failed for domain %s: %s (caused by: %v)", e.Domain, e.Message, e.Cause) + } + return fmt.Sprintf("DNS verification failed for domain %s: %s", e.Domain, e.Message) +} + +func (e *DNSVerificationError) Unwrap() error { + return e.Cause +} + +// DNSVerificationResult represents the result of a DNS verification attempt +type DNSVerificationResult struct { + Success bool `json:"success"` + Domain string `json:"domain"` + Token string `json:"token"` + Message string `json:"message"` + TXTRecords []string `json:"txt_records,omitempty"` + Duration string `json:"duration"` +} + +// DNSVerificationConfig holds configuration for DNS verification +type DNSVerificationConfig struct { + // Timeout for DNS queries (default: 10 seconds) + Timeout time.Duration + + // MaxRetries for transient failures (default: 3) + MaxRetries int + + // RetryDelay base delay between retries (default: 1 second) + RetryDelay time.Duration + + // UseSecureResolvers enables use of secure DNS resolvers + UseSecureResolvers bool + + // CustomResolvers allows specifying custom DNS servers + CustomResolvers []string + + // RecordPrefix specifies the prefix for DNS TXT records (default: "mcp-verify") + RecordPrefix string + + // Resolver allows injecting a custom DNS resolver (primarily for testing) + Resolver DNSResolver +} + +// DefaultDNSConfig returns the default configuration for DNS verification +func DefaultDNSConfig() *DNSVerificationConfig { + return &DNSVerificationConfig{ + Timeout: 10 * time.Second, + MaxRetries: 3, + RetryDelay: 1 * time.Second, + UseSecureResolvers: true, + CustomResolvers: []string{"8.8.8.8:53", "1.1.1.1:53"}, // Google and Cloudflare DNS + RecordPrefix: "mcp-verify", + } +} + +// VerifyDNSRecord verifies domain ownership by checking for a specific TXT record +// containing the expected verification token. +// +// This function implements the DNS TXT record verification method described in +// the Server Name Verification system. It looks for a TXT record with the format: +// "=" where prefix defaults to "mcp-verify" +// +// Security considerations: +// - Uses secure DNS resolvers to prevent spoofing attacks +// - Implements retry logic with exponential backoff for transient failures +// - Validates token format before verification +// - Logs all verification attempts for audit purposes +// +// Parameters: +// - domain: The domain name to verify (e.g., "example.com") +// - expectedToken: The 128-bit token that should be present in the DNS record +// +// Returns: +// - DNSVerificationResult with verification status and details +// - An error if the verification process fails critically +// +// The default configuration uses "mcp-verify" as the record prefix. To use a custom +// prefix, use VerifyDNSRecordWithConfig with a configured DNSVerificationConfig. +// +// Example usage: +// +// result, err := VerifyDNSRecord("example.com", "TBeVXe_X4npM6p8vpzStnA") +// if err != nil { +// log.Printf("DNS verification error: %v", err) +// return err +// } +// if result.Success { +// log.Printf("Domain %s verified successfully", result.Domain) +// } else { +// log.Printf("Domain %s verification failed: %s", result.Domain, result.Message) +// } +func VerifyDNSRecord(domain, expectedToken string) (*DNSVerificationResult, error) { + return VerifyDNSRecordWithConfig(domain, expectedToken, DefaultDNSConfig()) +} + +// VerifyDNSRecordWithConfig performs DNS verification with custom configuration +func VerifyDNSRecordWithConfig(domain, expectedToken string, config *DNSVerificationConfig) (*DNSVerificationResult, error) { + startTime := time.Now() + + // Input validation + if domain == "" { + return nil, &DNSVerificationError{ + Domain: domain, + Token: expectedToken, + Message: "domain cannot be empty", + } + } + + if expectedToken == "" { + return nil, &DNSVerificationError{ + Domain: domain, + Token: expectedToken, + Message: "token cannot be empty", + } + } + + // Validate token format + if !ValidateTokenFormat(expectedToken) { + return nil, &DNSVerificationError{ + Domain: domain, + Token: expectedToken, + Message: "invalid token format", + } + } + + // Normalize domain (remove trailing dots, convert to lowercase) + domain = strings.ToLower(strings.TrimSuffix(domain, ".")) + + log.Printf("Starting DNS verification for domain: %s with token: %s", domain, expectedToken) + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), config.Timeout) + defer cancel() + + // Perform verification with retries + result, err := performDNSVerificationWithRetries(ctx, domain, expectedToken, config) + + // Calculate duration + duration := time.Since(startTime) + if result != nil { + result.Duration = duration.String() + } + + log.Printf("DNS verification completed for domain %s in %v: success=%t", + domain, duration, result != nil && result.Success) + + return result, err +} + +// performDNSVerificationWithRetries implements the retry logic for DNS verification +func performDNSVerificationWithRetries( + ctx context.Context, + domain, expectedToken string, + config *DNSVerificationConfig, +) (*DNSVerificationResult, error) { + var lastErr error + var lastResult *DNSVerificationResult + + retryDelay := config.RetryDelay + + for attempt := 0; attempt <= config.MaxRetries; attempt++ { + if attempt > 0 { + log.Printf("DNS verification retry %d/%d for domain %s after %v delay", + attempt+1, config.MaxRetries, domain, retryDelay) + + // Wait before retry with context cancellation support + timer := time.NewTimer(retryDelay) + select { + case <-timer.C: + // Timer fired normally, continue with retry + case <-ctx.Done(): + // Context canceled, stop timer to prevent leak + timer.Stop() + return nil, &DNSVerificationError{ + Domain: domain, + Token: expectedToken, + Message: "verification canceled", + Cause: ctx.Err(), + } + } + + // Exponential backoff + retryDelay *= 2 + } + + result, err := performDNSVerification(ctx, domain, expectedToken, config) + if err == nil { + return result, nil + } + + lastErr = err + lastResult = result + + // Check if error is retryable + if !IsRetryableDNSError(err) { + log.Printf("Non-retryable DNS error for domain %s: %v", domain, err) + break + } + + log.Printf("Retryable DNS error for domain %s (attempt %d/%d): %v", + domain, attempt+1, config.MaxRetries, err) + } + + // All retries exhausted + return lastResult, lastErr +} + +// performDNSVerification performs a single DNS verification attempt +func performDNSVerification(ctx context.Context, domain, expectedToken string, config *DNSVerificationConfig) (*DNSVerificationResult, error) { + // Get resolver (either injected or create default) + var resolver DNSResolver + if config.Resolver != nil { + resolver = config.Resolver + } else { + resolver = NewDefaultDNSResolver(config) + } + + // Query TXT records + txtRecords, err := resolver.LookupTXT(ctx, domain) + if err != nil { + dnsErr := &DNSVerificationError{ + Domain: domain, + Token: expectedToken, + Message: "failed to query DNS TXT records", + Cause: err, + } + + result := &DNSVerificationResult{ + Success: false, + Domain: domain, + Token: expectedToken, + Message: dnsErr.Message, + } + + return result, dnsErr + } + + log.Printf("Found %d TXT records for domain %s", len(txtRecords), domain) + + // Check for verification token + expectedRecord := fmt.Sprintf("%s=%s", config.RecordPrefix, expectedToken) + + for _, record := range txtRecords { + log.Printf("Checking TXT record: %s", record) + if record == expectedRecord { + result := &DNSVerificationResult{ + Success: true, + Domain: domain, + Token: expectedToken, + Message: "domain verification successful", + TXTRecords: txtRecords, + } + + log.Printf("DNS verification successful for domain %s", domain) + return result, nil + } + } + + // Token not found + result := &DNSVerificationResult{ + Success: false, + Domain: domain, + Token: expectedToken, + Message: fmt.Sprintf("verification token not found in DNS TXT records (expected: %s)", expectedRecord), + TXTRecords: txtRecords, + } + + log.Printf("DNS verification failed for domain %s: token not found", domain) + return result, nil +} + +// IsRetryableDNSError determines if a DNS error should be retried +func IsRetryableDNSError(err error) bool { + if err == nil { + return false + } + + // Use iterative approach to prevent stack overflow with deeply nested errors + const maxIterations = 100 + iterationCount := 0 + for err != nil { + // Prevent infinite loop in case of circular error chain + if iterationCount >= maxIterations { + log.Printf("Exceeded maximum error unwrapping iterations (%d); possible circular error chain", maxIterations) + return false + } + iterationCount++ + // Check for temporary network errors + var netErr *net.OpError + if errors.As(err, &netErr) { + return netErr.Temporary() + } + + // Check for context timeout (might be temporary) + if errors.Is(err, context.DeadlineExceeded) { + return true + } + + // Check for DNS-specific temporary failures + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return dnsErr.Temporary() + } + + // Move to next error in chain + err = errors.Unwrap(err) + } + + return false +} diff --git a/internal/verification/dns_mock.go b/internal/verification/dns_mock.go new file mode 100644 index 00000000..7afbd934 --- /dev/null +++ b/internal/verification/dns_mock.go @@ -0,0 +1,99 @@ +package verification + +import ( + "context" + "fmt" + "time" +) + +// MockDNSResolver implements DNSResolver for testing +type MockDNSResolver struct { + // TXTRecords maps domain names to their TXT records + TXTRecords map[string][]string + + // Errors maps domain names to errors that should be returned + Errors map[string]error + + // Delay simulates DNS query latency + Delay time.Duration + + // CallCount tracks how many times LookupTXT was called + CallCount int + + // LastDomain tracks the last domain that was queried + LastDomain string +} + +func (m *MockDNSResolver) LookupTXT(ctx context.Context, name string) ([]string, error) { + m.CallCount++ + m.LastDomain = name + + // Simulate delay if configured + if m.Delay > 0 { + timer := time.NewTimer(m.Delay) + defer timer.Stop() + select { + case <-timer.C: + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + // Return error if configured for this domain + if err, exists := m.Errors[name]; exists { + return nil, err + } + + // Return TXT records if configured + if records, exists := m.TXTRecords[name]; exists { + return records, nil + } + + // Default: return empty records (domain exists but no TXT records) + return []string{}, nil +} + +// Reset clears all state in the mock resolver +func (m *MockDNSResolver) Reset() { + m.CallCount = 0 + m.LastDomain = "" + if m.TXTRecords != nil { + for k := range m.TXTRecords { + delete(m.TXTRecords, k) + } + } + if m.Errors != nil { + for k := range m.Errors { + delete(m.Errors, k) + } + } +} + +// SetTXTRecord sets a TXT record for a domain +func (m *MockDNSResolver) SetTXTRecord(domain string, records ...string) { + if m.TXTRecords == nil { + m.TXTRecords = make(map[string][]string) + } + m.TXTRecords[domain] = records +} + +// SetError sets an error to be returned for a domain +func (m *MockDNSResolver) SetError(domain string, err error) { + if m.Errors == nil { + m.Errors = make(map[string]error) + } + m.Errors[domain] = err +} + +// SetVerificationToken is a convenience method to set up a valid verification token +func (m *MockDNSResolver) SetVerificationToken(domain, token string) { + m.SetTXTRecord(domain, fmt.Sprintf("mcp-verify=%s", token)) +} + +// NewMockDNSResolver creates a new mock DNS resolver +func NewMockDNSResolver() *MockDNSResolver { + return &MockDNSResolver{ + TXTRecords: make(map[string][]string), + Errors: make(map[string]error), + } +} diff --git a/internal/verification/dns_mock_test.go b/internal/verification/dns_mock_test.go new file mode 100644 index 00000000..e2433859 --- /dev/null +++ b/internal/verification/dns_mock_test.go @@ -0,0 +1,197 @@ +package verification_test + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/modelcontextprotocol/registry/internal/verification" +) + +const testDomain = "example.com" + +func TestVerifyDNSRecordWithMockSuccess(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + mockResolver := verification.NewMockDNSResolver() + mockResolver.SetVerificationToken(testDomain, token) + + config := verification.DefaultDNSConfig() + config.Resolver = mockResolver + config.Timeout = 1 * time.Second + + result, err := verification.VerifyDNSRecordWithConfig(testDomain, token, config) + + if err != nil { + t.Errorf("VerifyDNSRecord returned unexpected error: %v", err) + } + + if result == nil { + t.Fatal("VerifyDNSRecord returned nil result") + } + + if !result.Success { + t.Errorf("Expected success=true, got success=%t, message=%s", result.Success, result.Message) + } + + if result.Domain != testDomain { + t.Errorf("Result domain = %s, want %s", result.Domain, testDomain) + } + + if result.Token != token { + t.Errorf("Result token = %s, want %s", result.Token, token) + } + + if mockResolver.CallCount != 1 { + t.Errorf("Expected 1 DNS call, got %d", mockResolver.CallCount) + } + + if mockResolver.LastDomain != testDomain { + t.Errorf("Expected query for %s, got %s", testDomain, mockResolver.LastDomain) + } +} + +func TestVerifyDNSRecordWithMockTokenNotFound(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + mockResolver := verification.NewMockDNSResolver() + mockResolver.SetTXTRecord(testDomain, "v=spf1 -all", "some-other-record") + + config := verification.DefaultDNSConfig() + config.Resolver = mockResolver + + result, err := verification.VerifyDNSRecordWithConfig(testDomain, token, config) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if result == nil { + t.Fatal("Expected result but got nil") + } + + if result.Success { + t.Error("Expected verification to fail") + } + + if !strings.Contains(result.Message, "verification token not found") { + t.Errorf("Expected 'token not found' message, got: %s", result.Message) + } + + if len(result.TXTRecords) != 2 { + t.Errorf("Expected 2 TXT records, got %d", len(result.TXTRecords)) + } +} + +func TestVerifyDNSRecordWithMockDNSError(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + mockResolver := verification.NewMockDNSResolver() + mockResolver.SetError(testDomain, &net.DNSError{ + Err: "no such host", + Name: testDomain, + Server: "8.8.8.8:53", + IsTimeout: false, + IsTemporary: false, + }) + + config := verification.DefaultDNSConfig() + config.Resolver = mockResolver + config.MaxRetries = 0 + + result, err := verification.VerifyDNSRecordWithConfig(testDomain, token, config) + + var dnsErr *verification.DNSVerificationError + if !errors.As(err, &dnsErr) { + t.Errorf("Expected DNSVerificationError, got: %T", err) + } + + if result == nil { + t.Fatal("Expected result even on error") + } + + if result.Success { + t.Error("Expected verification to fail") + } + + if !strings.Contains(result.Message, "failed to query DNS TXT records") { + t.Errorf("Expected DNS query failure message, got: %s", result.Message) + } +} + +func TestVerifyDNSRecordWithMockTimeout(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + mockResolver := verification.NewMockDNSResolver() + mockResolver.Delay = 200 * time.Millisecond + mockResolver.SetVerificationToken(testDomain, token) + + config := verification.DefaultDNSConfig() + config.Resolver = mockResolver + config.Timeout = 50 * time.Millisecond + config.MaxRetries = 0 + + _, err = verification.VerifyDNSRecordWithConfig(testDomain, token, config) + + if err == nil { + t.Error("Expected timeout error") + } + + if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "timeout") { + t.Errorf("Expected timeout-related error, got: %v", err) + } +} + +func TestMockDNSResolverHelperMethods(t *testing.T) { + mock := verification.NewMockDNSResolver() + + token := "test-token-123" + mock.SetVerificationToken(testDomain, token) + + records, err := mock.LookupTXT(context.Background(), testDomain) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + expected := fmt.Sprintf("mcp-verify=%s", token) + if len(records) != 1 || records[0] != expected { + t.Errorf("Expected [%s], got %v", expected, records) + } + + mock.CallCount = 5 + mock.LastDomain = "test.com" + mock.Reset() + + if mock.CallCount != 0 { + t.Errorf("Expected CallCount=0 after reset, got %d", mock.CallCount) + } + + if mock.LastDomain != "" { + t.Errorf("Expected LastDomain='' after reset, got %s", mock.LastDomain) + } + + records, err = mock.LookupTXT(context.Background(), testDomain) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(records) != 0 { + t.Errorf("Expected no records after reset, got %v", records) + } +} diff --git a/internal/verification/dns_resolver.go b/internal/verification/dns_resolver.go new file mode 100644 index 00000000..ae79d86c --- /dev/null +++ b/internal/verification/dns_resolver.go @@ -0,0 +1,52 @@ +package verification + +import ( + "context" + "fmt" + "net" +) + +// DNSResolver interface allows for dependency injection and testing +type DNSResolver interface { + LookupTXT(ctx context.Context, name string) ([]string, error) +} + +// DefaultDNSResolver wraps net.Resolver to implement our interface +type DefaultDNSResolver struct { + resolver *net.Resolver +} + +func (d *DefaultDNSResolver) LookupTXT(ctx context.Context, name string) ([]string, error) { + return d.resolver.LookupTXT(ctx, name) +} + +// NewDefaultDNSResolver creates a DNS resolver with the given configuration +// +//nolint:ireturn // Factory function returning interface is acceptable for dependency injection +func NewDefaultDNSResolver(config *DNSVerificationConfig) DNSResolver { + if config.UseSecureResolvers && len(config.CustomResolvers) > 0 { + // Create custom dialer for secure resolvers + dialer := &net.Dialer{ + Timeout: config.Timeout, + } + + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + // Use first available custom resolver + for _, resolverAddr := range config.CustomResolvers { + conn, err := dialer.DialContext(ctx, network, resolverAddr) + if err == nil { + return conn, nil + } + } + return nil, fmt.Errorf("all custom DNS resolvers failed") + }, + } + + return &DefaultDNSResolver{resolver: resolver} + } + + // Use system default resolver + return &DefaultDNSResolver{resolver: net.DefaultResolver} +} diff --git a/internal/verification/dns_test.go b/internal/verification/dns_test.go new file mode 100644 index 00000000..709a99fa --- /dev/null +++ b/internal/verification/dns_test.go @@ -0,0 +1,452 @@ +package verification_test + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/modelcontextprotocol/registry/internal/verification" +) + +func TestVerifyDNSRecordSuccess(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + domain := testDomain + + // Create mock resolver with the verification token + mockResolver := verification.NewMockDNSResolver() + mockResolver.SetVerificationToken(domain, token) + + // Use custom config with mock resolver + config := verification.DefaultDNSConfig() + config.Resolver = mockResolver + + result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + if err != nil { + t.Errorf("VerifyDNSRecord returned unexpected error: %v", err) + } + + if result == nil { + t.Fatal("VerifyDNSRecord returned nil result") + } + + if !result.Success { + t.Errorf("Expected successful verification, got: %s", result.Message) + } + + if result.Domain != domain { + t.Errorf("Result domain = %s, want %s", result.Domain, domain) + } + + if result.Token != token { + t.Errorf("Result token = %s, want %s", result.Token, token) + } + + // Verify the mock was called + if mockResolver.CallCount != 1 { + t.Errorf("Expected 1 DNS call, got %d", mockResolver.CallCount) + } + + if mockResolver.LastDomain != domain { + t.Errorf("Expected DNS query for %s, got %s", domain, mockResolver.LastDomain) + } + + t.Logf("DNS verification result: %+v", result) +} + +func TestVerifyDNSRecordTokenNotFound(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + domain := testDomain + + // Create mock resolver with different TXT records (no verification token) + mockResolver := verification.NewMockDNSResolver() + mockResolver.SetTXTRecord(domain, "v=spf1 -all", "some-other-record") + + // Use custom config with mock resolver + config := verification.DefaultDNSConfig() + config.Resolver = mockResolver + + result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + if err != nil { + t.Errorf("VerifyDNSRecord returned unexpected error: %v", err) + } + + if result == nil { + t.Fatal("VerifyDNSRecord returned nil result") + } + + if result.Success { + t.Error("Expected verification to fail when token is not found") + } + + if !strings.Contains(result.Message, "verification token not found") { + t.Errorf("Expected 'token not found' message, got: %s", result.Message) + } + + if result.Domain != domain { + t.Errorf("Result domain = %s, want %s", result.Domain, domain) + } + + if result.Token != token { + t.Errorf("Result token = %s, want %s", result.Token, token) + } + + // Verify TXT records are included in result + if len(result.TXTRecords) != 2 { + t.Errorf("Expected 2 TXT records in result, got %d", len(result.TXTRecords)) + } + + t.Logf("DNS verification result: %+v", result) +} + +func TestVerifyDNSRecordInvalidInputs(t *testing.T) { + tests := []struct { + name string + domain string + token string + expectError bool + errorContains string + }{ + { + name: "empty domain", + domain: "", + token: "validtoken123456789012", + expectError: true, + errorContains: "domain cannot be empty", + }, + { + name: "empty token", + domain: testDomain, + token: "", + expectError: true, + errorContains: "token cannot be empty", + }, + { + name: "invalid token format", + domain: testDomain, + token: "invalid-token!@#", + expectError: true, + errorContains: "invalid token format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := verification.VerifyDNSRecord(tt.domain, tt.token) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + } else if !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("Error message %q does not contain %q", err.Error(), tt.errorContains) + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + + t.Logf("Result: %+v, Error: %v", result, err) + }) + } +} + +func TestVerifyDNSRecordTokenFormatValidation(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + domain := testDomain + + // Create mock resolver with the verification token + mockResolver := verification.NewMockDNSResolver() + mockResolver.SetVerificationToken(domain, token) + + // Use custom config with mock resolver + config := verification.DefaultDNSConfig() + config.Resolver = mockResolver + + result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + + if err != nil { + var dnsErr *verification.DNSVerificationError + if errors.As(err, &dnsErr) { + if strings.Contains(dnsErr.Message, "invalid token format") { + t.Errorf("Unexpected token format validation error: %v", err) + } + } + } + + if result == nil { + t.Fatal("Expected result but got nil") + } + + if !result.Success { + t.Errorf("Expected successful verification, got: %s", result.Message) + } + + if result.Domain != domain { + t.Errorf("Result domain = %s, want %s", result.Domain, domain) + } + + if result.Token != token { + t.Errorf("Result token = %s, want %s", result.Token, token) + } +} + +func TestVerifyDNSRecordWithConfigTimeout(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + // Create mock resolver that simulates a timeout + mockResolver := verification.NewMockDNSResolver() + mockResolver.Delay = 200 * time.Millisecond // Longer than the config timeout + + config := &verification.DNSVerificationConfig{ + Timeout: 100 * time.Millisecond, + MaxRetries: 0, + RetryDelay: 0, + UseSecureResolvers: false, + CustomResolvers: []string{}, + Resolver: mockResolver, + } + + domain := testDomain + result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + + if err == nil { + t.Error("Expected timeout error but got none") + } else { + t.Logf("DNS query failed as expected: %v", err) + // Verify it's a context timeout error + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Errorf("Expected context deadline exceeded error, got: %v", err) + } + } + + if result == nil { + t.Fatal("Expected result but got nil") + } + + if result.Duration == "" { + t.Error("Expected duration to be populated") + } + + t.Logf("Verification completed in: %s", result.Duration) +} + +func TestDefaultDNSConfig(t *testing.T) { + config := verification.DefaultDNSConfig() + + if config == nil { + t.Fatal("DefaultDNSConfig returned nil") + } + + if config.Timeout <= 0 { + t.Error("Default timeout should be positive") + } + + if config.MaxRetries < 0 { + t.Error("Default max retries should be non-negative") + } + + if config.RetryDelay <= 0 { + t.Error("Default retry delay should be positive") + } + + if !config.UseSecureResolvers { + t.Error("Default should use secure resolvers") + } + + if len(config.CustomResolvers) == 0 { + t.Error("Default should have custom resolvers configured") + } + + t.Logf("Default DNS config: %+v", config) +} + +func TestVerifyDNSRecordWithCustomPrefix(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + domain := testDomain + customPrefix := "my-custom-prefix" + + // Create mock resolver with custom prefix verification token + mockResolver := verification.NewMockDNSResolver() + customRecord := fmt.Sprintf("%s=%s", customPrefix, token) + mockResolver.SetTXTRecord(domain, customRecord) + + // Use custom config with custom record prefix + config := verification.DefaultDNSConfig() + config.Resolver = mockResolver + config.RecordPrefix = customPrefix + + result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + if err != nil { + t.Errorf("VerifyDNSRecord returned unexpected error: %v", err) + } + + if result == nil { + t.Fatal("VerifyDNSRecord returned nil result") + } + + if !result.Success { + t.Errorf("Expected successful verification with custom prefix, got: %s", result.Message) + } + + if result.Domain != domain { + t.Errorf("Result domain = %s, want %s", result.Domain, domain) + } + + if result.Token != token { + t.Errorf("Result token = %s, want %s", result.Token, token) + } + + // Verify the mock was called + if mockResolver.CallCount != 1 { + t.Errorf("Expected 1 DNS call, got %d", mockResolver.CallCount) + } + + if mockResolver.LastDomain != domain { + t.Errorf("Expected DNS query for %s, got %s", domain, mockResolver.LastDomain) + } + + t.Logf("DNS verification with custom prefix '%s' successful: %+v", customPrefix, result) +} + +func TestVerifyDNSRecordCustomPrefixFailsWithWrongRecord(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + domain := testDomain + customPrefix := "my-custom-prefix" + + // Create mock resolver with default prefix (should fail with custom prefix config) + mockResolver := verification.NewMockDNSResolver() + defaultRecord := fmt.Sprintf("mcp-verify=%s", token) + mockResolver.SetTXTRecord(domain, defaultRecord) + + // Use custom config with custom record prefix + config := verification.DefaultDNSConfig() + config.Resolver = mockResolver + config.RecordPrefix = customPrefix + + result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + if err != nil { + t.Errorf("VerifyDNSRecord returned unexpected error: %v", err) + } + + if result == nil { + t.Fatal("VerifyDNSRecord returned nil result") + } + + if result.Success { + t.Error("Expected verification to fail when custom prefix doesn't match record") + } + + if !strings.Contains(result.Message, "verification token not found") { + t.Errorf("Expected 'token not found' message, got: %s", result.Message) + } + + t.Logf("DNS verification correctly failed with custom prefix when record has default prefix: %+v", result) +} + +func TestDNSVerificationError(t *testing.T) { + baseErr := errors.New("base network error") + dnsErr := &verification.DNSVerificationError{ + Domain: testDomain, + Token: "test-token", + Message: "DNS query failed", + Cause: baseErr, + } + + errMsg := dnsErr.Error() + if !strings.Contains(errMsg, testDomain) { + t.Errorf("Error message should contain domain: %s", errMsg) + } + + if !strings.Contains(errMsg, "DNS query failed") { + t.Errorf("Error message should contain message: %s", errMsg) + } + + if !strings.Contains(errMsg, "base network error") { + t.Errorf("Error message should contain cause: %s", errMsg) + } + + unwrapped := errors.Unwrap(dnsErr) + if !errors.Is(unwrapped, baseErr) { + t.Errorf("Unwrap should return base error, got: %v", unwrapped) + } +} + +func TestIsRetryableDNSError(t *testing.T) { + tests := []struct { + name string + err error + shouldRetry bool + }{ + { + name: "nil error", + err: nil, + shouldRetry: false, + }, + { + name: "context deadline exceeded", + err: context.DeadlineExceeded, + shouldRetry: true, + }, + { + name: "temporary DNS error", + err: &net.DNSError{Err: "server failure", IsTemporary: true}, + shouldRetry: true, + }, + { + name: "non-temporary DNS error", + err: &net.DNSError{Err: "no such host", IsTemporary: false}, + shouldRetry: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := verification.IsRetryableDNSError(tt.err) + if result != tt.shouldRetry { + t.Errorf("isRetryableDNSError(%v) = %t, want %t", tt.err, result, tt.shouldRetry) + } + }) + } +} + +func TestDNSRecordFormat(t *testing.T) { + tokenInfo, err := verification.GenerateTokenWithInfo() + if err != nil { + t.Fatalf("Failed to generate token info: %v", err) + } + + expectedFormat := "mcp-verify=" + tokenInfo.Token + if tokenInfo.DNSRecord != expectedFormat { + t.Errorf("DNS record format mismatch: got %s, want %s", tokenInfo.DNSRecord, expectedFormat) + } + + t.Logf("Expected DNS record format: %s", expectedFormat) + t.Logf("Generated DNS record format: %s", tokenInfo.DNSRecord) +} diff --git a/internal/verification/token.go b/internal/verification/token.go new file mode 100644 index 00000000..3430fe40 --- /dev/null +++ b/internal/verification/token.go @@ -0,0 +1,145 @@ +package verification + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "strings" +) + +// TokenLength defines the number of bytes for the verification token. +// 128 bits = 16 bytes provides cryptographically secure randomness +// suitable for domain ownership verification. +const TokenLength = 16 + +// TokenInfo contains a verification token and formatted strings for different verification methods +type TokenInfo struct { + // Token is the raw verification token + Token string `json:"token"` + + // DNSRecord is the formatted DNS TXT record value + DNSRecord string `json:"dns_record"` + + // HTTPPath is the formatted HTTP challenge path + HTTPPath string `json:"http_path"` +} + +// GenerateVerificationToken generates a cryptographically secure 128-bit (16 bytes) +// random token for domain ownership verification. The token is encoded using base64url +// (RFC 4648) which is both URL-safe and DNS TXT record safe. +// +// This function is designed for use in both DNS TXT record verification +// (mcp-verify=) and HTTP-01 web challenge verification +// (https://domain/.well-known/mcp-verify). +// +// Security considerations: +// - Uses crypto/rand for cryptographically secure random number generation +// - 128 bits provides 2^128 possible values, making collision probability negligible +// - Base64url encoding ensures compatibility with DNS and HTTP standards +// - Tokens should be treated as single-use and rotated regularly +// +// Returns: +// - A base64url-encoded token string suitable for verification +// - An error if the system's entropy source is unavailable +// +// Example usage: +// +// token, err := GenerateVerificationToken() +// if err != nil { +// return fmt.Errorf("failed to generate verification token: %w", err) +// } +// // Use token in DNS: mcp-verify= +// // Or HTTP: /.well-known/mcp-verify +func GenerateVerificationToken() (string, error) { + // Allocate byte slice for random data + randomBytes := make([]byte, TokenLength) + + // Generate cryptographically secure random bytes + // crypto/rand.Read uses the operating system's entropy source + _, err := rand.Read(randomBytes) + if err != nil { + return "", fmt.Errorf("failed to generate cryptographically secure random bytes: %w", err) + } + + // Encode using base64url (RFC 4648) for URL and DNS safety + // base64url encoding is URL-safe and doesn't contain characters + // that would be problematic in DNS TXT records or HTTP URLs + token := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(randomBytes) + + return token, nil +} + +// GenerateTokenWithInfo generates a verification token with additional metadata +// about how to use it for different verification methods. +// +// This function generates a token and returns it along with pre-formatted +// strings for DNS TXT records and HTTP challenge paths, making it easier +// for callers to implement verification workflows. +// +// Returns: +// - TokenInfo struct containing the token and formatted verification strings +// - An error if token generation fails +// +// Example usage: +// +// tokenInfo, err := GenerateTokenWithInfo() +// if err != nil { +// return fmt.Errorf("failed to generate token info: %w", err) +// } +// +// fmt.Printf("Add this DNS record: %s\n", tokenInfo.DNSRecord) +// fmt.Printf("Or serve content at: %s\n", tokenInfo.HTTPPath) +func GenerateTokenWithInfo() (*TokenInfo, error) { + token, err := GenerateVerificationToken() + if err != nil { + return nil, err + } + + return &TokenInfo{ + Token: token, + DNSRecord: fmt.Sprintf("mcp-verify=%s", token), + HTTPPath: fmt.Sprintf("/.well-known/mcp-challenge/%s", token), + }, nil +} + +// ValidateTokenFormat validates that a token string follows the expected format +// for MCP verification tokens (base64url encoding, no padding, 22 characters). +// +// This function verifies that: +// - Token is exactly 22 characters long (base64url encoding of 16 bytes) +// - Token contains only valid base64url characters (A-Z, a-z, 0-9, -, _) +// - Token contains no padding characters (=) +// +// Parameters: +// - token: The token string to validate +// +// Returns: +// - true if the token format is valid, false otherwise +func ValidateTokenFormat(token string) bool { + // Check length (22 characters for base64url encoding of 16 bytes) + if len(token) != 22 { + return false + } + + // Check for padding (shouldn't be present in base64url) + if strings.Contains(token, "=") { + return false + } + + // Check that all characters are valid base64url characters + for _, char := range token { + if !isValidBase64URLChar(char) { + return false + } + } + + return true +} + +// isValidBase64URLChar checks if a character is valid for base64url encoding +func isValidBase64URLChar(char rune) bool { + return (char >= 'A' && char <= 'Z') || + (char >= 'a' && char <= 'z') || + (char >= '0' && char <= '9') || + char == '-' || char == '_' +} diff --git a/internal/verification/token_test.go b/internal/verification/token_test.go new file mode 100644 index 00000000..bc466c66 --- /dev/null +++ b/internal/verification/token_test.go @@ -0,0 +1,289 @@ +package verification_test + +import ( + "encoding/base64" + "strings" + "testing" + + "github.com/modelcontextprotocol/registry/internal/verification" +) + +const ( + errMsgGenTokenIteration = "GenerateVerificationToken() error = %v, iteration %d" + errMsgGenToken = "GenerateVerificationToken() error = %v" + errMsgGenTokenNormal = "GenerateVerificationToken() should succeed in normal conditions: %v" + dnsRecordPrefix = "mcp-verify=" +) + +func TestGenerateVerificationToken(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Errorf("GenerateVerificationToken() error = %v, want nil", err) + return + } + + // Test token is not empty + if token == "" { + t.Error("GenerateVerificationToken() returned empty token") + } + + // Test token length (should be 22 characters for base64url encoding of 16 bytes) + expectedLength := 22 + if len(token) != expectedLength { + t.Errorf("GenerateVerificationToken() token length = %d, want %d", len(token), expectedLength) + } + + // Test token contains only base64url characters + for _, char := range token { + if !isValidBase64URLChar(char) { + t.Errorf("GenerateVerificationToken() token contains invalid character: %c", char) + } + } + + // Test token doesn't contain padding + if strings.Contains(token, "=") { + t.Error("GenerateVerificationToken() token should not contain padding") + } +} + +// isValidBase64URLChar checks if a character is valid for base64url encoding +func isValidBase64URLChar(char rune) bool { + return (char >= 'A' && char <= 'Z') || + (char >= 'a' && char <= 'z') || + (char >= '0' && char <= '9') || + char == '-' || char == '_' +} + +func TestGenerateVerificationTokenUniqueness(t *testing.T) { + // Generate multiple tokens and ensure they're unique + tokenCount := 1000 + tokens := make(map[string]bool) + + for i := 0; i < tokenCount; i++ { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenTokenIteration, err, i) + } + + if tokens[token] { + t.Errorf("GenerateVerificationToken() generated duplicate token: %s", token) + } + tokens[token] = true + } + + if len(tokens) != tokenCount { + t.Errorf("Expected %d unique tokens, got %d", tokenCount, len(tokens)) + } +} + +func TestGenerateVerificationTokenEntropy(t *testing.T) { + // Test that generated tokens have exactly 128 bits (16 bytes) of entropy + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenToken, err) + } + + // Decode the base64url token to verify byte length + decoded, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(token) + if err != nil { + t.Fatalf("Failed to decode token: %v", err) + } + + expectedBytes := 16 + if len(decoded) != expectedBytes { + t.Errorf("Token entropy = %d bytes, want %d bytes", len(decoded), expectedBytes) + } +} + +func TestGenerateVerificationTokenErrorHandling(t *testing.T) { + // This test verifies that the function properly wraps errors from crypto/rand + // We can't easily mock crypto/rand.Read without causing fatal errors, + // so we test the error wrapping behavior indirectly + + // Test with valid input to ensure normal operation + token, err := verification.GenerateVerificationToken() + if err != nil { + // If this fails in a normal environment, there's likely a real issue + t.Errorf(errMsgGenTokenNormal, err) + } + + if token == "" { + t.Error("GenerateVerificationToken() should return non-empty token") + } + + // The error handling is tested by the fact that our function + // properly declares error returns and wraps rand.Read errors + // This is validated by the successful compilation and the above test +} + +func TestTokenConstants(t *testing.T) { + // Test that TokenLength is exactly 16 bytes (128 bits) + expectedLength := 16 + if verification.TokenLength != expectedLength { + t.Errorf("TokenLength = %d, want %d (128 bits)", verification.TokenLength, expectedLength) + } +} + +func TestTokenURLSafety(t *testing.T) { + // Generate multiple tokens and ensure they're URL-safe + for i := 0; i < 100; i++ { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenTokenIteration, err, i) + } + + // Check that token doesn't contain URL-unsafe characters + unsafeChars := []string{"+", "/", "=", " ", "%", "&", "?", "#"} + for _, unsafe := range unsafeChars { + if strings.Contains(token, unsafe) { + t.Errorf("Token contains URL-unsafe character '%s': %s", unsafe, token) + } + } + } +} + +func TestTokenDNSSafety(t *testing.T) { + // Generate multiple tokens and ensure they're DNS TXT record safe + for i := 0; i < 100; i++ { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenTokenIteration, err, i) + } + + // Check that token doesn't contain DNS-problematic characters + // DNS TXT records generally support alphanumeric and some symbols + unsafeChars := []string{" ", "\"", "\\", "\n", "\r", "\t"} + for _, unsafe := range unsafeChars { + if strings.Contains(token, unsafe) { + t.Errorf("Token contains DNS-unsafe character '%s': %s", unsafe, token) + } + } + + // Test full DNS record format + dnsRecord := dnsRecordPrefix + token + MaxDNSRecordLength := 255 + if len(dnsRecord) > MaxDNSRecordLength { + t.Errorf("DNS record too long (%d chars): %s", len(dnsRecord), dnsRecord) + } + } +} + +func TestDNSTXTRecordRFCCompliance(t *testing.T) { + // Test DNS TXT record format compliance according to RFC 1035 and RFC 1464 + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenToken, err) + } + + dnsRecord := dnsRecordPrefix + token + + // RFC 1035: DNS names and TXT records have specific length limitations + // TXT record data must not exceed 255 octets per string + if len(dnsRecord) > 255 { + t.Errorf("DNS TXT record exceeds 255 character limit: %d chars", len(dnsRecord)) + } + + // RFC 1464: TXT records should follow attribute=value format + if !strings.Contains(dnsRecord, "=") { + t.Error("DNS TXT record missing required '=' separator") + } + + parts := strings.SplitN(dnsRecord, "=", 2) + if len(parts) != 2 { + t.Error("DNS TXT record should have exactly one '=' separator") + } + + attribute := parts[0] + value := parts[1] + + // Validate attribute name (should be "mcp-verify") + expectedAttribute := strings.TrimSuffix(dnsRecordPrefix, "=") + if attribute != expectedAttribute { + t.Errorf("DNS TXT record attribute = %s, want %s", attribute, expectedAttribute) + } + + // Validate that value is our token + if value != token { + t.Errorf("DNS TXT record value = %s, want %s", value, token) + } + + // Test that the record contains only ASCII printable characters (RFC compliant) + for i, char := range dnsRecord { + if char < 32 || char > 126 { + t.Errorf("DNS TXT record contains non-ASCII printable character at position %d: %c (code %d)", i, char, char) + } + } +} + +func TestDNSTXTRecordSpecialCharacters(t *testing.T) { + // Test that DNS records handle RFC-compliant special characters correctly + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenToken, err) + } + + dnsRecord := dnsRecordPrefix + token + + // Characters that should NOT appear in our DNS records + prohibitedChars := []rune{ + 0, // NULL + 9, // TAB + 10, // LF + 13, // CR + 34, // Double quote + 92, // Backslash + 127, // DEL + } + + for _, prohibited := range prohibitedChars { + if strings.ContainsRune(dnsRecord, prohibited) { + t.Errorf("DNS record contains prohibited character: %c (code %d)", prohibited, prohibited) + } + } + + // Characters that SHOULD be allowed (base64url safe) + allowedChars := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_=" + for _, char := range dnsRecord { + if !strings.ContainsRune(allowedChars, char) { + t.Errorf("DNS record contains unexpected character: %c (code %d)", char, char) + } + } +} + +func TestDNSTXTRecordLength(t *testing.T) { + // Test DNS TXT record length constraints + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenToken, err) + } + + dnsRecord := dnsRecordPrefix + token + + // RFC 1035: TXT record strings are limited to 255 octets + maxTXTRecordLength := 255 + if len(dnsRecord) > maxTXTRecordLength { + t.Errorf("DNS TXT record length %d exceeds RFC limit of %d", len(dnsRecord), maxTXTRecordLength) + } + + // Calculate expected length: "mcp-verify=" (11 chars) + token (22 chars) = 33 chars + expectedLength := 11 + 22 // len("mcp-verify=") + token length + if len(dnsRecord) != expectedLength { + t.Errorf("DNS TXT record length %d, expected %d", len(dnsRecord), expectedLength) + } + + // Ensure we have reasonable margin below the limit + marginRequired := 50 // Leave room for future changes + if len(dnsRecord) > (maxTXTRecordLength - marginRequired) { + t.Errorf("DNS TXT record length %d too close to limit, needs %d char margin", len(dnsRecord), marginRequired) + } +} + +// Benchmark tests for performance +func BenchmarkGenerateVerificationToken(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := verification.GenerateVerificationToken() + if err != nil { + b.Fatalf(errMsgGenToken, err) + } + } +} diff --git a/server-name-verification.md b/server-name-verification.md new file mode 100644 index 00000000..9a64f07b --- /dev/null +++ b/server-name-verification.md @@ -0,0 +1,227 @@ +# Server Name Verification for MCP Metaregistry + +## Context and Problem Statement + +The MCP Metaregistry will allow MCP server publishers to use domain-scoped namespaces for their server entries (e.g. +`com.example/my-server`). We need a reliable way to ensure that a user claiming a domain-based namespace actually +owns (or is authorized to use) that domain. In other words, if someone publishes a server under `com.github/*`, we +must verify they control `github.com` to prevent impersonation or squatting. The solution should be secure, align +with industry best practices, and manageable long-term. + +## Decision Drivers + +- __Security and Authenticity:__ Only legitimate domain owners should be able to publish under that domain's namespace. + This prevents malicious actors from impersonating popular organizations. + +- __Industry Best Practice:__ Favor solutions known to be secure and commonly used for domain ownership proof + (minimize inventing new untested methods). + +- __Usability for Publishers:__ The verification process should be straightforward for developers and not require + excessive infrastructure (e.g., should work even if the domain isn't running a web server). + +- __Continuous Trust:__ The mechanism should not only verify ownership once, but also detect if ownership changes (e.g., + domain expires or is sold) and revoke publishing rights if necessary to protect integrity. + +- __Minimal External Dependencies:__ Rely primarily on the domain's DNS itself (which the owner already controls) + rather than third-party services, for simplicity and longevity. + +- __Organizational Use:__ Enable both individual users and organizations to verify domains (so that a team/org can + publish under a corporate domain once verified). + +- __Operational Maintainability:__ The solution should be possible to automate (for verification and periodic + re-checks) and monitor, with clear failure modes and recovery procedures. + +## Considered Options + +### Option 1: DNS TXT Record Verification + +The user adds a specified TXT record to their domain's DNS zone to prove control. This is a widely adopted method +(used by certificate authorities, cloud services, etc.) for domain ownership verification. + +#### Pros + +- __Highly secure__ requires direct access to domain DNS settings. +- __Independent__ of any web server, or HTTP content. +- __Industry-standard__ practice. +- DNS is __ubiquitous__. +- __Automate-able__ our service can query DNS anytime to validate. Continuous or repeated checks are straightforward + to implement by re-querying the DNS record. + +#### Cons + +- Requires the publisher to access and modify their DNS configuration, which may be non-intuitive for some users. +- DNS changes are __not instantaneous__. Propagation can take time (often minutes, sometimes hours), which could + delay verification. +- Without precautions, DNS lookups could be spoofed by an attacker (DNS poisoning) if not using secure resolvers or + DNSSEC. +- Keeping the TXT record in DNS long-term (for continuous verification) slightly "clutters" the domain's DNS zone, + though using a prefixed record minimizes any impact. + +### Option 2: HTTP-01 Web Challenge + +Provide a token that the user must serve via HTTP on a known URL (for example, hosting a file or response at +`http:///.well-known/mcp-verification/`). This approach is used by ACME (Let's Encrypt) for domain +validation. + +#### Pros + +- Fairly simple if the domain already hosts a website, the owner just drops a file or configures a response. +- Many developers are familiar with this from SSL certificate issuance. +- It doesn't require messing with DNS directly. +- Easy to automate if you have a webserver. +- Works with standard HTTP infrastructure. + +#### Cons + +- Not viable for domains that don't run an HTTP server or are not easily accessible on the internet. +- It fails for domains behind certain network restrictions (e.g., if port 80 is closed or filtered). +- Continuous monitoring would be complex. The registry would have to periodically re-fetch a URL and differentiate + between a temporarily down server vs. lost ownership. +- Introduces more points of failure (web hosting, redirects, etc.), whereas DNS is a more direct indicator of ownership. + +### Option 3: DNS CAA or Certificate-Based Methods + +Leverage the Certificate Authority Authorization (CAA) DNS record or possession of an SSL/TLS certificate for the +domain as proof. CAA records specify which CAs can issue certificates for a domain (ussed in SSL issuance control). +We could require a special CAA record or similar DNS record to prove ownership. + +#### Pros + +- If a domain owner can obtain a valid SSL certificate (which itself requires domain verification) or set a CAA, it + indirectly shows domain control. +- CAA is a DNS record, so it could be used in a similar way to TXT. +- Automatically checked by CAs, so some security-conscious domains already use it to restrict certificate issuance. + +#### Cons + +- CAA is not designed for arbitrary token storage or service-specific challenges. It only encodes which CAs are + allowed. Trying to repurpose it for our verification would be an abuse of its intended purpose and could conflict + with actual CAA usage. +- Not all domain owners set CAA, and those who do use it for security policies might be unwilling to change it for this. +- Using possession of an SSL certificate as proof is also problematic. It adds an extra step (obtaining a cert) and + still ultimately relies on DNS or email validation in the certificate process. + +### Option 4: OAuth-Based Domain Linking (e.g. via GitHub) + +Use a trusted third-party platform to vouch for domain ownership. For example, GitHub organizations allow domain +verification (with DNS) to display a "Verified" badge. We could accept a link between the user's GitHub account/org +and a domain as evidence, or use an OAuth flow with a provider that has the domain in email. + +#### Pros + +- In cases where the publisher is a company with an existing verified GitHub organization, this could save a step. + They may have already proven domain ownership to GitHub. Using an OAuth link or API, we might trust that + verification instead of asking for another DNS record. +- Offloads the verification to a known platform and might simplify the process for some users (no need to handle DNS + if already done elsewhere). + +#### Cons + +- Only works for a subset of users (e.g., those using GitHub and having verified domains there). +- It introduces an external dependency and potential single point of failure. +- Indirectly still using DNS for verification but are one step removed from the source of truth. + +### Option 5: Email Confirmation to Domain Admin + +Send a verification code via email to an address at the target domain (commonly used addresses like `admin@domain.com` +or WHOIS contact). This method is sometimes offered by CAs for domain validation. + +#### Pros + +- Does not require DNS changes or hosting files. +- If the domain owner actively uses an administrative email, it's a direct way to reach them. +- It could be automated by sending an email and awaiting a confirmation link or code input. + +#### Cons + +- Assumes standard email aliases (`admin@`, `webmaster@`, etc.) or accurate WHOIS contact emails, which may not + exist or be monitored. +- Automating the ingestion of the confirmation is less straightforward (it may require a manual step by the user to + click a link or paste a code). +- Operationally, running an email service and handling bounces/non-delivery adds complexity. +- Many domains have WHOIS privacy. + +## Proposed Solution + +We support two __complementary__ ownership-verification mechanisms: + +| Method | Best for publishers who... | Key strengths | Key limitations | +|:----------------------|:-----------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------| +| DNS TXT Record | Can edit DNS at their registrar / DNS host | - Industry-standard proof of control
- Works even if the domain has no web server
- Easy to re-check automatically | - Requires DNS access
- Propagation delay | +| HTTP-01 Web Challenge | Already run a web site / can deploy a static file but cannot touch DNS | - No registrar access needed (just drop a file or route)
- Familiar to developers from Let's Encrypt
- Near-instant verification (no DNS caching) | - Fails if the domain has no publicly reachable HTTP(S) service
- Adds a second control plane (web hosting) that must stay available | + +### Why offer both options? + +- __Wider coverage = better UX__\ + Between DNS and HTTP we cover virtually all real-world setups. Examples: a SaaS team on a locked-down corporate + DNS can still verify via HTTP; a bare domain that hosts no site can verify by TXT. + +- __Failsafe resilience__\ + If one control plain is down (DNS outage or web migration), the other can still validate (publish pipelines keep + moving). + +- __Layered security__\ + For maintainers who enable _both_ methods of verification, an attacker must compromise both DNS and web origin to + hijack the namespace. + +- __Consistent automation model__\ + Both rely on 128-bit random tokens and can be re-checked on every publish plus a nightly cron, so continuous trust + is preserved. + +### How it works + +1. __Token issuance:__ When a publisher first claims a custom domain namespace the registry generates a 128-bit + random token. + +2. __Prove control via either path:__ + - __DNS path:__ Add TXT record `mcp-verify=` to DNS. + - __HTTP path:__ Host a plain-text file whose body is the token at `https:///. + well-known/mcp-challenge/`. + +3. __Automated check:__ The CLI/server polls DNS or fetches the well-known URL; success in __either__ path marks the + domain verified for that user or organization. + +4. __Continuous verification:__ To guard against later ownership changes, the registry re-checks __both__ indicators: + - __Every publish__ immediately queries DNS and/or fetches the well-known file; publishing is allowed if at least + one token still matches. + - __Background job (run on a regular cadence)__ re-checks every verified domain using both DNS and HTTP tokens. The + job will apply a failure-tolerance policy. For example, if a domain fails the check three times in a row, it is + marked unverified and new publishes are blocked. After the second consecutive failure, maintainers receive a + warning; if the check fails a third time, they are notified again as the domain status is downgraded. This guards + against transient outages while still revoking trust when ownership indicators consistently disappear. + +This dual mechanism provides layered security, DNS is the gold-standard signal, while HTTP-01 offers a low-friction +alternative for teams that cannot touch DNS. Together they: + +- Cover nearly every hosting scenario (DNS-only, web-only, or both). + +- Let maintainers migrate from one method to the other without renaming packages. + +- Add resilience: if a DNS provider or web host is temporarily down, the other path still validates, keeping CI/CD + pipelines unblocked. + +By combining DNS and HTTP verification, and by continuously validating whichever token(s) are configured, the MCP +Metaregistry delivers high assurance of domain ownership while remaining flexible and developer-friendly. + +### Positive Consequences + +- High-confidence ownership with flexibility. DNS remains the gold-standard; HTTP-01 offers a low-friction + alternative when DNS edits are impossible. + +- Reduced onboarding friction. Developers pick the path of least resistance; fewer support tickets. + +- Operational robustness. Dual-path verification means fewer false blocks during provider outages. + +- Organizational friendliness. Either method can be performed once by an infra team and thereafter reused by all + org members. + +### Negative Consequences + +- Slightly more code and monitoring. We must implement and observe two verification paths instead of one, and store + two tokens per domain. + +- Extra edge-cases. Need clear rules for what happens if DNS passes but HTTP fails (and vice-versa). Policy: allow + publish if any passes; flag if both fail. + +- Web-server dependency for HTTP-01. Projects choosing only HTTP must keep the well-known file reachable; transient + 5xx outages could momentarily block publishes. Continuous checks mitigate but do not eliminate this risk. diff --git a/tests/database/validate-tokens.go b/tests/database/validate-tokens.go new file mode 100644 index 00000000..127a36e8 --- /dev/null +++ b/tests/database/validate-tokens.go @@ -0,0 +1,121 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/modelcontextprotocol/registry/internal/database" + "github.com/modelcontextprotocol/registry/internal/model" +) + +// This is a manual validation script to test MongoDB unique constraints +// Run this against a MongoDB instance to validate the unique index behavior +func main() { + fmt.Println("MongoDB Token Uniqueness Validation") + fmt.Println("===================================") + + // You can modify this connection string for your MongoDB instance + connectionURI := "mongodb://localhost:27017" + if connectionURI == "mongodb://localhost:27017" { + fmt.Println("โš ๏ธ Using default MongoDB URI. Update if needed.") + fmt.Println(" To test with a different MongoDB, update the connectionURI variable.") + fmt.Println() + } + + ctx := context.Background() + + // Connect to MongoDB + fmt.Println("๐Ÿ”Œ Connecting to MongoDB...") + db, err := database.NewMongoDB(ctx, connectionURI, "test_token_validation", "servers", "verification") + if err != nil { + log.Fatalf("Failed to connect to MongoDB: %v", err) + } + defer db.Close() + + fmt.Println("โœ… Connected successfully!") + + // Clean up any existing test data + fmt.Println("๐Ÿงน Cleaning up existing test data...") + // Note: In a real validation, you might want to preserve existing data + // This is just for testing purposes + + // Test 1: Basic token uniqueness + fmt.Println("\n๐Ÿ“ Test 1: Basic Token Uniqueness") + fmt.Println("----------------------------------") + + testToken := &model.VerificationToken{ + Token: "validation_token_" + fmt.Sprintf("%d", time.Now().Unix()), + CreatedAt: time.Now(), + } + + // Store token for first domain + fmt.Printf(" Storing token '%s' for domain1.test...\n", testToken.Token) + err = db.StoreVerificationToken(ctx, "domain1.test", testToken) + if err != nil { + fmt.Printf(" โŒ Unexpected error: %v\n", err) + } else { + fmt.Println(" โœ… First storage succeeded") + } + + // Try to store same token for different domain + fmt.Printf(" Storing same token '%s' for domain2.test...\n", testToken.Token) + err = db.StoreVerificationToken(ctx, "domain2.test", testToken) + if err != nil { + if database.ErrTokenAlreadyExists.Error() == err.Error() { + fmt.Println(" โœ… Correctly rejected duplicate token") + } else { + fmt.Printf(" โš ๏ธ Got error but not the expected one: %v\n", err) + } + } else { + fmt.Println(" โŒ Should have rejected duplicate token!") + } + + // Test 2: Different tokens should work + fmt.Println("\n๐Ÿ“ Test 2: Different Tokens Should Work") + fmt.Println("---------------------------------------") + + differentToken := &model.VerificationToken{ + Token: "different_token_" + fmt.Sprintf("%d", time.Now().Unix()), + CreatedAt: time.Now(), + } + + fmt.Printf(" Storing different token '%s' for domain2.test...\n", differentToken.Token) + err = db.StoreVerificationToken(ctx, "domain2.test", differentToken) + if err != nil { + fmt.Printf(" โŒ Unexpected error: %v\n", err) + } else { + fmt.Println(" โœ… Different token stored successfully") + } + + // Test 3: Verify stored tokens + fmt.Println("\n๏ฟฝ๏ฟฝ Test 3: Verify Stored Tokens") + fmt.Println("-------------------------------") + + tokens1, err := db.GetVerificationTokens(ctx, "domain1.test") + if err != nil { + fmt.Printf(" โŒ Error retrieving tokens for domain1: %v\n", err) + } else { + fmt.Printf(" Domain1 has %d pending token(s)\n", len(tokens1.PendingTokens)) + if len(tokens1.PendingTokens) > 0 { + fmt.Printf(" First token: %s\n", tokens1.PendingTokens[0].Token) + } + } + + tokens2, err := db.GetVerificationTokens(ctx, "domain2.test") + if err != nil { + fmt.Printf(" โŒ Error retrieving tokens for domain2: %v\n", err) + } else { + fmt.Printf(" Domain2 has %d pending token(s)\n", len(tokens2.PendingTokens)) + if len(tokens2.PendingTokens) > 0 { + fmt.Printf(" First token: %s\n", tokens2.PendingTokens[0].Token) + } + } + + fmt.Println("\n๐ŸŽ‰ Validation complete!") + fmt.Println("\n๐Ÿ’ก Tips:") + fmt.Println(" - If Test 1 shows duplicate rejection, MongoDB unique indexes are working") + fmt.Println(" - If Test 2 succeeds, different tokens are allowed") + fmt.Println(" - If Test 3 shows correct token counts, storage is working properly") +}