Skip to content

Commit 7f1abc4

Browse files
fix(event_handler): split OpenAPI validation to respect middleware returns (#7050)
* fix(event_handler): respect middleware response * removing duplicated code + coverage * Adding e2e tests
1 parent 4efab6c commit 7f1abc4

File tree

8 files changed

+582
-132
lines changed

8 files changed

+582
-132
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def __call__(
469469

470470
# Save CPU cycles by building middleware stack once
471471
if not self._middleware_stack_built:
472-
self._build_middleware_stack(router_middlewares=router_middlewares)
472+
self._build_middleware_stack(router_middlewares=router_middlewares, app=app)
473473

474474
# If debug is turned on then output the middleware stack to the console
475475
if app._debug:
@@ -487,7 +487,7 @@ def __call__(
487487
# Call the Middleware Wrapped _call_stack function handler with the app
488488
return self._middleware_stack(app)
489489

490-
def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]]) -> None:
490+
def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]], app) -> None:
491491
"""
492492
Builds the middleware stack for the handler by wrapping each
493493
handler in an instance of MiddlewareWrapper which is used to contain the state
@@ -505,7 +505,25 @@ def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]])
505505
The Route Middleware stack is processed in reverse order. This is so the stack of
506506
middleware handlers is applied in the order of being added to the handler.
507507
"""
508-
all_middlewares = router_middlewares + self.middlewares
508+
# Build middleware stack in the correct order for validation:
509+
# 1. Request validation middleware (first)
510+
# 2. Router middlewares + user middlewares (middle)
511+
# 3. Response validation middleware (before route handler)
512+
# 4. Route handler adapter (last)
513+
514+
all_middlewares = []
515+
516+
# Add request validation middleware first if validation is enabled
517+
if hasattr(app, "_request_validation_middleware"):
518+
all_middlewares.append(app._request_validation_middleware)
519+
520+
# Add user middlewares in the middle
521+
all_middlewares.extend(router_middlewares + self.middlewares)
522+
523+
# Add response validation middleware before the route handler if validation is enabled
524+
if hasattr(app, "_response_validation_middleware"):
525+
all_middlewares.append(app._response_validation_middleware)
526+
509527
logger.debug(f"Building middleware stack: {all_middlewares}")
510528

