Skip to content

Commit 56920f8

Browse files
author
Avish Porwal
committed
Fixing unhandled NUL Bytes in API Requests
1 parent 9afbaac commit 56920f8

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

internal/api/handlers/v0/servers.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ type ServerVersionsInput struct {
4242
}
4343

4444
// RegisterServersEndpoints registers all server-related endpoints with a custom path prefix
45+
//
46+
//nolint:cyclop
4547
func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service.RegistryService) {
4648
// List servers endpoint
4749
huma.Register(api, huma.Operation{
@@ -52,6 +54,10 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service.
5254
Description: "Get a paginated list of MCP servers from the registry",
5355
Tags: []string{"servers"},
5456
}, func(ctx context.Context, input *ListServersInput) (*Response[apiv0.ServerListResponse], error) {
57+
if containsNULByte(input.Cursor) {
58+
return nil, huma.Error400BadRequest("Invalid cursor: NUL byte not allowed")
59+
}
60+
5561
// Build filter from input parameters
5662
filter := &database.ServerFilter{}
5763

@@ -119,12 +125,18 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service.
119125
if err != nil {
120126
return nil, huma.Error400BadRequest("Invalid server name encoding", err)
121127
}
128+
if containsNULByte(serverName) {
129+
return nil, huma.Error400BadRequest("Invalid server name: NUL byte not allowed")
130+
}
122131

123132
// URL-decode the version
124133
version, err := url.PathUnescape(input.Version)
125134
if err != nil {
126135
return nil, huma.Error400BadRequest("Invalid version encoding", err)
127136
}
137+
if containsNULByte(version) {
138+
return nil, huma.Error400BadRequest("Invalid version: NUL byte not allowed")
139+
}
128140

129141
var serverResponse *apiv0.ServerResponse
130142
// Handle "latest" as a special version
@@ -160,6 +172,9 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service.
160172
if err != nil {
161173
return nil, huma.Error400BadRequest("Invalid server name encoding", err)
162174
}
175+
if containsNULByte(serverName) {
176+
return nil, huma.Error400BadRequest("Invalid server name: NUL byte not allowed")
177+
}
163178

164179
// Get all versions for this server
165180
servers, err := registry.GetAllVersionsByServerName(ctx, serverName)
@@ -186,3 +201,7 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service.
186201
}, nil
187202
})
188203
}
204+
205+
func containsNULByte(s string) bool {
206+
return strings.IndexByte(s, 0) >= 0
207+
}

internal/api/handlers/v0/servers_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,8 @@ func TestServersEndpointEdgeCases(t *testing.T) {
461461
{"limit too high", "?limit=1000", http.StatusUnprocessableEntity, "validation failed"},
462462
{"negative limit", "?limit=-1", http.StatusUnprocessableEntity, "validation failed"},
463463
{"invalid updated_since format", "?updated_since=invalid", http.StatusBadRequest, "Invalid updated_since format"},
464+
{"cursor contains NUL", "?cursor=%00", http.StatusBadRequest, "Invalid cursor"},
465+
{"cursor contains non NUL", "?cursor=server", http.StatusOK, ""},
464466
{"future updated_since", "?updated_since=2030-01-01T00:00:00Z", http.StatusOK, ""},
465467
{"very old updated_since", "?updated_since=1990-01-01T00:00:00Z", http.StatusOK, ""},
466468
{"empty search parameter", "?search=", http.StatusOK, ""},
@@ -489,6 +491,16 @@ func TestServersEndpointEdgeCases(t *testing.T) {
489491
}
490492
})
491493

494+
t.Run("path parameter NUL byte rejected", func(t *testing.T) {
495+
req := httptest.NewRequest(http.MethodGet, "/v0/servers/%00/versions", nil)
496+
w := httptest.NewRecorder()
497+
498+
mux.ServeHTTP(w, req)
499+
500+
assert.Equal(t, http.StatusBadRequest, w.Code)
501+
assert.Contains(t, w.Body.String(), "Invalid server name")
502+
})
503+
492504
t.Run("response structure validation", func(t *testing.T) {
493505
req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil)
494506
w := httptest.NewRecorder()

0 commit comments

Comments
 (0)