Skip to content

Commit 28244e6

Browse files
committed
Move validators to a single package
Signed-off-by: Radoslav Dimitrov <[email protected]>
1 parent b24fa47 commit 28244e6

File tree

8 files changed

+617
-635
lines changed

8 files changed

+617
-635
lines changed

internal/api/handlers/v0/edit.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"github.com/modelcontextprotocol/registry/internal/validators"
1616
apiv1 "github.com/modelcontextprotocol/registry/pkg/api/v1"
1717
"github.com/modelcontextprotocol/registry/pkg/model"
18-
"github.com/modelcontextprotocol/registry/pkg/validation"
1918
)
2019

2120
// EditServerInput represents the input for editing a server
@@ -56,7 +55,7 @@ func RegisterEditEndpoints(api huma.API, registry service.RegistryService, cfg *
5655
}
5756

5857
// Validate that only allowed extension fields are present
59-
if err := validation.ValidatePublishRequestExtensions(input.RawBody); err != nil {
58+
if err := validators.ValidatePublishRequestExtensions(input.RawBody); err != nil {
6059
return nil, huma.Error400BadRequest("Invalid request format", err)
6160
}
6261

internal/api/handlers/v0/publish.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"github.com/modelcontextprotocol/registry/internal/service"
1313
"github.com/modelcontextprotocol/registry/internal/validators"
1414
apiv1 "github.com/modelcontextprotocol/registry/pkg/api/v1"
15-
"github.com/modelcontextprotocol/registry/pkg/validation"
1615
)
1716

1817
// PublishServerInput represents the input for publishing a server
@@ -52,7 +51,7 @@ func RegisterPublishEndpoint(api huma.API, registry service.RegistryService, cfg
5251
}
5352

