Skip to content

feat(event_handler): Cookie parameter support in OpenAPI #7165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
Cookie,
Dependant,
Form,
Header,
Expand Down Expand Up @@ -275,7 +276,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
return False
elif is_scalar_field(field=param_field):
return False
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
elif isinstance(param_field.field_info, (Query, Header, Cookie)) and is_scalar_sequence_field(param_field):
return False
else:
if not isinstance(param_field.field_info, Body):
Expand Down
129 changes: 129 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,135 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"


class Cookie(Param):
"""
A class used internally to represent a cookie parameter in the cookie header.
"""

in_ = ParamTypes.cookie

def __init__(
self,
default: Any = _Unset,
*,
default_factory: Callable[[], Any] | None = _Unset,
annotation: Any | None = None,
alias: str | None = None,
alias_priority: int | None = _Unset,
validation_alias: str | None = None,
serialization_alias: str | None = None,
title: str | None = None,
description: str | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
le: float | None = None,
min_length: int | None = None,
max_length: int | None = None,
pattern: str | None = None,
discriminator: str | None = None,
strict: bool | None = _Unset,
multiple_of: float | None = _Unset,
allow_inf_nan: bool | None = _Unset,
max_digits: int | None = _Unset,
decimal_places: int | None = _Unset,
examples: list[Any] | None = None,
openapi_examples: dict[str, Example] | None = None,
deprecated: bool | None = None,
include_in_schema: bool = True,
json_schema_extra: dict[str, Any] | None = None,
**extra: Any,
):
"""
Constructs a new Query param.

Parameters
----------
default: Any
The default value of the parameter
default_factory: Callable[[], Any], optional
Callable that will be called when a default value is needed for this field
annotation: Any, optional
The type annotation of the parameter
alias: str, optional
The public name of the field
alias_priority: int, optional
Priority of the alias. This affects whether an alias generator is used
validation_alias: str | AliasPath | AliasChoices | None, optional
Alias to be used for validation only
serialization_alias: str | AliasPath | AliasChoices | None, optional
Alias to be used for serialization only
title: str, optional
The title of the parameter
description: str, optional
The description of the parameter
gt: float, optional
Only applies to numbers, required the field to be "greater than"
ge: float, optional
Only applies to numbers, required the field to be "greater than or equal"
lt: float, optional
Only applies to numbers, required the field to be "less than"
le: float, optional
Only applies to numbers, required the field to be "less than or equal"
min_length: int, optional
Only applies to strings, required the field to have a minimum length
max_length: int, optional
Only applies to strings, required the field to have a maximum length
pattern: str, optional
Only applies to strings, requires the field match against a regular expression pattern string
discriminator: str, optional
Parameter field name for discriminating the type in a tagged union
strict: bool, optional
Enables Pydantic's strict mode for the field
multiple_of: float, optional
Only applies to numbers, requires the field to be a multiple of the given value
allow_inf_nan: bool, optional
Only applies to numbers, requires the field to allow infinity and NaN values
max_digits: int, optional
Only applies to Decimals, requires the field to have a maxmium number of digits within the decimal.
decimal_places: int, optional
Only applies to Decimals, requires the field to have at most a number of decimal places
examples: list[Any], optional
A list of examples for the parameter
deprecated: bool, optional
If `True`, the parameter will be marked as deprecated
include_in_schema: bool, optional
If `False`, the parameter will be excluded from the generated OpenAPI schema
json_schema_extra: dict[str, Any], optional
Extra values to include in the generated OpenAPI schema
"""
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)


class Path(Param):
"""
A class used internally to represent a path parameter in a path operation.
Expand Down
21 changes: 21 additions & 0 deletions tests/functional/event_handler/_pydantic/test_openapi_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
Cookie,
Form,
Header,
Param,
Expand Down Expand Up @@ -100,6 +101,26 @@ def handler(user_id: str, include_extra: bool = False):
assert parameter.schema_.title == "Include Extra"


def test_openapi_with_cookie_params():
app = APIGatewayRestResolver()

@app.get("/menu", summary="Get food items", operation_id="GetMenu", description="Get food items")
def handler(country: Annotated[str, Cookie(examples=[Example(summary="Country", value="🇧🇷")])]):
print(country)
raise NotImplementedError()

schema = app.get_openapi_schema()
get = schema.paths["/menu"].get

param = get.parameters[0]
assert param.name == "country"
assert param.in_ == ParameterInType.cookie

example = Example(**param.schema_.examples[0])
assert example.summary == "Country"
assert example.value == "🇧🇷"


def test_openapi_with_custom_params():
app = APIGatewayRestResolver()

Expand Down