Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,30 +661,36 @@ def _get_openapi_path( # noqa PLR0912
else:
# Need to iterate to transform any 'model' into a 'schema'
for content_type, payload in response["content"].items():
new_payload: OpenAPIResponseContentSchema

# Case 2.1: the 'content' has a model
if "model" in payload:
# Find the model in the dependant's extra models
model_payload_typed = cast(OpenAPIResponseContentModel, payload)
return_field = next(
filter(
lambda model: model.type_ is cast(OpenAPIResponseContentModel, payload)["model"],
lambda model: model.type_ is model_payload_typed["model"],
self.dependant.response_extra_models,
),
)
if not return_field:
raise AssertionError("Model declared in custom responses was not found")

new_payload = self._openapi_operation_return(
model_payload = self._openapi_operation_return(
param=return_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
)

# Preserve existing fields like examples, encoding, etc.
new_payload: OpenAPIResponseContentSchema = {}
for key, value in payload.items():
if key != "model":
new_payload[key] = value # type: ignore[literal-required]
new_payload.update(model_payload) # Add/override with model schema

# Case 2.2: the 'content' has a schema
else:
# Do nothing! We already have what we need!
new_payload = payload
new_payload = cast(OpenAPIResponseContentSchema, payload)

response["content"][content_type] = new_payload

Expand Down
24 changes: 21 additions & 3 deletions aws_lambda_powertools/event_handler/openapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,32 @@
}


class OpenAPIResponseHeader(TypedDict, total=False):
"""OpenAPI Response Header Object"""

description: NotRequired[str]
schema: NotRequired[dict[str, Any]]
examples: NotRequired[dict[str, Any]]
style: NotRequired[str]
explode: NotRequired[bool]
allowReserved: NotRequired[bool]
deprecated: NotRequired[bool]


class OpenAPIResponseContentSchema(TypedDict, total=False):
schema: dict
examples: NotRequired[dict[str, Any]]
encoding: NotRequired[dict[str, Any]]


class OpenAPIResponseContentModel(TypedDict):
class OpenAPIResponseContentModel(TypedDict, total=False):
model: Any
examples: NotRequired[dict[str, Any]]
encoding: NotRequired[dict[str, Any]]


class OpenAPIResponse(TypedDict):
description: str
class OpenAPIResponse(TypedDict, total=False):
description: str # Still required
headers: NotRequired[dict[str, OpenAPIResponseHeader]]
content: NotRequired[dict[str, OpenAPIResponseContentSchema | OpenAPIResponseContentModel]]
links: NotRequired[dict[str, Any]]
135 changes: 134 additions & 1 deletion tests/functional/event_handler/_pydantic/test_openapi_responses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from secrets import randbelow
from typing import Union
from typing import Optional, Union

from pydantic import BaseModel

Expand Down Expand Up @@ -237,3 +237,136 @@ def handler():
assert 200 in responses.keys()
assert responses[200].description == "Successful Response"
assert 422 not in responses.keys()


def test_openapi_response_with_headers():
"""Test that response headers are properly included in OpenAPI schema"""
app = APIGatewayRestResolver(enable_validation=True)

@app.get(
"/",
responses={
200: {
"description": "Successful Response",
"headers": {
"X-Rate-Limit": {
"description": "Rate limit header",
"schema": {"type": "integer"},
},
"X-Custom-Header": {
"description": "Custom header",
"schema": {"type": "string"},
"examples": {"example1": "value1"},
},
},
},
},
)
def handler():
return {"message": "hello"}

schema = app.get_openapi_schema()
response_dict = schema.paths["/"].get.responses[200]

# Verify headers are present
assert "headers" in response_dict
headers = response_dict["headers"]

# Check X-Rate-Limit header
assert "X-Rate-Limit" in headers
assert headers["X-Rate-Limit"]["description"] == "Rate limit header"
assert headers["X-Rate-Limit"]["schema"]["type"] == "integer"

# Check X-Custom-Header with examples
assert "X-Custom-Header" in headers
assert headers["X-Custom-Header"]["description"] == "Custom header"
assert headers["X-Custom-Header"]["schema"]["type"] == "string"
assert headers["X-Custom-Header"]["examples"]["example1"] == "value1"


def test_openapi_response_with_links():
"""Test that response links are properly included in OpenAPI schema"""
app = APIGatewayRestResolver(enable_validation=True)

@app.get(
"/users/{user_id}",
responses={
200: {
"description": "User details",
"links": {
"GetUserOrders": {
"operationId": "getUserOrders",
"parameters": {"userId": "$response.body#/id"},
"description": "Get orders for this user",
},
},
},
},
)
def get_user(user_id: str):
return {"id": user_id, "name": "John Doe"}

schema = app.get_openapi_schema()
response = schema.paths["/users/{user_id}"].get.responses[200]

# Verify links are present
links = response.links

assert "GetUserOrders" in links
assert links["GetUserOrders"].operationId == "getUserOrders"
assert links["GetUserOrders"].parameters["userId"] == "$response.body#/id"
assert links["GetUserOrders"].description == "Get orders for this user"


def test_openapi_response_examples_preserved_with_model():
"""Test that examples are preserved when using model in response content"""
app = APIGatewayRestResolver(enable_validation=True)

class UserResponse(BaseModel):
id: int
name: str
email: Optional[str] = None

@app.get(
"/",
responses={
200: {
"description": "User response",
"content": {
"application/json": {
"model": UserResponse,
"examples": {
"example1": {
"summary": "Example 1",
"value": {"id": 1, "name": "John", "email": "[email protected]"},
},
"example2": {
"summary": "Example 2",
"value": {"id": 2, "name": "Jane"},
},
},
},
},
},
},
)
def handler() -> UserResponse:
return UserResponse(id=1, name="Test")

schema = app.get_openapi_schema()
content = schema.paths["/"].get.responses[200].content["application/json"]

# Verify model schema is present
assert content.schema_.ref == "#/components/schemas/UserResponse"

# Verify examples are preserved
examples = content.examples

assert "example1" in examples
assert examples["example1"].summary == "Example 1"
assert examples["example1"].value["id"] == 1
assert examples["example1"].value["name"] == "John"

assert "example2" in examples
assert examples["example2"].summary == "Example 2"
assert examples["example2"].value["id"] == 2