diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 407cd00781b..22d7ba91bcc 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -661,30 +661,36 @@ def _get_openapi_path( # noqa PLR0912 else: # Need to iterate to transform any 'model' into a 'schema' for content_type, payload in response["content"].items(): - new_payload: OpenAPIResponseContentSchema - # Case 2.1: the 'content' has a model if "model" in payload: # Find the model in the dependant's extra models + model_payload_typed = cast(OpenAPIResponseContentModel, payload) return_field = next( filter( - lambda model: model.type_ is cast(OpenAPIResponseContentModel, payload)["model"], + lambda model: model.type_ is model_payload_typed["model"], self.dependant.response_extra_models, ), ) if not return_field: raise AssertionError("Model declared in custom responses was not found") - new_payload = self._openapi_operation_return( + model_payload = self._openapi_operation_return( param=return_field, model_name_map=model_name_map, field_mapping=field_mapping, ) + # Preserve existing fields like examples, encoding, etc. + new_payload: OpenAPIResponseContentSchema = {} + for key, value in payload.items(): + if key != "model": + new_payload[key] = value # type: ignore[literal-required] + new_payload.update(model_payload) # Add/override with model schema + # Case 2.2: the 'content' has a schema else: # Do nothing! We already have what we need! - new_payload = payload + new_payload = cast(OpenAPIResponseContentSchema, payload) response["content"][content_type] = new_payload diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py index 61ac295f948..a4fb3662fb0 100644 --- a/aws_lambda_powertools/event_handler/openapi/types.py +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -63,14 +63,32 @@ } +class OpenAPIResponseHeader(TypedDict, total=False): + """OpenAPI Response Header Object""" + + description: NotRequired[str] + schema: NotRequired[dict[str, Any]] + examples: NotRequired[dict[str, Any]] + style: NotRequired[str] + explode: NotRequired[bool] + allowReserved: NotRequired[bool] + deprecated: NotRequired[bool] + + class OpenAPIResponseContentSchema(TypedDict, total=False): schema: dict + examples: NotRequired[dict[str, Any]] + encoding: NotRequired[dict[str, Any]] -class OpenAPIResponseContentModel(TypedDict): +class OpenAPIResponseContentModel(TypedDict, total=False): model: Any + examples: NotRequired[dict[str, Any]] + encoding: NotRequired[dict[str, Any]] -class OpenAPIResponse(TypedDict): - description: str +class OpenAPIResponse(TypedDict, total=False): + description: str # Still required + headers: NotRequired[dict[str, OpenAPIResponseHeader]] content: NotRequired[dict[str, OpenAPIResponseContentSchema | OpenAPIResponseContentModel]] + links: NotRequired[dict[str, Any]] diff --git a/tests/functional/event_handler/_pydantic/test_openapi_responses.py b/tests/functional/event_handler/_pydantic/test_openapi_responses.py index 8c41651f803..71c7d186cbe 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_responses.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_responses.py @@ -1,5 +1,5 @@ from secrets import randbelow -from typing import Union +from typing import Optional, Union from pydantic import BaseModel @@ -237,3 +237,136 @@ def handler(): assert 200 in responses.keys() assert responses[200].description == "Successful Response" assert 422 not in responses.keys() + + +def test_openapi_response_with_headers(): + """Test that response headers are properly included in OpenAPI schema""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.get( + "/", + responses={ + 200: { + "description": "Successful Response", + "headers": { + "X-Rate-Limit": { + "description": "Rate limit header", + "schema": {"type": "integer"}, + }, + "X-Custom-Header": { + "description": "Custom header", + "schema": {"type": "string"}, + "examples": {"example1": "value1"}, + }, + }, + }, + }, + ) + def handler(): + return {"message": "hello"} + + schema = app.get_openapi_schema() + response_dict = schema.paths["/"].get.responses[200] + + # Verify headers are present + assert "headers" in response_dict + headers = response_dict["headers"] + + # Check X-Rate-Limit header + assert "X-Rate-Limit" in headers + assert headers["X-Rate-Limit"]["description"] == "Rate limit header" + assert headers["X-Rate-Limit"]["schema"]["type"] == "integer" + + # Check X-Custom-Header with examples + assert "X-Custom-Header" in headers + assert headers["X-Custom-Header"]["description"] == "Custom header" + assert headers["X-Custom-Header"]["schema"]["type"] == "string" + assert headers["X-Custom-Header"]["examples"]["example1"] == "value1" + + +def test_openapi_response_with_links(): + """Test that response links are properly included in OpenAPI schema""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.get( + "/users/{user_id}", + responses={ + 200: { + "description": "User details", + "links": { + "GetUserOrders": { + "operationId": "getUserOrders", + "parameters": {"userId": "$response.body#/id"}, + "description": "Get orders for this user", + }, + }, + }, + }, + ) + def get_user(user_id: str): + return {"id": user_id, "name": "John Doe"} + + schema = app.get_openapi_schema() + response = schema.paths["/users/{user_id}"].get.responses[200] + + # Verify links are present + links = response.links + + assert "GetUserOrders" in links + assert links["GetUserOrders"].operationId == "getUserOrders" + assert links["GetUserOrders"].parameters["userId"] == "$response.body#/id" + assert links["GetUserOrders"].description == "Get orders for this user" + + +def test_openapi_response_examples_preserved_with_model(): + """Test that examples are preserved when using model in response content""" + app = APIGatewayRestResolver(enable_validation=True) + + class UserResponse(BaseModel): + id: int + name: str + email: Optional[str] = None + + @app.get( + "/", + responses={ + 200: { + "description": "User response", + "content": { + "application/json": { + "model": UserResponse, + "examples": { + "example1": { + "summary": "Example 1", + "value": {"id": 1, "name": "John", "email": "john@example.com"}, + }, + "example2": { + "summary": "Example 2", + "value": {"id": 2, "name": "Jane"}, + }, + }, + }, + }, + }, + }, + ) + def handler() -> UserResponse: + return UserResponse(id=1, name="Test") + + schema = app.get_openapi_schema() + content = schema.paths["/"].get.responses[200].content["application/json"] + + # Verify model schema is present + assert content.schema_.ref == "#/components/schemas/UserResponse" + + # Verify examples are preserved + examples = content.examples + + assert "example1" in examples + assert examples["example1"].summary == "Example 1" + assert examples["example1"].value["id"] == 1 + assert examples["example1"].value["name"] == "John" + + assert "example2" in examples + assert examples["example2"].summary == "Example 2" + assert examples["example2"].value["id"] == 2