Skip to content

Commit 34beb4d

Browse files
committed
support validate_by_name for pydantic BaseModels
1 parent 3f8c1ba commit 34beb4d

File tree

4 files changed

+39
-74
lines changed

4 files changed

+39
-74
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -889,13 +889,18 @@ def _get_basic_type_schema(param_type: type) -> dict[str, str]:
889889
"""
890890
Get basic OpenAPI schema for simple types
891891
"""
892-
if isinstance(int, param_type):
893-
return {"type": "integer"}
894-
elif isinstance(float, param_type):
895-
return {"type": "number"}
896-
elif isinstance(bool, param_type):
897-
return {"type": "boolean"}
898-
else:
892+
try:
893+
# Check bool before int, since bool is a subclass of int in Python
894+
if issubclass(param_type, bool):
895+
return {"type": "boolean"}
896+
elif issubclass(param_type, int):
897+
return {"type": "integer"}
898+
elif issubclass(param_type, float):
899+
return {"type": "number"}
900+
else:
901+
return {"type": "string"}
902+
except TypeError:
903+
# param_type may not be a type (e.g., typing.Optional[int]), fallback to string
899904
return {"type": "string"}
900905

901906
@staticmethod

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 17 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
88
from urllib.parse import parse_qs
99

10-
from pydantic import BaseModel, ValidationError
10+
from pydantic import BaseModel
1111

1212
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler
1313
from aws_lambda_powertools.event_handler.openapi.compat import (
@@ -69,8 +69,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
6969
route.dependant.query_params,
7070
)
7171

72-
# Process query values (with Pydantic model support)
73-
query_values, query_errors = _request_params_to_args_with_pydantic_support(
72+
# Process query values
73+
query_values, query_errors = _request_params_to_args(
7474
route.dependant.query_params,
7575
query_string,
7676
)
@@ -81,8 +81,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
8181
route.dependant.header_params,
8282
)
8383