511529
# IMPORTANT:
@@ -1639,17 +1657,16 @@ def __init__(
16391657
self._json_body_deserializer = json_body_deserializer
16401658

16411659
if self._enable_validation:
1642-
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware
1643-
1644-
# Note the serializer argument: only use custom serializer if provided by the caller
1645-
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
1646-
self.use(
1647-
[
1648-
OpenAPIValidationMiddleware(
1649-
validation_serializer=serializer,
1650-
has_response_validation_error=self._has_response_validation_error,
1651-
),
1652-
],
1660+
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import (
1661+
OpenAPIRequestValidationMiddleware,
1662+
OpenAPIResponseValidationMiddleware,
1663+
)
1664+
1665+
# Store validation middlewares to be added in the correct order later
1666+
self._request_validation_middleware = OpenAPIRequestValidationMiddleware()
1667+
self._response_validation_middleware = OpenAPIResponseValidationMiddleware(
1668+
validation_serializer=serializer,
1669+
has_response_validation_error=self._has_response_validation_error,
16531670
)
16541671

16551672
def _validate_response_validation_error_http_code(
@@ -2668,7 +2685,7 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
26682685
route=route,
26692686
)
26702687

2671-
# OpenAPIValidationMiddleware will only raise ResponseValidationError when
2688+
# OpenAPIResponseValidationMiddleware will only raise ResponseValidationError when
26722689
# 'self._response_validation_error_http_code' is not None or
26732690
# when route has custom_response_validation_http_code
26742691
if isinstance(exp, ResponseValidationError):

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 119 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -37,56 +37,20 @@
3737
APPLICATION_FORM_CONTENT_TYPE = "application/x-www-form-urlencoded"
3838

3939

40-
class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
40+
class OpenAPIRequestValidationMiddleware(BaseMiddlewareHandler):
4141
"""
42-
OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the
43-
Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It
44-
should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`.
42+
OpenAPI request validation middleware - validates only incoming requests.
4543
46-
Example
47-
--------
48-
49-
```python
50-
from pydantic import BaseModel
51-
52-
from aws_lambda_powertools.event_handler.api_gateway import (
53-
APIGatewayRestResolver,
54-
)
55-
56-
class Todo(BaseModel):
57-
name: str
58-
59-
app = APIGatewayRestResolver(enable_validation=True)
60-
61-
@app.get("/todos")
62-
def get_todos(): list[Todo]:
63-
return [Todo(name="hello world")]
64-
```
44+
This middleware should be used first in the middleware chain to validate
45+
requests before they reach user middlewares.
6546
"""
6647

67-
def __init__(
68-
self,
69-
validation_serializer: Callable[[Any], str] | None = None,
70-
has_response_validation_error: bool = False,
71-
):
72-
"""
73-
Initialize the OpenAPIValidationMiddleware.
74-
75-
Parameters
76-
----------
77-
validation_serializer : Callable, optional
78-
Optional serializer to use when serializing the response for validation.
79-
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
80-
81-
has_response_validation_error: bool, optional
82-
Optional flag used to distinguish between payload and validation errors.
83-
By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
84-
"""
85-
self._validation_serializer = validation_serializer
86-
self._has_response_validation_error = has_response_validation_error
48+
def __init__(self):
49+
"""Initialize the request validation middleware."""
50+
pass
8751

8852
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
89-
logger.debug("OpenAPIValidationMiddleware handler")
53+
logger.debug("OpenAPIRequestValidationMiddleware handler")
9054

9155
route: Route = app.context["_route"]
9256

@@ -140,15 +104,111 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
140104
if errors:
141105
# Raise the validation errors
142106
raise RequestValidationError(_normalize_errors(errors))
107+
108+
# Re-write the route_args with the validated values
109+
app.context["_route_args"] = values
110+
111+
# Call the next middleware
112+
return next_middleware(app)
113+
114+
def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
115+
"""
116+
Get the request body from the event, and parse it according to content type.
117+
"""
118+
content_type = app.current_event.headers.get("content-type", "").strip()
119+
120+
# Handle JSON content
121+
if not content_type or content_type.startswith(APPLICATION_JSON_CONTENT_TYPE):
122+
return self._parse_json_data(app)
123+
124+
# Handle URL-encoded form data
125+
elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE):
126+
return self._parse_form_data(app)
127+
143128
else:
144-
# Re-write the route_args with the validated values, and call the next middleware
145-
app.context["_route_args"] = values
129+
raise NotImplementedError("Only JSON body or Form() are supported")
146130

147-
# Call the handler by calling the next middleware
148-
response = next_middleware(app)
131+
def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]:
132+
"""Parse JSON data from the request body."""
133+
try:
134+
return app.current_event.json_body
135+
except json.JSONDecodeError as e:
136+
raise RequestValidationError(
137+
[
138+
{
139+
"type": "json_invalid",
140+
"loc": ("body", e.pos),
141+
"msg": "JSON decode error",
142+
"input": {},
143+
"ctx": {"error": e.msg},
144+
},
145+
],
146+
body=e.doc,
147+
) from e
149148

150-
# Process the response
151-
return self._handle_response(route=route, response=response)
149+
def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]:
150+
"""Parse URL-encoded form data from the request body."""
151+
try:
152+
body = app.current_event.decoded_body or ""
153+
# parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values
154+
parsed = parse_qs(body, keep_blank_values=True)
155+
156+
result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()}
157+
return result
158+
159+
except Exception as e: # pragma: no cover
160+
raise RequestValidationError( # pragma: no cover
161+
[
162+
{
163+
"type": "form_invalid",
164+
"loc": ("body",),
165+
"msg": "Form data parsing error",
166+
"input": {},
167+
"ctx": {"error": str(e)},
168+
},
169+
],
170+
) from e
171+
172+
173+
class OpenAPIResponseValidationMiddleware(BaseMiddlewareHandler):
174+
"""
175+
OpenAPI response validation middleware - validates only outgoing responses.
176+
177+
This middleware should be used last in the middleware chain to validate
178+
responses only from route handlers, not from user middlewares.
179+
"""
180+
181+
def __init__(
182+
self,
183+
validation_serializer: Callable[[Any], str] | None = None,
184+
has_response_validation_error: bool = False,
185+
):
186+
"""
187+
Initialize the response validation middleware.
188+
189+
Parameters
190+
----------
191+
validation_serializer : Callable, optional
192+
Optional serializer to use when serializing the response for validation.
193+
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
194+
195+
has_response_validation_error: bool, optional
196+
Optional flag used to distinguish between payload and validation errors.
197+
By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
198+
"""
199+
self._validation_serializer = validation_serializer
200+
self._has_response_validation_error = has_response_validation_error
201+
202+
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
203+
logger.debug("OpenAPIResponseValidationMiddleware handler")
204+
205+
route: Route = app.context["_route"]
206+
207+
# Call the next middleware (should be the route handler)
208+
response = next_middleware(app)
209+
210+
# Process the response
211+
return self._handle_response(route=route, response=response)
152212

