Skip to content

Commit b90d337

Browse files
fix(event_handler): respect middleware response
1 parent 5999346 commit b90d337

File tree

5 files changed

+690
-14
lines changed

5 files changed

+690
-14
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def __call__(
469469

470470
# Save CPU cycles by building middleware stack once
471471
if not self._middleware_stack_built:
472-
self._build_middleware_stack(router_middlewares=router_middlewares)
472+
self._build_middleware_stack(router_middlewares=router_middlewares, app=app)
473473

474474
# If debug is turned on then output the middleware stack to the console
475475
if app._debug:
@@ -487,7 +487,7 @@ def __call__(
487487
# Call the Middleware Wrapped _call_stack function handler with the app
488488
return self._middleware_stack(app)
489489

490-
def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]]) -> None:
490+
def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]], app) -> None:
491491
"""
492492
Builds the middleware stack for the handler by wrapping each
493493
handler in an instance of MiddlewareWrapper which is used to contain the state
@@ -505,7 +505,25 @@ def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]])
505505
The Route Middleware stack is processed in reverse order. This is so the stack of
506506
middleware handlers is applied in the order of being added to the handler.
507507
"""
508-
all_middlewares = router_middlewares + self.middlewares
508+
# Build middleware stack in the correct order for validation:
509+
# 1. Request validation middleware (first)
510+
# 2. Router middlewares + user middlewares (middle)
511+
# 3. Response validation middleware (before route handler)
512+
# 4. Route handler adapter (last)
513+
514+
all_middlewares = []
515+
516+
# Add request validation middleware first if validation is enabled
517+
if hasattr(app, "_request_validation_middleware"):
518+
all_middlewares.append(app._request_validation_middleware)
519+
520+
# Add user middlewares in the middle
521+
all_middlewares.extend(router_middlewares + self.middlewares)
522+
523+
# Add response validation middleware before the route handler if validation is enabled
524+
if hasattr(app, "_response_validation_middleware"):
525+
all_middlewares.append(app._response_validation_middleware)
526+
509527
logger.debug(f"Building middleware stack: {all_middlewares}")
510528

