Skip to content

Commit c2659aa

Browse files
committed
fix: reduce cognitive complexity
1 parent 9c056fe commit c2659aa

File tree

1 file changed

+54
-38
lines changed

1 file changed

+54
-38
lines changed

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import dataclasses
44
import json
55
import logging
6-
from copy import deepcopy
76
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
87
from urllib.parse import parse_qs
98

@@ -314,12 +313,12 @@ def _prepare_response_content(
314313
def _request_params_to_args(
315314
required_params: Sequence[ModelField],
316315
received_params: Mapping[str, Any],
317-
) -> tuple[dict[str, Any], list[Any]]:
316+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
318317
"""
319318
Convert the request params to a dictionary of values using validation, and returns a list of errors.
320319
"""
321-
values = {}
322-
errors = []
320+
values: dict[str, Any] = {}
321+
errors: list[dict[str, Any]] = []
323322

324323
for field in required_params:
325324
field_info = field.field_info
@@ -328,16 +327,12 @@ def _request_params_to_args(
328327
if not isinstance(field_info, Param):
329328
raise AssertionError(f"Expected Param field_info, got {field_info}")
330329

331-
value = received_params.get(field.alias)
332-
333330
loc = (field_info.in_.value, field.alias)
331+
value = received_params.get(field.alias)
334332

335333
# If we don't have a value, see if it's required or has a default
336334
if value is None:
337-
if field.required:
338-
errors.append(get_missing_field_error(loc=loc))
339-
else:
340-
values[field.name] = deepcopy(field.default)
335+
_handle_missing_field_value(field, values, errors, loc)
341336
continue
342337

343338
# Finally, validate the value
@@ -363,43 +358,64 @@ def _request_body_to_args(
363358
)
364359

365360
for field in required_params:
366-
# This sets the location to:
367-
# { "user": { object } } if field.alias == user
368-
# { { object } if field_alias is omitted
369-
loc: tuple[str, ...] = ("body", field.alias)
370-
if field_alias_omitted:
371-
loc = ("body",)
372-
373-
value: Any | None = None
361+
loc = _get_body_field_location(field, field_alias_omitted)
362+
value = _extract_field_value_from_body(field, received_body, loc, errors)
374363

375-
# Now that we know what to look for, try to get the value from the received body
376-
if received_body is not None:
377-
try:
378-
value = received_body.get(field.alias)
379-
except AttributeError:
380-
errors.append(get_missing_field_error(loc))
381-
continue
382-
383-
# Determine if the field is required
364+
# If we don't have a value, see if it's required or has a default
384365
if value is None:
385-
if field.required:
386-
errors.append(get_missing_field_error(loc))
387-
else:
388-
values[field.name] = deepcopy(field.default)
366+
_handle_missing_field_value(field, values, errors, loc)
389367
continue
390368

391-
# Normalize lists for non-sequence fields
392-
if isinstance(value, list) and not is_sequence_field(field):
393-
value = value[0]
394-
395-
# MAINTENANCE: Handle byte and file fields
396-
397-
# Finally, validate the value
369+
value = _normalize_field_value(field, value)
398370
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
399371

400372
return values, errors
401373

402374

375+
def _get_body_field_location(field: ModelField, field_alias_omitted: bool) -> tuple[str, ...]:
376+
"""Get the location tuple for a body field based on whether the field alias is omitted."""
377+
if field_alias_omitted:
378+
return ("body",)
379+
return ("body", field.alias)
380+
381+
382+
def _extract_field_value_from_body(
383+
field: ModelField,
384+
received_body: dict[str, Any] | None,
385+
loc: tuple[str, ...],
386+
errors: list[dict[str, Any]],
387+
) -> Any | None:
388+
"""Extract field value from the received body, handling potential AttributeError."""
389+
if received_body is None:
390+
return None
391+
392+
try:
393+
return received_body.get(field.alias)
394+
except AttributeError:
395+
errors.append(get_missing_field_error(loc))
396+
return None
397+
398+
399+
def _handle_missing_field_value(
400+
field: ModelField,
401+
values: dict[str, Any],
402+
errors: list[dict[str, Any]],
403+
loc: tuple[str, ...],
404+
) -> None:
405+
"""Handle the case when a field value is missing."""
406+
if field.required:
407+
errors.append(get_missing_field_error(loc))
408+
else:
409+
values[field.name] = field.get_default()
410+
411+
412+
def _normalize_field_value(field: ModelField, value: Any) -> Any:
413+
"""Normalize field value, converting lists to single values for non-sequence fields."""
414+
if isinstance(value, list) and not is_sequence_field(field):
415+
return value[0]
416+
return value
417+
418+
403419
def _validate_field(
404420
*,
405421
field: ModelField,

0 commit comments

Comments
 (0)