153213
def _handle_response(self, *, route: Route, response: Response):
154214
# Process the response body if it exists
@@ -228,85 +288,27 @@ def _prepare_response_content(
228288
"""
229289
Prepares the response content for serialization.
230290
"""
231-
if isinstance(res, BaseModel):
232-
return _model_dump(
291+
if isinstance(res, BaseModel): # pragma: no cover
292+
return _model_dump( # pragma: no cover
233293
res,
234294
by_alias=True,
235295
exclude_unset=exclude_unset,
236296
exclude_defaults=exclude_defaults,
237297
exclude_none=exclude_none,
238298
)
239-
elif isinstance(res, list):
240-
return [
299+
elif isinstance(res, list): # pragma: no cover
300+
return [ # pragma: no cover
241301
self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
242302
for item in res
243303
]
244-
elif isinstance(res, dict):
245-
return {
304+
elif isinstance(res, dict): # pragma: no cover
305+
return { # pragma: no cover
246306
k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
247307
for k, v in res.items()
248308
}
249-
elif dataclasses.is_dataclass(res):
250-
return dataclasses.asdict(res) # type: ignore[arg-type]
251-
return res
252-
253-
def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
254-
"""
255-
Get the request body from the event, and parse it according to content type.
256-
"""
257-
content_type = app.current_event.headers.get("content-type", "").strip()
258-
259-
# Handle JSON content
260-
if not content_type or content_type.startswith(APPLICATION_JSON_CONTENT_TYPE):
261-
return self._parse_json_data(app)
262-
263-
# Handle URL-encoded form data
264-
elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE):
265-
return self._parse_form_data(app)
266-
267-
else:
268-
raise NotImplementedError("Only JSON body or Form() are supported")
269-
270-
def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]:
271-
"""Parse JSON data from the request body."""
272-
try:
273-
return app.current_event.json_body
274-
except json.JSONDecodeError as e:
275-
raise RequestValidationError(
276-
[
277-
{
278-
"type": "json_invalid",
279-
"loc": ("body", e.pos),
280-
"msg": "JSON decode error",
281-
"input": {},
282-
"ctx": {"error": e.msg},
283-
},
284-
],
285-
body=e.doc,
286-
) from e
287-
288-
def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]:
289-
"""Parse URL-encoded form data from the request body."""
290-
try:
291-
body = app.current_event.decoded_body or ""
292-
# parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values
293-
parsed = parse_qs(body, keep_blank_values=True)
294-
295-
result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()}
296-
return result
297-
298-
except Exception as e: # pragma: no cover
299-
raise RequestValidationError( # pragma: no cover
300-
[
301-
{
302-
"type": "form_invalid",
303-
"loc": ("body",),
304-
"msg": "Form data parsing error",
305-
"input": {},
306-
"ctx": {"error": str(e)},
307-
},
308-
],
309-
) from e
309+
elif dataclasses.is_dataclass(res): # pragma: no cover
310+
return dataclasses.asdict(res) # type: ignore[arg-type] # pragma: no cover
311+
return res # pragma: no cover
310312

311313

312314
def _request_params_to_args(

docs/core/event_handler/api_gateway.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,20 @@ As a practical example, let's refactor our correlation ID middleware so it accep
960960
!!! note "Class-based **vs** function-based middlewares"
961961
When registering a middleware, we expect a callable in both cases. For class-based middlewares, `BaseMiddlewareHandler` is doing the work of calling your `handler` method with the correct parameters, hence why we expect an instance of it.
962962

963+
#### Middleware and data validation
964+
965+
When you enable data validation with `enable_validation=True`, we split validation into two separate middlewares:
966+
967+
1. **Request validation** runs first to validate incoming data
968+
2. **Your middlewares** run in the middle and can return early responses
969+
3. **Response validation** runs last, only for route handler responses
970+
971+
This ensures your middlewares can return early responses (401, 403, 429, etc.) without triggering validation errors, while still validating actual route handler responses for data integrity.
972+
973+
```python hl_lines="5 11 23 36" title="Middleware early returns work seamlessly with validation"
974+
--8<-- "examples/event_handler_rest/src/middleware_and_data_validation.py"
975+
```
976+
963977
#### Native middlewares
964978

965979
These are native middlewares that may become native features depending on customer demand.

0 commit comments

Comments
 (0)