5453
// Validate that only allowed extension fields are present
55-
if err := validation.ValidatePublishRequestExtensions(input.RawBody); err != nil {
54+
if err := validators.ValidatePublishRequestExtensions(input.RawBody); err != nil {
5655
return nil, huma.Error400BadRequest("Invalid request format", err)
5756
}
5857

internal/service/fake_service.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ import (
66

77
"github.com/google/uuid"
88
"github.com/modelcontextprotocol/registry/internal/database"
9+
"github.com/modelcontextprotocol/registry/internal/validators"
910
apiv1 "github.com/modelcontextprotocol/registry/pkg/api/v1"
1011
"github.com/modelcontextprotocol/registry/pkg/model"
11-
"github.com/modelcontextprotocol/registry/pkg/validation"
1212
)
1313

1414
// fakeRegistryService implements RegistryService interface with an in-memory database
@@ -114,17 +114,17 @@ func (s *fakeRegistryService) Publish(req apiv1.PublishRequest) (*apiv1.ServerRe
114114
defer cancel()
115115

116116
// Validate the request
117-
if err := validation.ValidatePublisherExtensions(req); err != nil {
117+
if err := validators.ValidatePublisherExtensions(req); err != nil {
118118
return nil, err
119119
}
120120

121121
// Validate server name exists
122-
if _, err := validation.ParseServerName(req.Server); err != nil {
122+
if _, err := validators.ParseServerName(req.Server); err != nil {
123123
return nil, err
124124
}
125125

126126
// Extract publisher extensions from request
127-
publisherExtensions := validation.ExtractPublisherExtensions(req)
127+
publisherExtensions := validators.ExtractPublisherExtensions(req)
128128

129129
// Create registry metadata for fake service (always marks as latest)
130130
now := time.Now()
@@ -152,17 +152,17 @@ func (s *fakeRegistryService) EditServer(id string, req apiv1.PublishRequest) (*
152152
defer cancel()
153153

154154
// Validate the request
155-
if err := validation.ValidatePublisherExtensions(req); err != nil {
155+
if err := validators.ValidatePublisherExtensions(req); err != nil {
156156
return nil, err
157157
}
158158

159159
// Validate server name exists and format
160-
if _, err := validation.ParseServerName(req.Server); err != nil {
160+
if _, err := validators.ParseServerName(req.Server); err != nil {
161161
return nil, err
162162
}
163163

164164
// Extract publisher extensions from request
165-
publisherExtensions := validation.ExtractPublisherExtensions(req)
165+
publisherExtensions := validators.ExtractPublisherExtensions(req)
166166

167167
// Update server in database
168168
serverRecord, err := s.db.UpdateServer(ctx, id, req.Server, publisherExtensions)

internal/service/registry_service.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ import (
1010

1111
"github.com/google/uuid"
1212
"github.com/modelcontextprotocol/registry/internal/database"
13+
"github.com/modelcontextprotocol/registry/internal/validators"
1314
apiv1 "github.com/modelcontextprotocol/registry/pkg/api/v1"
1415
"github.com/modelcontextprotocol/registry/pkg/model"
15-
"github.com/modelcontextprotocol/registry/pkg/validation"
1616
)
1717

1818
const maxServerVersionsPerServer = 10000
@@ -133,12 +133,12 @@ func (s *registryServiceImpl) Publish(req apiv1.PublishRequest) (*apiv1.ServerRe
133133
defer cancel()
134134

135135
// Validate the request
136-
if err := validation.ValidatePublisherExtensions(req); err != nil {
136+
if err := validators.ValidatePublisherExtensions(req); err != nil {
137137
return nil, err
138138
}
139139

140140
// Validate server name exists and format
141-
if _, err := validation.ParseServerName(req.Server); err != nil {
141+
if _, err := validators.ParseServerName(req.Server); err != nil {
142142
return nil, err
143143
}
144144

@@ -150,7 +150,7 @@ func (s *registryServiceImpl) Publish(req apiv1.PublishRequest) (*apiv1.ServerRe
150150
}
151151

152152
// Validate reverse-DNS namespace matching for remote URLs
153-
if err := validation.ValidateRemoteNamespaceMatch(req.Server); err != nil {
153+
if err := validators.ValidateRemoteNamespaceMatch(req.Server); err != nil {
154154
return nil, err
155155
}
156156

@@ -201,7 +201,7 @@ func (s *registryServiceImpl) Publish(req apiv1.PublishRequest) (*apiv1.ServerRe
201201
}
202202

203203
// Extract publisher extensions from request
204-
publisherExtensions := validation.ExtractPublisherExtensions(req)
204+
publisherExtensions := validators.ExtractPublisherExtensions(req)
205205

206206
// Create registry metadata with service-determined values
207207
registryMetadata := apiv1.RegistryExtensions{
@@ -261,12 +261,12 @@ func (s *registryServiceImpl) EditServer(id string, req apiv1.PublishRequest) (*
261261
defer cancel()
262262

263263
// Validate the request
264-
if err := validation.ValidatePublisherExtensions(req); err != nil {
264+
if err := validators.ValidatePublisherExtensions(req); err != nil {
265265
return nil, err
266266
}
267267

268268
// Validate server name exists and format
269-
if _, err := validation.ParseServerName(req.Server); err != nil {
269+
if _, err := validators.ParseServerName(req.Server); err != nil {
270270
return nil, err
271271
}
272272

@@ -278,7 +278,7 @@ func (s *registryServiceImpl) EditServer(id string, req apiv1.PublishRequest) (*
278278
}
279279

280280
// Validate reverse-DNS namespace matching for remote URLs
281-
if err := validation.ValidateRemoteNamespaceMatch(req.Server); err != nil {
281+
if err := validators.ValidateRemoteNamespaceMatch(req.Server); err != nil {
282282
return nil, err
283283
}
284284

@@ -288,7 +288,7 @@ func (s *registryServiceImpl) EditServer(id string, req apiv1.PublishRequest) (*
288288
}
289289

290290
// Extract publisher extensions from request
291-
publisherExtensions := validation.ExtractPublisherExtensions(req)
291+
publisherExtensions := validators.ExtractPublisherExtensions(req)
292292

293293
// Update server in database
294294
serverRecord, err := s.db.UpdateServer(ctx, id, req.Server, publisherExtensions)

internal/validators/validators.go

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package validators
22

33
import (
4+
"encoding/json"
45
"fmt"
6+
"net/url"
7+
"slices"
8+
"strings"
59

10+
apiv1 "github.com/modelcontextprotocol/registry/pkg/api/v1"
611
"github.com/modelcontextprotocol/registry/pkg/model"
712
)
813

@@ -122,3 +127,163 @@ func (ov *ObjectValidator) Validate(obj *model.ServerJSON) error {
122127
}
123128
return nil
124129
}
130+
131+
// ValidatePublisherExtensions validates that publisher extensions are within size limits
132+
func ValidatePublisherExtensions(req apiv1.PublishRequest) error {
133+
const maxExtensionSize = 4 * 1024 // 4KB limit
134+
135+
// Check size limit for x-publisher extension
136+
if req.XPublisher != nil {
137+
extensionsJSON, err := json.Marshal(req.XPublisher)
138+
if err != nil {
139+
return fmt.Errorf("failed to marshal x-publisher extension: %w", err)
140+
}
141+
if len(extensionsJSON) > maxExtensionSize {
142+
return fmt.Errorf("x-publisher extension exceeds 4KB limit (%d bytes)", len(extensionsJSON))
143+
}
144+
}
145+
146+
return nil
147+
}
148+
149+
// ValidatePublishRequestExtensions validates that only allowed extension fields are present
150+
func ValidatePublishRequestExtensions(requestData []byte) error {
151+
// Parse the raw JSON to check for unknown fields
152+
var rawRequest map[string]interface{}
153+
if err := json.Unmarshal(requestData, &rawRequest); err != nil {
154+
return fmt.Errorf("failed to parse request JSON: %w", err)
155+
}
156+
157+
// Define allowed top-level fields
158+
allowedFields := map[string]bool{
159+
"server": true,
160+
"x-publisher": true,
161+
}
162+
163+
// Check for any disallowed fields
164+
var invalidFields []string
165+
for field := range rawRequest {
166+
if !allowedFields[field] {
167+
invalidFields = append(invalidFields, field)
168+
}
169+
}
170+
171+
if len(invalidFields) > 0 {
172+
return fmt.Errorf("invalid extension fields: %v. Only 'server' and 'x-publisher' fields are allowed", invalidFields)
173+
}
174+
175+
return nil
176+
}
177+
178+
// ExtractPublisherExtensions extracts publisher extensions from a apiv1.PublishRequest
179+
func ExtractPublisherExtensions(req apiv1.PublishRequest) map[string]interface{} {
180+
publisherExtensions := make(map[string]interface{})
181+
if req.XPublisher != nil {
182+
// Copy fields directly, avoiding double nesting
183+
for k, v := range req.XPublisher {
184+
publisherExtensions[k] = v
185+
}
186+
}
187+
return publisherExtensions
188+
}
189+
190+
// ParseServerName extracts the server name from a model.ServerJSON for validation purposes
191+
func ParseServerName(serverDetail model.ServerJSON) (string, error) {
192+
name := serverDetail.Name
193+
if name == "" {
194+
return "", fmt.Errorf("server name is required and must be a string")
195+
}
196+
197+
// Validate format: dns-namespace/name
198+
if !strings.Contains(name, "/") {
199+
return "", fmt.Errorf("server name must be in format 'dns-namespace/name' (e.g., 'com.example.api/server')")
200+
}
201+
202+
parts := strings.SplitN(name, "/", 2)
203+
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
204+
return "", fmt.Errorf("server name must be in format 'dns-namespace/name' with non-empty namespace and name parts")
205+
}
206+
207+
return name, nil
208+
}
209+
210+
// ValidateRemoteNamespaceMatch validates that remote URLs match the reverse-DNS namespace
211+
func ValidateRemoteNamespaceMatch(serverDetail model.ServerJSON) error {
212+
namespace := serverDetail.Name
213+
214+
for _, remote := range serverDetail.Remotes {
215+
if err := validateRemoteURLMatchesNamespace(remote.URL, namespace); err != nil {
216+
return fmt.Errorf("remote URL %s does not match namespace %s: %w", remote.URL, namespace, err)
217+
}
218+
}
219+
220+
return nil
221+
}
222+
223+
// validateRemoteURLMatchesNamespace checks if a remote URL's hostname matches the publisher domain from the namespace
224+
func validateRemoteURLMatchesNamespace(remoteURL, namespace string) error {
225+
// Parse the URL to extract the hostname
226+
parsedURL, err := url.Parse(remoteURL)
227+
if err != nil {
228+
return fmt.Errorf("invalid URL format: %w", err)
229+
}
230+
231+
hostname := parsedURL.Hostname()
232+
if hostname == "" {
233+
return fmt.Errorf("URL must have a valid hostname")
234+
}
235+
236+
// Skip validation for localhost and local development URLs
237+
if hostname == "localhost" || strings.HasSuffix(hostname, ".localhost") || hostname == "127.0.0.1" {
238+
return nil
239+
}
240+
241+
// Extract publisher domain from reverse-DNS namespace
242+
publisherDomain := extractPublisherDomainFromNamespace(namespace)
243+
if publisherDomain == "" {
244+
return fmt.Errorf("invalid namespace format: cannot extract domain from %s", namespace)
245+
}
246+
247+
// Check if the remote URL hostname matches the publisher domain or is a subdomain
248+
if !isValidHostForDomain(hostname, publisherDomain) {
249+
return fmt.Errorf("remote URL host %s does not match publisher domain %s", hostname, publisherDomain)
250+
}
251+
252+
return nil
253+
}
254+
255+
// extractPublisherDomainFromNamespace converts reverse-DNS namespace to normal domain format
256+
// e.g., "com.example" -> "example.com"
257+
func extractPublisherDomainFromNamespace(namespace string) string {
258+
// Extract the namespace part before the first slash
259+
namespacePart := namespace
260+
if slashIdx := strings.Index(namespace, "/"); slashIdx != -1 {
261+
namespacePart = namespace[:slashIdx]
262+
}
263+
264+
// Split into parts and reverse them to get normal domain format
265+
parts := strings.Split(namespacePart, ".")
266+
if len(parts) < 2 {
267+
return ""
268+
}
269+
270+
// Reverse the parts to convert from reverse-DNS to normal domain
271+
slices.Reverse(parts)
272+
273+
return strings.Join(parts, ".")
274+
}
275+
276+
// isValidHostForDomain checks if a hostname is the domain or a subdomain of the publisher domain
277+
func isValidHostForDomain(hostname, publisherDomain string) bool {
278+
// Exact match
279+
if hostname == publisherDomain {
280+
return true
281+
}
282+
283+
// Subdomain match - hostname should end with "." + publisherDomain
284+
if strings.HasSuffix(hostname, "."+publisherDomain) {
285+
return true
286+
}
287+
288+
return false
289+
}

0 commit comments

Comments
 (0)