511529
# IMPORTANT:
@@ -1639,17 +1657,16 @@ def __init__(
16391657
self._json_body_deserializer = json_body_deserializer
16401658

16411659
if self._enable_validation:
1642-
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware
1643-
1644-
# Note the serializer argument: only use custom serializer if provided by the caller
1645-
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
1646-
self.use(
1647-
[
1648-
OpenAPIValidationMiddleware(
1649-
validation_serializer=serializer,
1650-
has_response_validation_error=self._has_response_validation_error,
1651-
),
1652-
],
1660+
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import (
1661+
OpenAPIRequestValidationMiddleware,
1662+
OpenAPIResponseValidationMiddleware,
1663+
)
1664+
1665+
# Store validation middlewares to be added in the correct order later
1666+
self._request_validation_middleware = OpenAPIRequestValidationMiddleware()
1667+
self._response_validation_middleware = OpenAPIResponseValidationMiddleware(
1668+
validation_serializer=serializer,
1669+
has_response_validation_error=self._has_response_validation_error,
16531670
)
16541671

16551672
def _validate_response_validation_error_http_code(

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,280 @@
3737
APPLICATION_FORM_CONTENT_TYPE = "application/x-www-form-urlencoded"
3838

3939

40+
class OpenAPIRequestValidationMiddleware(BaseMiddlewareHandler):
41+
"""
42+
OpenAPI request validation middleware - validates only incoming requests.
43+
44+
This middleware should be used first in the middleware chain to validate
45+
requests before they reach user middlewares.
46+
"""
47+
48+
def __init__(self):
49+
"""Initialize the request validation middleware."""
50+
pass
51+
52+
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
53+
logger.debug("OpenAPIRequestValidationMiddleware handler")
54+
55+
route: Route = app.context["_route"]
56+
57+
values: dict[str, Any] = {}
58+
errors: list[Any] = []
59+
60+
# Process path values, which can be found on the route_args
61+
path_values, path_errors = _request_params_to_args(
62+
route.dependant.path_params,
63+
app.context["_route_args"],
64+
)
65+
66+
# Normalize query values before validate this
67+
query_string = _normalize_multi_query_string_with_param(
68+
app.current_event.resolved_query_string_parameters,
69+
route.dependant.query_params,
70+
)
71+
72+
# Process query values
73+
query_values, query_errors = _request_params_to_args(
74+
route.dependant.query_params,
75+
query_string,
76+
)
77+
78+
# Normalize header values before validate this
79+
headers = _normalize_multi_header_values_with_param(
80+
app.current_event.resolved_headers_field,
81+
route.dependant.header_params,
82+
)
83+
84+
# Process header values
85+
header_values, header_errors = _request_params_to_args(
86+
route.dependant.header_params,
87+
headers,
88+
)
89+
90+
values.update(path_values)
91+
values.update(query_values)
92+
values.update(header_values)
93+
errors += path_errors + query_errors + header_errors
94+
95+
# Process the request body, if it exists
96+
if route.dependant.body_params:
97+
(body_values, body_errors) = _request_body_to_args(
98+
required_params=route.dependant.body_params,
99+
received_body=self._get_body(app),
100+
)
101+
values.update(body_values)
102+
errors.extend(body_errors)
103+
104+
if errors:
105+
# Raise the validation errors
106+
raise RequestValidationError(_normalize_errors(errors))
107+
108+
# Re-write the route_args with the validated values
109+
app.context["_route_args"] = values
110+
111+
# Call the next middleware
112+
return next_middleware(app)
113+
114+
def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
115+
"""
116+
Get the request body from the event, and parse it according to content type.
117+
"""
118+
content_type = app.current_event.headers.get("content-type", "").strip()
119+
120+
# Handle JSON content
121+
if not content_type or content_type.startswith(APPLICATION_JSON_CONTENT_TYPE):
122+
return self._parse_json_data(app)
123+
124+
# Handle URL-encoded form data
125+
elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE):
126+
return self._parse_form_data(app)
127+
128+
else:
129+
raise NotImplementedError("Only JSON body or Form() are supported")
130+
131+
def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]:
132+
"""Parse JSON data from the request body."""
133+
try:
134+
return app.current_event.json_body
135+
except json.JSONDecodeError as e:
136+
raise RequestValidationError(
137+
[
138+
{
139+
"type": "json_invalid",
140+
"loc": ("body", e.pos),
141+
"msg": "JSON decode error",
142+
"input": {},
143+
"ctx": {"error": e.msg},
144+
},
145+
],
146+
body=e.doc,
147+
) from e
148+
149+
def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]:
150+
"""Parse URL-encoded form data from the request body."""
151+
try:
152+
body = app.current_event.decoded_body or ""
153+
# parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values
154+
parsed = parse_qs(body, keep_blank_values=True)
155+
156+
result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()}
157+
return result
158+
159+
except Exception as e: # pragma: no cover
160+
raise RequestValidationError( # pragma: no cover
161+
[
162+
{
163+
"type": "form_invalid",
164+
"loc": ("body",),
165+
"msg": "Form data parsing error",
166+
"input": {},
167+
"ctx": {"error": str(e)},
168+
},
169+
],
170+
) from e
171+
172+
173+
class OpenAPIResponseValidationMiddleware(BaseMiddlewareHandler):
174+
"""
175+
OpenAPI response validation middleware - validates only outgoing responses.
176+
177+
This middleware should be used last in the middleware chain to validate
178+
responses only from route handlers, not from user middlewares.
179+
"""
180+
181+
def __init__(
182+
self,
183+
validation_serializer: Callable[[Any], str] | None = None,
184+
has_response_validation_error: bool = False,
185+
):
186+
"""
187+
Initialize the response validation middleware.
188+
189+
Parameters
190+
----------
191+
validation_serializer : Callable, optional
192+
Optional serializer to use when serializing the response for validation.
193+
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
194+
195+
has_response_validation_error: bool, optional
196+
Optional flag used to distinguish between payload and validation errors.
197+
By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
198+
"""
199+
self._validation_serializer = validation_serializer
200+
self._has_response_validation_error = has_response_validation_error
201+
202+
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
203+
logger.debug("OpenAPIResponseValidationMiddleware handler")
204+
205+
route: Route = app.context["_route"]
206+
207+
# Call the next middleware (should be the route handler)
208+
response = next_middleware(app)
209+
210+
# Process the response
211+
return self._handle_response(route=route, response=response)
212+
213+
def _handle_response(self, *, route: Route, response: Response):
214+
# Process the response body if it exists
215+
if response.body and response.is_json():
216+
response.body = self._serialize_response(
217+
field=route.dependant.return_param,
218+
response_content=response.body,
219+
has_route_custom_response_validation=route.custom_response_validation_http_code is not None,
220+
)
221+
222+
return response
223+
224+
def _serialize_response(
225+
self,
226+
*,
227+
field: ModelField | None = None,
228+
response_content: Any,
229+
include: IncEx | None = None,
230+
exclude: IncEx | None = None,
231+
by_alias: bool = True,
232+
exclude_unset: bool = False,
233+
exclude_defaults: bool = False,
234+
exclude_none: bool = False,
235+
has_route_custom_response_validation: bool = False,
236+
) -> Any:
237+
"""
238+
Serialize the response content according to the field type.
239+
"""
240+
if field:
241+
errors: list[dict[str, Any]] = []
242+
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
243+
if errors:
244+
# route-level validation must take precedence over app-level
245+
if has_route_custom_response_validation:
246+
raise ResponseValidationError(
247+
errors=_normalize_errors(errors),
248+
body=response_content,
249+
source="route",
250+
)
251+
if self._has_response_validation_error:
252+
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")
253+
254+
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
255+
256+
if hasattr(field, "serialize"):
257+
return field.serialize(
258+
value,
259+
include=include,
260+
exclude=exclude,
261+
by_alias=by_alias,
262+
exclude_unset=exclude_unset,
263+
exclude_defaults=exclude_defaults,
264+
exclude_none=exclude_none,
265+
)
266+
return jsonable_encoder(
267+
value,
268+
include=include,
269+
exclude=exclude,
270+
by_alias=by_alias,
271+
exclude_unset=exclude_unset,
272+
exclude_defaults=exclude_defaults,
273+
exclude_none=exclude_none,
274+
custom_serializer=self._validation_serializer,
275+
)
276+
else:
277+
# Just serialize the response content returned from the handler.
278+
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)
279+
280+
def _prepare_response_content(
281+
self,
282+
res: Any,
283+
*,
284+
exclude_unset: bool,
285+
exclude_defaults: bool = False,
286+
exclude_none: bool = False,
287+
) -> Any:
288+
"""
289+
Prepares the response content for serialization.
290+
"""
291+
if isinstance(res, BaseModel):
292+
return _model_dump(
293+
res,
294+
by_alias=True,
295+
exclude_unset=exclude_unset,
296+
exclude_defaults=exclude_defaults,
297+
exclude_none=exclude_none,
298+
)
299+
elif isinstance(res, list):
300+
return [
301+
self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
302+
for item in res
303+
]
304+
elif isinstance(res, dict):
305+
return {
306+
k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
307+
for k, v in res.items()
308+
}
309+
elif dataclasses.is_dataclass(res):
310+
return dataclasses.asdict(res) # type: ignore[arg-type]
311+
return res
312+
313+
40314
class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
41315
"""
42316
OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the

docs/core/event_handler/api_gateway.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,20 @@ As a practical example, let's refactor our correlation ID middleware so it accep
928928
!!! note "Class-based **vs** function-based middlewares"
929929
When registering a middleware, we expect a callable in both cases. For class-based middlewares, `BaseMiddlewareHandler` is doing the work of calling your `handler` method with the correct parameters, hence why we expect an instance of it.
930930

931+
#### Middleware and data validation
932+
933+
When you enable data validation with `enable_validation=True`, we split validation into two separate middlewares:
934+
935+
1. **Request validation** runs first to validate incoming data
936+
2. **Your middlewares** run in the middle and can return early responses
937+
3. **Response validation** runs last, only for route handler responses
938+
939+
This ensures your middlewares can return early responses (401, 403, 429, etc.) without triggering validation errors, while still validating actual route handler responses for data integrity.
940+
941+
```python hl_lines="5 11 23 36" title="Middleware early returns work seamlessly with validation"
942+
--8<-- "examples/event_handler_rest/src/middleware_and_data_validation.py"
943+
```
944+
931945
#### Native middlewares
932946

933947
These are native middlewares that may become native features depending on customer demand.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
3+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
4+
from aws_lambda_powertools.event_handler.middlewares import NextMiddleware
5+
6+
app = APIGatewayRestResolver(enable_validation=True)
7+
8+
9+
def auth_middleware(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
10+
# This 401 response won't trigger validation errors
11+
return Response(status_code=401, content_type="application/json", body="{}")
12+
13+
14+
app.use(middlewares=[auth_middleware])
15+
16+
17+
@app.get("/protected")
18+
def protected_route() -> list[str]:
19+
# Only this response will be validated against OpenAPI schema
20+
return ["protected", "route"]

0 commit comments

Comments
 (0)