Skip to content

Commit 28440e8

Browse files
committed
feat(event_handler): Cookie parameter support in OpenAPI
1 parent 5999346 commit 28440e8

File tree

3 files changed

+152
-1
lines changed

3 files changed

+152
-1
lines changed

aws_lambda_powertools/event_handler/openapi/dependant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from aws_lambda_powertools.event_handler.openapi.params import (
1515
Body,
16+
Cookie,
1617
Dependant,
1718
Form,
1819
Header,
@@ -275,7 +276,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
275276
return False
276277
elif is_scalar_field(field=param_field):
277278
return False
278-
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
279+
elif isinstance(param_field.field_info, (Query, Header, Cookie)) and is_scalar_sequence_field(param_field):
279280
return False
280281
else:
281282
if not isinstance(param_field.field_info, Body):

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,135 @@ def __repr__(self) -> str:
236236
return f"{self.__class__.__name__}({self.default})"
237237

238238

239+
class Cookie(Param):
240+
"""
241+
A class used internally to represent a cookie parameter in the cookie header.
242+
"""
243+
244+
in_ = ParamTypes.cookie
245+
246+
def __init__(
247+
self,
248+
default: Any = _Unset,
249+
*,
250+
default_factory: Callable[[], Any] | None = _Unset,
251+
annotation: Any | None = None,
252+
alias: str | None = None,
253+
alias_priority: int | None = _Unset,
254+
validation_alias: str | None = None,
255+
serialization_alias: str | None = None,
256+
title: str | None = None,
257+
description: str | None = None,
258+
gt: float | None = None,
259+
ge: float | None = None,
260+
lt: float | None = None,
261+
le: float | None = None,
262+
min_length: int | None = None,
263+
max_length: int | None = None,
264+
pattern: str | None = None,
265+
discriminator: str | None = None,
266+
strict: bool | None = _Unset,
267+
multiple_of: float | None = _Unset,
268+
allow_inf_nan: bool | None = _Unset,
269+
max_digits: int | None = _Unset,
270+
decimal_places: int | None = _Unset,
271+
examples: list[Any] | None = None,
272+
openapi_examples: dict[str, Example] | None = None,
273+
deprecated: bool | None = None,
274+
include_in_schema: bool = True,
275+
json_schema_extra: dict[str, Any] | None = None,
276+
**extra: Any,
277+
):
278+
"""
279+
Constructs a new Query param.
280+
281+
Parameters
282+
----------
283+
default: Any
284+
The default value of the parameter
285+
default_factory: Callable[[], Any], optional
286+
Callable that will be called when a default value is needed for this field
287+
annotation: Any, optional
288+
The type annotation of the parameter
289+
alias: str, optional
290+
The public name of the field
291+
alias_priority: int, optional
292+
Priority of the alias. This affects whether an alias generator is used
293+
validation_alias: str | AliasPath | AliasChoices | None, optional
294+
Alias to be used for validation only
295+
serialization_alias: str | AliasPath | AliasChoices | None, optional
296+
Alias to be used for serialization only
297+
title: str, optional
298+
The title of the parameter
299+
description: str, optional
300+
The description of the parameter
301+
gt: float, optional
302+
Only applies to numbers, required the field to be "greater than"
303+
ge: float, optional
304+
Only applies to numbers, required the field to be "greater than or equal"
305+
lt: float, optional
306+
Only applies to numbers, required the field to be "less than"
307+
le: float, optional
308+
Only applies to numbers, required the field to be "less than or equal"
309+
min_length: int, optional
310+
Only applies to strings, required the field to have a minimum length
311+
max_length: int, optional
312+
Only applies to strings, required the field to have a maximum length
313+
pattern: str, optional
314+
Only applies to strings, requires the field match against a regular expression pattern string
315+
discriminator: str, optional
316+
Parameter field name for discriminating the type in a tagged union
317+
strict: bool, optional
318+
Enables Pydantic's strict mode for the field
319+
multiple_of: float, optional
320+
Only applies to numbers, requires the field to be a multiple of the given value
321+
allow_inf_nan: bool, optional
322+
Only applies to numbers, requires the field to allow infinity and NaN values
323+
max_digits: int, optional
324+
Only applies to Decimals, requires the field to have a maxmium number of digits within the decimal.
325+
decimal_places: int, optional
326+
Only applies to Decimals, requires the field to have at most a number of decimal places
327+
examples: list[Any], optional
328+
A list of examples for the parameter
329+
deprecated: bool, optional
330+
If `True`, the parameter will be marked as deprecated
331+
include_in_schema: bool, optional
332+
If `False`, the parameter will be excluded from the generated OpenAPI schema
333+
json_schema_extra: dict[str, Any], optional
334+
Extra values to include in the generated OpenAPI schema
335+
"""
336+
super().__init__(
337+
default=default,
338+
default_factory=default_factory,
339+
annotation=annotation,
340+
alias=alias,
341+
alias_priority=alias_priority,
342+
validation_alias=validation_alias,
343+
serialization_alias=serialization_alias,
344+
title=title,
345+
description=description,
346+
gt=gt,
347+
ge=ge,
348+
lt=lt,
349+
le=le,
350+
min_length=min_length,
351+
max_length=max_length,
352+
pattern=pattern,
353+
discriminator=discriminator,
354+
strict=strict,
355+
multiple_of=multiple_of,
356+
allow_inf_nan=allow_inf_nan,
357+
max_digits=max_digits,
358+
decimal_places=decimal_places,
359+
deprecated=deprecated,
360+
examples=examples,
361+
openapi_examples=openapi_examples,
362+
include_in_schema=include_in_schema,
363+
json_schema_extra=json_schema_extra,
364+
**extra,
365+
)
366+
367+
239368
class Path(Param):
240369
"""
241370
A class used internally to represent a path parameter in a path operation.

tests/functional/event_handler/_pydantic/test_openapi_params.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from aws_lambda_powertools.event_handler.openapi.params import (
1616
Body,
17+
Cookie,
1718
Form,
1819
Header,
1920
Param,
@@ -100,6 +101,26 @@ def handler(user_id: str, include_extra: bool = False):
100101
assert parameter.schema_.title == "Include Extra"
101102

102103

104+
def test_openapi_with_cookie_params():
105+
app = APIGatewayRestResolver()
106+
107+
@app.get("/menu", summary="Get food items", operation_id="GetMenu", description="Get food items")
108+
def handler(country: Annotated[str, Cookie(examples=[Example(summary="Country", value="🇧🇷")])]):
109+
print(country)
110+
raise NotImplementedError()
111+
112+
schema = app.get_openapi_schema()
113+
get = schema.paths["/menu"].get
114+
115+
param = get.parameters[0]
116+
assert param.name == "country"
117+
assert param.in_ == ParameterInType.cookie
118+
119+
example = Example(**param.schema_.examples[0])
120+
assert example.summary == "Country"
121+
assert example.value == "🇧🇷"
122+
123+
103124
def test_openapi_with_custom_params():
104125
app = APIGatewayRestResolver()
105126

0 commit comments

Comments
 (0)