Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"context"
"encoding/json"
"log"
"net/http"
"strings"
Expand All @@ -17,26 +18,54 @@ import (
"github.com/modelcontextprotocol/registry/internal/telemetry"
)

// NulByteValidationMiddleware rejects requests containing NUL bytes in URL path or query parameters
// This prevents PostgreSQL encoding errors and returns a proper 400 Bad Request
// NulByteValidationMiddleware rejects requests containing NUL bytes in URL path or query parameters.
// This prevents PostgreSQL encoding errors (SQLSTATE 22021) and returns a proper 400 Bad Request.
// Checks for both literal NUL bytes (\x00) and URL-encoded form (%00).
func NulByteValidationMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check URL path for NUL bytes
if strings.ContainsRune(r.URL.Path, '\x00') {
http.Error(w, "Invalid request: URL path contains null bytes", http.StatusBadRequest)
// Check URL path for literal NUL bytes or URL-encoded %00
// Path needs %00 check because handlers call url.PathUnescape() which would decode it
if containsNulByte(r.URL.Path) {
writeErrorResponse(w, http.StatusBadRequest, "Invalid request: URL path contains null bytes")
return
}

// Check raw query string for NUL bytes
if strings.ContainsRune(r.URL.RawQuery, '\x00') {
http.Error(w, "Invalid request: query parameters contain null bytes", http.StatusBadRequest)
// Check raw query string for literal NUL bytes or URL-encoded %00
if containsNulByte(r.URL.RawQuery) {
writeErrorResponse(w, http.StatusBadRequest, "Invalid request: query parameters contain null bytes")
return
}

next.ServeHTTP(w, r)
})
}

// writeErrorResponse writes a JSON error response using huma's ErrorModel format
// for consistency with the rest of the API.
func writeErrorResponse(w http.ResponseWriter, status int, detail string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)

errModel := &huma.ErrorModel{
Title: http.StatusText(status),
Status: status,
Detail: detail,
}
_ = json.NewEncoder(w).Encode(errModel)
}

// containsNulByte checks if a string contains a NUL byte, either as a literal \x00
// or URL-encoded as %00.
func containsNulByte(s string) bool {
// Check for literal NUL byte
if strings.ContainsRune(s, '\x00') {
return true
}
// Check for URL-encoded NUL byte (%00)
// Using Contains directly since %00 has no case variation (both hex digits are 0)
return strings.Contains(s, "%00")
}

// TrailingSlashMiddleware redirects requests with trailing slashes to their canonical form
func TrailingSlashMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
88 changes: 88 additions & 0 deletions internal/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ func TestNulByteValidationMiddleware(t *testing.T) {
if !strings.Contains(w.Body.String(), "URL path contains null bytes") {
t.Errorf("expected body to contain error message, got %q", w.Body.String())
}
// Verify JSON response format
if w.Header().Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %q", w.Header().Get("Content-Type"))
}
})

t.Run("query with NUL byte should return 400", func(t *testing.T) {
Expand Down Expand Up @@ -79,6 +83,90 @@ func TestNulByteValidationMiddleware(t *testing.T) {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})

t.Run("query with URL-encoded NUL byte (%00) should return 400", func(t *testing.T) {
// This is the exact case from issue #862: ?cursor=%00
req := httptest.NewRequest(http.MethodGet, "/v0/servers?cursor=%00", nil)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
if !strings.Contains(w.Body.String(), "query parameters contain null bytes") {
t.Errorf("expected body to contain error message, got %q", w.Body.String())
}
})

t.Run("query with URL-encoded NUL byte followed by text should return 400", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v0/servers?cursor=%00test", nil)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})

t.Run("query with embedded URL-encoded NUL byte should return 400", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v0/servers?cursor=abc%00def", nil)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})

t.Run("query with double-encoded NUL byte (%2500) should pass through", func(t *testing.T) {
// %2500 decodes to %00 (literal string), not a NUL byte
// This is intentionally allowed - double-decoding is the caller's responsibility
req := httptest.NewRequest(http.MethodGet, "/v0/servers?cursor=%2500", nil)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

// This should pass - %2500 is not a NUL byte injection attempt
// When decoded once: %2500 -> %00 (the string "%00", not a NUL byte)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d (double-encoded should pass)", http.StatusOK, w.Code)
}
})

t.Run("query with valid percent-encoding should pass through", func(t *testing.T) {
// Ensure we don't false-positive on valid encodings like %20 (space)
req := httptest.NewRequest(http.MethodGet, "/v0/servers?search=hello%20world", nil)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
})

t.Run("path with URL-encoded NUL byte (%00) should return 400", func(t *testing.T) {
// Handlers call url.PathUnescape() which would decode %00 to \x00
req := httptest.NewRequest(http.MethodGet, "/v0/servers/%00/versions", nil)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
if !strings.Contains(w.Body.String(), "URL path contains null bytes") {
t.Errorf("expected body to contain error message, got %q", w.Body.String())
}
})

t.Run("path with URL-encoded NUL byte among other encodings should return 400", func(t *testing.T) {
// %0a is newline, %00 is NUL - should still catch the NUL
req := httptest.NewRequest(http.MethodGet, "/v0/servers/test%0a%00name/versions", nil)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
}

func TestTrailingSlashMiddleware(t *testing.T) {
Expand Down
Loading