84-
# Process header values (with Pydantic model support)
85-
header_values, header_errors = _request_params_to_args_with_pydantic_support(
84+
# Process header values
85+
header_values, header_errors = _request_params_to_args(
8686
route.dependant.header_params,
8787
headers,
8888
)
@@ -311,7 +311,7 @@ def _prepare_response_content(
311311
return res # pragma: no cover
312312

313313

314-
def _request_params_to_args_with_pydantic_support(
314+
def _request_params_to_args(
315315
required_params: Sequence[ModelField],
316316
received_params: Mapping[str, Any],
317317
) -> tuple[dict[str, Any], list[Any]]:
@@ -330,71 +330,24 @@ def _request_params_to_args_with_pydantic_support(
330330
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
331331

332332
if isinstance(field_info, (Query, Header)) and lenient_issubclass(field_info.annotation, BaseModel):
333-
# Handle Pydantic model - use the same approach as _request_body_to_args
334-
loc = (field_info.in_.value, field.alias)
335-
336-
# Get the raw data for the Pydantic model
337-
value = received_params.get(field.alias)
338-
339-
if value is None:
340-
if field.required:
341-
errors.append(get_missing_field_error(loc))
342-
else:
343-
values[field.name] = deepcopy(field.default)
344-
continue
345-
333+
pass
334+
elif isinstance(field_info, Param):
335+
pass
346336
else:
347-
# Regular parameter processing (existing logic)
348-
if not isinstance(field_info, Param):
349-
raise AssertionError(f"Expected Param field_info, got {field_info}")
350-
351-
value = received_params.get(field.alias)
352-
loc = (field_info.in_.value, field.alias)
353-
354-
if value is None:
355-
if field.required:
356-
errors.append(get_missing_field_error(loc=loc))
357-
else:
358-
values[field.name] = deepcopy(field.default)
359-
continue
360-
361-
# Use _validate_field like _request_body_to_args does
362-
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
363-
return values, errors
364-
365-
366-
def _request_params_to_args(
367-
required_params: Sequence[ModelField],
368-
received_params: Mapping[str, Any],
369-
) -> tuple[dict[str, Any], list[Any]]:
370-
"""
371-
Convert the request params to a dictionary of values using validation, and returns a list of errors.
372-
"""
373-
values = {}
374-
errors = []
375-
376-
for field in required_params:
377-
field_info = field.field_info
378-
379-
# To ensure early failure, we check if it's not an instance of Param.
380-
if not isinstance(field_info, Param):
381337
raise AssertionError(f"Expected Param field_info, got {field_info}")
382338

383339
value = received_params.get(field.alias)
384-
385340
loc = (field_info.in_.value, field.alias)
386341

387-
# If we don't have a value, see if it's required or has a default
388342
if value is None:
389343
if field.required:
390344
errors.append(get_missing_field_error(loc=loc))
391345
else:
392346
values[field.name] = deepcopy(field.default)
393347
continue
394348

395-
# Finally, validate the value
349+
# Use _validate_field like _request_body_to_args does
396350
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
397-
398351
return values, errors
399352

400353

@@ -529,7 +482,13 @@ def _normalize_multi_query_string_with_param(
529482
try:
530483
model_data[field_alias] = query_string[field_alias][0]
531484
except KeyError:
532-
pass
485+
if model_class.model_config.get("validate_by_name") or model_class.model_config.get(
486+
"populate_by_name",
487+
):
488+
try:
489+
model_data[field_alias] = query_string[field_name][0]
490+
except KeyError:
491+
pass
533492

534493
# Store the collected data under the param alias
535494
resolved_query_string[param.alias] = model_data

tests/functional/event_handler/_pydantic/test_openapi_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_openapi_pydantic_header_params():
8080
class HeaderParams(BaseModel):
8181
authorization: str = Field(description="Authorization token")
8282
user_agent: str = Field(default="PowerTools/1.0", description="User agent")
83-
accept_language: Optional[str] = Field(default=None, alias="accept-language", description="Language preference")
83+
language: Optional[str] = Field(default=None, alias="accept-language", description="Language preference")
8484

8585
@app.get("/protected")
8686
def protected_handler(headers: Annotated[HeaderParams, Header()]):
@@ -101,7 +101,7 @@ def protected_handler(headers: Annotated[HeaderParams, Header()]):
101101
# Check individual parameters
102102
param_names = [param.name for param in get_operation.parameters]
103103
assert "authorization" in param_names
104-
assert "user_agent" in param_names
104+
assert "user-agent" in param_names # headers are always spinal-case
105105
assert "accept-language" in param_names # Should use alias
106106

107107
# Check parameter details

tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from enum import Enum
55
from pathlib import PurePath
6-
from typing import List, Optional, Tuple
6+
from typing import Dict, List, Optional, Tuple
77

88
import pytest
99
from pydantic import BaseModel, Field
@@ -2265,7 +2265,7 @@ def test_validate_pydantic_query_params_with_config_dict_and_validators(gw_event
22652265
"""Test that Pydantic models with ConfigDict, aliases, and validators work correctly"""
22662266
from typing import Any
22672267

2268-
from pydantic import UUID4, AfterValidator, Base64UrlStr, ConfigDict, StringConstraints, alias_generators
2268+
from pydantic import AfterValidator, Base64UrlStr, ConfigDict, StringConstraints, alias_generators
22692269

22702270
del gw_event["multiValueHeaders"]
22712271
del gw_event["multiValueQueryStringParameters"]
@@ -2284,7 +2284,7 @@ class QuerySimple(BaseModel):
22842284
search_id: str
22852285

22862286
@app.get("/query-model-simple")
2287-
def query_model(params: Annotated[QuerySimple, Query()]) -> dict[str, Any]:
2287+
def query_model(params: Annotated[QuerySimple, Query()]) -> Dict[str, Any]:
22882288
return {
22892289
"fullName": params.full_name,
22902290
"nextToken": params.next_token,
@@ -2299,11 +2299,12 @@ class QueryAdvanced(BaseModel):
22992299
model_config = ConfigDict(
23002300
alias_generator=alias_generators.to_camel,
23012301
validate_by_alias=True,
2302+
validate_by_name=True,
23022303
serialize_by_alias=True,
23032304
)
23042305

23052306
@app.get("/query-model-advanced")
2306-
def query_model_advanced(params: Annotated[QueryAdvanced, Query()]) -> dict[str, Any]:
2307+
def query_model_advanced(params: Annotated[QueryAdvanced, Query()]) -> Dict[str, Any]:
23072308
return params.model_dump()
23082309

23092310
# Test QuerySimple with validators
@@ -2353,11 +2354,11 @@ def query_model_advanced(params: Annotated[QueryAdvanced, Query()]) -> dict[str,
23532354
assert body["nextToken"] == "dGVzdA=="
23542355
assert body["id"] == "search-456"
23552356

2356-
# Test QueryAdvanced with snake_case field names (should also work due to populate_by_name behavior)
2357+
# Test QueryAdvanced with snake_case field names due to validate_by_name=True
23572358
gw_event["queryStringParameters"] = {
23582359
"full_name": "Snake Case Test", # snake_case field name
23592360
"next_token": "dGVzdA==", # snake_case field name
2360-
"id": "search-789", # explicit alias
2361+
"search_id": "search-789", # snake_case field name
23612362
}
23622363

23632364
gw_event["path"] = "/query-model-advanced"
@@ -2366,7 +2367,7 @@ def query_model_advanced(params: Annotated[QueryAdvanced, Query()]) -> dict[str,
23662367

23672368
body = json.loads(result["body"])
23682369
assert body["fullName"] == "Snake Case Test"
2369-
assert body["nextToken"] == "token789"
2370+
assert body["nextToken"] == "dGVzdA=="
23702371
assert body["id"] == "search-789"
23712372

23722373
# Test QueryAdvanced validation error (full_name too short)

0 commit comments

Comments
 (0)