diff --git a/api.go b/api.go index e93c88fd..99fbd02d 100644 --- a/api.go +++ b/api.go @@ -216,6 +216,10 @@ type Config struct { // for example if you need access to the path settings that may be changed // by the user after the defaults have been set. CreateHooks []func(Config) Config + + // RejectUnknownQueryParameters determines whether to reject requests + // containing query parameters not defined in the API spec. + RejectUnknownQueryParameters bool } // API represents a Huma API wrapping a specific router. @@ -259,6 +263,9 @@ type API interface { // See also `huma.Operation{}.Middlewares` for adding operation-specific // middleware at operation registration time. Middlewares() Middlewares + + // RejectUnknownQueryParameters indicates whether unknown query parameters should be rejected. + RejectUnknownQueryParameters() bool } // Format represents a request / response format. It is used to marshal and @@ -272,12 +279,13 @@ type Format struct { } type api struct { - config Config - adapter Adapter - formats map[string]Format - formatKeys []string - transformers []Transformer - middlewares Middlewares + config Config + adapter Adapter + formats map[string]Format + formatKeys []string + transformers []Transformer + middlewares Middlewares + rejectUnknownQueryParameters bool } func (a *api) Adapter() Adapter { @@ -354,6 +362,10 @@ func (a *api) Middlewares() Middlewares { return a.middlewares } +func (a *api) RejectUnknownQueryParameters() bool { + return a.rejectUnknownQueryParameters +} + // getAPIPrefix returns the API prefix from the first server URL in the OpenAPI // spec. If no server URL is set, then an empty string is returned. func getAPIPrefix(oapi *OpenAPI) string { @@ -385,10 +397,11 @@ func NewAPI(config Config, a Adapter) API { } newAPI := &api{ - config: config, - adapter: a, - formats: map[string]Format{}, - transformers: config.Transformers, + config: config, + adapter: a, + formats: map[string]Format{}, + transformers: config.Transformers, + rejectUnknownQueryParameters: config.RejectUnknownQueryParameters, } if config.OpenAPI == nil { diff --git a/huma.go b/huma.go index 6f98afaa..5ca32244 100644 --- a/huma.go +++ b/huma.go @@ -688,6 +688,60 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) var cookies map[string]*http.Cookie v := reflect.ValueOf(&input).Elem() + + // Reject unknown query parameters if config is set. + if api.RejectUnknownQueryParameters() && !op.SkipValidateParams { + u := ctx.URL() + q := u.Query() + + // Gather known parameters + deepObject prefixes. + knownParams := make(map[string]struct{}) + var deepPrefixes []string + + inputParams.Every(v, func(_ reflect.Value, p *paramFieldInfo) { + if p == nil || p.Loc != "query" { + return + } + + if p.Style == styleDeepObject { + deepPrefixes = append(deepPrefixes, p.Name+"[") + return + } + + knownParams[p.Name] = struct{}{} + }) + + // Validate all keys in the request. + for key := range q { + if _, ok := knownParams[key]; ok { + continue + } + + // Check it against deepPrefixes. + isDeep := false + for _, prefix := range deepPrefixes { + if strings.HasPrefix(key, prefix) { + isDeep = true + break + } + } + + if isDeep { + continue + } + + pb.Reset() + pb.Push("query") + pb.Push(key) + res.Add(pb, "", "unknown query parameter") + } + + if len(res.Errors) > 0 { + writeErr(api, ctx, &contextError{Code: http.StatusUnprocessableEntity, Msg: "validation failed", Errs: res.Errors}, *res) + return + } + } + inputParams.Every(v, func(f reflect.Value, p *paramFieldInfo) { f = reflect.Indirect(f) if f.Kind() == reflect.Invalid { diff --git a/huma_test.go b/huma_test.go index 4f10903a..d148e430 100644 --- a/huma_test.go +++ b/huma_test.go @@ -2469,6 +2469,102 @@ Content-Type: text/plain URL: "/one-of", Body: `[{"foo": "first"}, {"foo": "second"}]`, }, + { + Name: "reject-unknown-query-params", + Config: func() huma.Config { + cfg := huma.DefaultConfig("Test API", "1.0.0") + cfg.RejectUnknownQueryParameters = true + return cfg + }(), + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/test", + }, func(ctx context.Context, input *struct { + Known string `query:"known"` + }) (*struct{}, error) { + return nil, nil + }) + }, + Method: http.MethodGet, + URL: "/test?known=ok&unknown=bad", + Assert: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusUnprocessableEntity, resp.Code) + var body huma.ErrorModel + require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &body)) + assert.Equal(t, "validation failed", body.Detail) + found := false + for _, e := range body.Errors { + if e.Message == "unknown query parameter" && e.Location == "query.unknown" { + found = true + break + } + } + require.True(t, found, "expected unknown query parameter error for query.unknown; got: %v", body.Errors) + }, + }, + { + Name: "reject-unknown-query-params-deepobject-allowed", + Config: func() huma.Config { + cfg := huma.DefaultConfig("Test API", "1.0.0") + cfg.RejectUnknownQueryParameters = true + return cfg + }(), + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/test", + }, func(ctx context.Context, input *struct { + Test struct { + Int int `json:"int"` + String string `json:"string"` + } `query:"test,deepObject"` + }) (*struct{}, error) { + // Should parse and succeed, no validation error. + return nil, nil + }) + }, + Method: http.MethodGet, + URL: "/test?test[int]=1&test[string]=foo", + // No Assert: default check ensures status < 300. + }, + { + Name: "reject-unknown-query-params-deepobject-unknown", + Config: func() huma.Config { + cfg := huma.DefaultConfig("Test API", "1.0.0") + cfg.RejectUnknownQueryParameters = true + return cfg + }(), + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/test", + }, func(ctx context.Context, input *struct { + Test struct { + Int int `json:"int"` + String string `json:"string"` + } `query:"test,deepObject"` + }) (*struct{}, error) { + return nil, nil + }) + }, + Method: http.MethodGet, + URL: "/test?test[int]=1&test[string]=foo&test2[foo]=a", + Assert: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusUnprocessableEntity, resp.Code) + var body huma.ErrorModel + require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &body)) + assert.Equal(t, "validation failed", body.Detail) + found := false + for _, e := range body.Errors { + if e.Message == "unknown query parameter" && e.Location == "query.test2[foo]" { + found = true + break + } + } + require.True(t, found, "expected unknown query parameter error for query.test2[foo]; got: %v", body.Errors) + }, + }, { Name: "security-override-public", Register: func(t *testing.T, api huma.API) {