diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 98a8740a74f..32453eba121 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -13,6 +13,7 @@ ) from aws_lambda_powertools.event_handler.openapi.params import ( Body, + Cookie, Dependant, Form, Header, @@ -275,7 +276,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: return False elif is_scalar_field(field=param_field): return False - elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field): + elif isinstance(param_field.field_info, (Query, Header, Cookie)) and is_scalar_sequence_field(param_field): return False else: if not isinstance(param_field.field_info, Body): diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 8fc8d0becfa..f4e0b72fb1e 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -236,6 +236,135 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.default})" +class Cookie(Param): + """ + A class used internally to represent a cookie parameter in the cookie header. + """ + + in_ = ParamTypes.cookie + + def __init__( + self, + default: Any = _Unset, + *, + default_factory: Callable[[], Any] | None = _Unset, + annotation: Any | None = None, + alias: str | None = None, + alias_priority: int | None = _Unset, + validation_alias: str | None = None, + serialization_alias: str | None = None, + title: str | None = None, + description: str | None = None, + gt: float | None = None, + ge: float | None = None, + lt: float | None = None, + le: float | None = None, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | None = None, + discriminator: str | None = None, + strict: bool | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + examples: list[Any] | None = None, + openapi_examples: dict[str, Example] | None = None, + deprecated: bool | None = None, + include_in_schema: bool = True, + json_schema_extra: dict[str, Any] | None = None, + **extra: Any, + ): + """ + Constructs a new Query param. + + Parameters + ---------- + default: Any + The default value of the parameter + default_factory: Callable[[], Any], optional + Callable that will be called when a default value is needed for this field + annotation: Any, optional + The type annotation of the parameter + alias: str, optional + The public name of the field + alias_priority: int, optional + Priority of the alias. This affects whether an alias generator is used + validation_alias: str | AliasPath | AliasChoices | None, optional + Alias to be used for validation only + serialization_alias: str | AliasPath | AliasChoices | None, optional + Alias to be used for serialization only + title: str, optional + The title of the parameter + description: str, optional + The description of the parameter + gt: float, optional + Only applies to numbers, required the field to be "greater than" + ge: float, optional + Only applies to numbers, required the field to be "greater than or equal" + lt: float, optional + Only applies to numbers, required the field to be "less than" + le: float, optional + Only applies to numbers, required the field to be "less than or equal" + min_length: int, optional + Only applies to strings, required the field to have a minimum length + max_length: int, optional + Only applies to strings, required the field to have a maximum length + pattern: str, optional + Only applies to strings, requires the field match against a regular expression pattern string + discriminator: str, optional + Parameter field name for discriminating the type in a tagged union + strict: bool, optional + Enables Pydantic's strict mode for the field + multiple_of: float, optional + Only applies to numbers, requires the field to be a multiple of the given value + allow_inf_nan: bool, optional + Only applies to numbers, requires the field to allow infinity and NaN values + max_digits: int, optional + Only applies to Decimals, requires the field to have a maxmium number of digits within the decimal. + decimal_places: int, optional + Only applies to Decimals, requires the field to have at most a number of decimal places + examples: list[Any], optional + A list of examples for the parameter + deprecated: bool, optional + If `True`, the parameter will be marked as deprecated + include_in_schema: bool, optional + If `False`, the parameter will be excluded from the generated OpenAPI schema + json_schema_extra: dict[str, Any], optional + Extra values to include in the generated OpenAPI schema + """ + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + openapi_examples=openapi_examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + class Path(Param): """ A class used internally to represent a path parameter in a path operation. diff --git a/tests/functional/event_handler/_pydantic/test_openapi_params.py b/tests/functional/event_handler/_pydantic/test_openapi_params.py index 19b5287d66a..fa9ec489aa4 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_params.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_params.py @@ -14,6 +14,7 @@ ) from aws_lambda_powertools.event_handler.openapi.params import ( Body, + Cookie, Form, Header, Param, @@ -100,6 +101,26 @@ def handler(user_id: str, include_extra: bool = False): assert parameter.schema_.title == "Include Extra" +def test_openapi_with_cookie_params(): + app = APIGatewayRestResolver() + + @app.get("/menu", summary="Get food items", operation_id="GetMenu", description="Get food items") + def handler(country: Annotated[str, Cookie(examples=[Example(summary="Country", value="🇧🇷")])]): + print(country) + raise NotImplementedError() + + schema = app.get_openapi_schema() + get = schema.paths["/menu"].get + + param = get.parameters[0] + assert param.name == "country" + assert param.in_ == ParameterInType.cookie + + example = Example(**param.schema_.examples[0]) + assert example.summary == "Country" + assert example.value == "🇧🇷" + + def test_openapi_with_custom_params(): app = APIGatewayRestResolver()