Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ type Config struct {
// for example if you need access to the path settings that may be changed
// by the user after the defaults have been set.
CreateHooks []func(Config) Config

// RejectUnknownQueryParameters determines whether to reject requests
// containing query parameters not defined in the API spec.
RejectUnknownQueryParameters bool
}

// API represents a Huma API wrapping a specific router.
Expand Down Expand Up @@ -259,6 +263,9 @@ type API interface {
// See also `huma.Operation{}.Middlewares` for adding operation-specific
// middleware at operation registration time.
Middlewares() Middlewares

// RejectUnknownQueryParameters indicates whether unknown query parameters should be rejected.
RejectUnknownQueryParameters() bool
}

// Format represents a request / response format. It is used to marshal and
Expand All @@ -272,12 +279,13 @@ type Format struct {
}

type api struct {
config Config
adapter Adapter
formats map[string]Format
formatKeys []string
transformers []Transformer
middlewares Middlewares
config Config
adapter Adapter
formats map[string]Format
formatKeys []string
transformers []Transformer
middlewares Middlewares
rejectUnknownQueryParameters bool
}

func (a *api) Adapter() Adapter {
Expand Down Expand Up @@ -354,6 +362,10 @@ func (a *api) Middlewares() Middlewares {
return a.middlewares
}

func (a *api) RejectUnknownQueryParameters() bool {
return a.rejectUnknownQueryParameters
}

// getAPIPrefix returns the API prefix from the first server URL in the OpenAPI
// spec. If no server URL is set, then an empty string is returned.
func getAPIPrefix(oapi *OpenAPI) string {
Expand Down Expand Up @@ -385,10 +397,11 @@ func NewAPI(config Config, a Adapter) API {
}

newAPI := &api{
config: config,
adapter: a,
formats: map[string]Format{},
transformers: config.Transformers,
config: config,
adapter: a,
formats: map[string]Format{},
transformers: config.Transformers,
rejectUnknownQueryParameters: config.RejectUnknownQueryParameters,
}

if config.OpenAPI == nil {
Expand Down
54 changes: 54 additions & 0 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,60 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
var cookies map[string]*http.Cookie

v := reflect.ValueOf(&input).Elem()

// Reject unknown query parameters if config is set.
if api.RejectUnknownQueryParameters() && !op.SkipValidateParams {
u := ctx.URL()
q := u.Query()

// Gather known parameters + deepObject prefixes.
knownParams := make(map[string]struct{})
var deepPrefixes []string

inputParams.Every(v, func(_ reflect.Value, p *paramFieldInfo) {
if p == nil || p.Loc != "query" {
return
}

if p.Style == styleDeepObject {
deepPrefixes = append(deepPrefixes, p.Name+"[")
return
}

knownParams[p.Name] = struct{}{}
})

// Validate all keys in the request.
for key := range q {
if _, ok := knownParams[key]; ok {
continue
}

// Check it against deepPrefixes.
isDeep := false
for _, prefix := range deepPrefixes {
if strings.HasPrefix(key, prefix) {
isDeep = true
break
}
}

if isDeep {
continue
}

pb.Reset()
pb.Push("query")
pb.Push(key)
res.Add(pb, "", "unknown query parameter")
}

if len(res.Errors) > 0 {
writeErr(api, ctx, &contextError{Code: http.StatusUnprocessableEntity, Msg: "validation failed", Errs: res.Errors}, *res)
return
}
}

inputParams.Every(v, func(f reflect.Value, p *paramFieldInfo) {
f = reflect.Indirect(f)
if f.Kind() == reflect.Invalid {
Expand Down
96 changes: 96 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2469,6 +2469,102 @@ Content-Type: text/plain
URL: "/one-of",
Body: `[{"foo": "first"}, {"foo": "second"}]`,
},
{
Name: "reject-unknown-query-params",
Config: func() huma.Config {
cfg := huma.DefaultConfig("Test API", "1.0.0")
cfg.RejectUnknownQueryParameters = true
return cfg
}(),
Register: func(t *testing.T, api huma.API) {
huma.Register(api, huma.Operation{
Method: http.MethodGet,
Path: "/test",
}, func(ctx context.Context, input *struct {
Known string `query:"known"`
}) (*struct{}, error) {
return nil, nil
})
},
Method: http.MethodGet,
URL: "/test?known=ok&unknown=bad",
Assert: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code)
var body huma.ErrorModel
require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &body))
assert.Equal(t, "validation failed", body.Detail)
found := false
for _, e := range body.Errors {
if e.Message == "unknown query parameter" && e.Location == "query.unknown" {
found = true
break
}
}
require.True(t, found, "expected unknown query parameter error for query.unknown; got: %v", body.Errors)
},
},
{
Name: "reject-unknown-query-params-deepobject-allowed",
Config: func() huma.Config {
cfg := huma.DefaultConfig("Test API", "1.0.0")
cfg.RejectUnknownQueryParameters = true
return cfg
}(),
Register: func(t *testing.T, api huma.API) {
huma.Register(api, huma.Operation{
Method: http.MethodGet,
Path: "/test",
}, func(ctx context.Context, input *struct {
Test struct {
Int int `json:"int"`
String string `json:"string"`
} `query:"test,deepObject"`
}) (*struct{}, error) {
// Should parse and succeed, no validation error.
return nil, nil
})
},
Method: http.MethodGet,
URL: "/test?test[int]=1&test[string]=foo",
// No Assert: default check ensures status < 300.
},
{
Name: "reject-unknown-query-params-deepobject-unknown",
Config: func() huma.Config {
cfg := huma.DefaultConfig("Test API", "1.0.0")
cfg.RejectUnknownQueryParameters = true
return cfg
}(),
Register: func(t *testing.T, api huma.API) {
huma.Register(api, huma.Operation{
Method: http.MethodGet,
Path: "/test",
}, func(ctx context.Context, input *struct {
Test struct {
Int int `json:"int"`
String string `json:"string"`
} `query:"test,deepObject"`
}) (*struct{}, error) {
return nil, nil
})
},
Method: http.MethodGet,
URL: "/test?test[int]=1&test[string]=foo&test2[foo]=a",
Assert: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code)
var body huma.ErrorModel
require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &body))
assert.Equal(t, "validation failed", body.Detail)
found := false
for _, e := range body.Errors {
if e.Message == "unknown query parameter" && e.Location == "query.test2[foo]" {
found = true
break
}
}
require.True(t, found, "expected unknown query parameter error for query.test2[foo]; got: %v", body.Errors)
},
},
{
Name: "security-override-public",
Register: func(t *testing.T, api huma.API) {
Expand Down
Loading