@@ -3,11 +3,84 @@ package api_test
33import (
44 "net/http"
55 "net/http/httptest"
6+ "strings"
67 "testing"
78
89 "github.com/modelcontextprotocol/registry/internal/api"
910)
1011
12+ func TestNulByteValidationMiddleware (t * testing.T ) {
13+ // Create a simple handler that returns "OK"
14+ handler := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
15+ w .WriteHeader (http .StatusOK )
16+ _ , _ = w .Write ([]byte ("OK" ))
17+ })
18+
19+ // Wrap with our middleware
20+ middleware := api .NulByteValidationMiddleware (handler )
21+
22+ t .Run ("normal path should pass through" , func (t * testing.T ) {
23+ req := httptest .NewRequest (http .MethodGet , "/v0/servers" , nil )
24+ w := httptest .NewRecorder ()
25+ middleware .ServeHTTP (w , req )
26+
27+ if w .Code != http .StatusOK {
28+ t .Errorf ("expected status %d, got %d" , http .StatusOK , w .Code )
29+ }
30+ })
31+
32+ t .Run ("path with query params should pass through" , func (t * testing.T ) {
33+ req := httptest .NewRequest (http .MethodGet , "/v0/servers?cursor=abc123" , nil )
34+ w := httptest .NewRecorder ()
35+ middleware .ServeHTTP (w , req )
36+
37+ if w .Code != http .StatusOK {
38+ t .Errorf ("expected status %d, got %d" , http .StatusOK , w .Code )
39+ }
40+ })
41+
42+ t .Run ("path with NUL byte should return 400" , func (t * testing.T ) {
43+ // Create request with NUL byte in path by manually setting URL
44+ req := httptest .NewRequest (http .MethodGet , "/v0/servers/test" , nil )
45+ req .URL .Path = "/v0/servers/\x00 "
46+ w := httptest .NewRecorder ()
47+ middleware .ServeHTTP (w , req )
48+
49+ if w .Code != http .StatusBadRequest {
50+ t .Errorf ("expected status %d, got %d" , http .StatusBadRequest , w .Code )
51+ }
52+ if ! strings .Contains (w .Body .String (), "URL path contains null bytes" ) {
53+ t .Errorf ("expected body to contain error message, got %q" , w .Body .String ())
54+ }
55+ })
56+
57+ t .Run ("query with NUL byte should return 400" , func (t * testing.T ) {
58+ // Create request with NUL byte in query by manually setting RawQuery
59+ req := httptest .NewRequest (http .MethodGet , "/v0/servers" , nil )
60+ req .URL .RawQuery = "cursor=\x00 "
61+ w := httptest .NewRecorder ()
62+ middleware .ServeHTTP (w , req )
63+
64+ if w .Code != http .StatusBadRequest {
65+ t .Errorf ("expected status %d, got %d" , http .StatusBadRequest , w .Code )
66+ }
67+ if ! strings .Contains (w .Body .String (), "query parameters contain null bytes" ) {
68+ t .Errorf ("expected body to contain error message, got %q" , w .Body .String ())
69+ }
70+ })
71+
72+ t .Run ("path with embedded NUL byte should return 400" , func (t * testing.T ) {
73+ req := httptest .NewRequest (http .MethodGet , "/v0/servers/test" , nil )
74+ req .URL .Path = "/v0/servers/test\x00 name"
75+ w := httptest .NewRecorder ()
76+ middleware .ServeHTTP (w , req )
77+
78+ if w .Code != http .StatusBadRequest {
79+ t .Errorf ("expected status %d, got %d" , http .StatusBadRequest , w .Code )
80+ }
81+ })
82+ }
83+
1184func TestTrailingSlashMiddleware (t * testing.T ) {
1285 // Create a simple handler that returns "OK"
1386 handler := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
0 commit comments