Skip to content

Commit 2933225

Browse files
removing duplicated code + coverage
1 parent b90d337 commit 2933225

File tree

2 files changed

+10
-282
lines changed

2 files changed

+10
-282
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2685,7 +2685,7 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
26852685
route=route,
26862686
)
26872687

2688-
# OpenAPIValidationMiddleware will only raise ResponseValidationError when
2688+
# OpenAPIResponseValidationMiddleware will only raise ResponseValidationError when
26892689
# 'self._response_validation_error_http_code' is not None or
26902690
# when route has custom_response_validation_http_code
26912691
if isinstance(exp, ResponseValidationError):

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 9 additions & 281 deletions
Original file line numberDiff line numberDiff line change
@@ -288,299 +288,27 @@ def _prepare_response_content(
288288
"""
289289
Prepares the response content for serialization.
290290
"""
291-
if isinstance(res, BaseModel):
292-
return _model_dump(
291+
if isinstance(res, BaseModel): # pragma: no cover
292+
return _model_dump( # pragma: no cover
293293
res,
294294
by_alias=True,
295295
exclude_unset=exclude_unset,
296296
exclude_defaults=exclude_defaults,
297297
exclude_none=exclude_none,
298298
)
299-
elif isinstance(res, list):
300-
return [
299+
elif isinstance(res, list): # pragma: no cover
300+
return [ # pragma: no cover
301301
self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
302302
for item in res
303303
]
304-
elif isinstance(res, dict):
305-
return {
304+
elif isinstance(res, dict): # pragma: no cover
305+
return { # pragma: no cover
306306
k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
307307
for k, v in res.items()
308308
}
309-
elif dataclasses.is_dataclass(res):
310-
return dataclasses.asdict(res) # type: ignore[arg-type]
311-
return res
312-
313-
314-
class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
315-
"""
316-
OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the
317-
Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It
318-
should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`.
319-
320-
Example
321-
--------
322-
323-
```python
324-
from pydantic import BaseModel
325-
326-
from aws_lambda_powertools.event_handler.api_gateway import (
327-
APIGatewayRestResolver,
328-
)
329-
330-
class Todo(BaseModel):
331-
name: str
332-
333-
app = APIGatewayRestResolver(enable_validation=True)
334-
335-
@app.get("/todos")
336-
def get_todos(): list[Todo]:
337-
return [Todo(name="hello world")]
338-
```
339-
"""
340-
341-
def __init__(
342-
self,
343-
validation_serializer: Callable[[Any], str] | None = None,
344-
has_response_validation_error: bool = False,
345-
):
346-
"""
347-
Initialize the OpenAPIValidationMiddleware.
348-
349-
Parameters
350-
----------
351-
validation_serializer : Callable, optional
352-
Optional serializer to use when serializing the response for validation.
353-
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
354-
355-
has_response_validation_error: bool, optional
356-
Optional flag used to distinguish between payload and validation errors.
357-
By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
358-
"""
359-
self._validation_serializer = validation_serializer
360-
self._has_response_validation_error = has_response_validation_error
361-
362-
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
363-
logger.debug("OpenAPIValidationMiddleware handler")
364-
365-
route: Route = app.context["_route"]
366-
367-
values: dict[str, Any] = {}
368-
errors: list[Any] = []
369-
370-
# Process path values, which can be found on the route_args
371-
path_values, path_errors = _request_params_to_args(
372-
route.dependant.path_params,
373-
app.context["_route_args"],
374-
)
375-
376-
# Normalize query values before validate this
377-
query_string = _normalize_multi_query_string_with_param(
378-
app.current_event.resolved_query_string_parameters,
379-
route.dependant.query_params,
380-
)
381-
382-
# Process query values
383-
query_values, query_errors = _request_params_to_args(
384-
route.dependant.query_params,
385-
query_string,
386-
)
387-
388-
# Normalize header values before validate this
389-
headers = _normalize_multi_header_values_with_param(
390-
app.current_event.resolved_headers_field,
391-
route.dependant.header_params,
392-
)
393-
394-
# Process header values
395-
header_values, header_errors = _request_params_to_args(
396-
route.dependant.header_params,
397-
headers,
398-
)
399-
400-
values.update(path_values)
401-
values.update(query_values)
402-
values.update(header_values)
403-
errors += path_errors + query_errors + header_errors
404-
405-
# Process the request body, if it exists
406-
if route.dependant.body_params:
407-
(body_values, body_errors) = _request_body_to_args(
408-
required_params=route.dependant.body_params,
409-
received_body=self._get_body(app),
410-
)
411-
values.update(body_values)
412-
errors.extend(body_errors)
413-
414-
if errors:
415-
# Raise the validation errors
416-
raise RequestValidationError(_normalize_errors(errors))
417-
else:
418-
# Re-write the route_args with the validated values, and call the next middleware
419-
app.context["_route_args"] = values
420-
421-
# Call the handler by calling the next middleware
422-
response = next_middleware(app)
423-
424-
# Process the response
425-
return self._handle_response(route=route, response=response)
426-
427-
def _handle_response(self, *, route: Route, response: Response):
428-
# Process the response body if it exists
429-
if response.body and response.is_json():
430-
response.body = self._serialize_response(
431-
field=route.dependant.return_param,
432-
response_content=response.body,
433-
has_route_custom_response_validation=route.custom_response_validation_http_code is not None,
434-
)
435-
436-
return response
437-
438-
def _serialize_response(
439-
self,
440-
*,
441-
field: ModelField | None = None,
442-
response_content: Any,
443-
include: IncEx | None = None,
444-
exclude: IncEx | None = None,
445-
by_alias: bool = True,
446-
exclude_unset: bool = False,
447-
exclude_defaults: bool = False,
448-
exclude_none: bool = False,
449-
has_route_custom_response_validation: bool = False,
450-
) -> Any:
451-
"""
452-
Serialize the response content according to the field type.
453-
"""
454-
if field:
455-
errors: list[dict[str, Any]] = []
456-
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
457-
if errors:
458-
# route-level validation must take precedence over app-level
459-
if has_route_custom_response_validation:
460-
raise ResponseValidationError(
461-
errors=_normalize_errors(errors),
462-
body=response_content,
463-
source="route",
464-
)
465-
if self._has_response_validation_error:
466-
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")
467-
468-
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
469-
470-
if hasattr(field, "serialize"):
471-
return field.serialize(
472-
value,
473-
include=include,
474-
exclude=exclude,
475-
by_alias=by_alias,
476-
exclude_unset=exclude_unset,
477-
exclude_defaults=exclude_defaults,
478-
exclude_none=exclude_none,
479-
)
480-
return jsonable_encoder(
481-
value,
482-
include=include,
483-
exclude=exclude,
484-
by_alias=by_alias,
485-
exclude_unset=exclude_unset,
486-
exclude_defaults=exclude_defaults,
487-
exclude_none=exclude_none,
488-
custom_serializer=self._validation_serializer,
489-
)
490-
else:
491-
# Just serialize the response content returned from the handler.
492-
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)
493-
494-
def _prepare_response_content(
495-
self,
496-
res: Any,
497-
*,
498-
exclude_unset: bool,
499-
exclude_defaults: bool = False,
500-
exclude_none: bool = False,
501-
) -> Any:
502-
"""
503-
Prepares the response content for serialization.
504-
"""
505-
if isinstance(res, BaseModel):
506-
return _model_dump(
507-
res,
508-
by_alias=True,
509-
exclude_unset=exclude_unset,
510-
exclude_defaults=exclude_defaults,
511-
exclude_none=exclude_none,
512-
)
513-
elif isinstance(res, list):
514-
return [
515-
self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
516-
for item in res
517-
]
518-
elif isinstance(res, dict):
519-
return {
520-
k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
521-
for k, v in res.items()
522-
}
523-
elif dataclasses.is_dataclass(res):
524-
return dataclasses.asdict(res) # type: ignore[arg-type]
525-
return res
526-
527-
def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
528-
"""
529-
Get the request body from the event, and parse it according to content type.
530-
"""
531-
content_type = app.current_event.headers.get("content-type", "").strip()
532-
533-
# Handle JSON content
534-
if not content_type or content_type.startswith(APPLICATION_JSON_CONTENT_TYPE):
535-
return self._parse_json_data(app)
536-
537-
# Handle URL-encoded form data
538-
elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE):
539-
return self._parse_form_data(app)
540-
541-
else:
542-
raise NotImplementedError("Only JSON body or Form() are supported")
543-
544-
def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]:
545-
"""Parse JSON data from the request body."""
546-
try:
547-
return app.current_event.json_body
548-
except json.JSONDecodeError as e:
549-
raise RequestValidationError(
550-
[
551-
{
552-
"type": "json_invalid",
553-
"loc": ("body", e.pos),
554-
"msg": "JSON decode error",
555-
"input": {},
556-
"ctx": {"error": e.msg},
557-
},
558-
],
559-
body=e.doc,
560-
) from e
561-
562-
def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]:
563-
"""Parse URL-encoded form data from the request body."""
564-
try:
565-
body = app.current_event.decoded_body or ""
566-
# parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values
567-
parsed = parse_qs(body, keep_blank_values=True)
568-
569-
result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()}
570-
return result
571-
572-
except Exception as e: # pragma: no cover
573-
raise RequestValidationError( # pragma: no cover
574-
[
575-
{
576-
"type": "form_invalid",
577-
"loc": ("body",),
578-
"msg": "Form data parsing error",
579-
"input": {},
580-
"ctx": {"error": str(e)},
581-
},
582-
],
583-
) from e
309+
elif dataclasses.is_dataclass(res): # pragma: no cover
310+
return dataclasses.asdict(res) # type: ignore[arg-type] # pragma: no cover
311+
return res # pragma: no cover
584312

585313

586314
def _request_params_to_args(

0 commit comments

Comments
 (0)