Skip to content

Commit 8da7631

Browse files
Fix middleware validation
1 parent 8b9221a commit 8da7631

File tree

6 files changed

+94
-75
lines changed

6 files changed

+94
-75
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,35 @@ def build_allow_methods(methods: set[str]) -> str:
255255
return ",".join(sorted(methods))
256256

257257

258+
class BedrockResponse(Generic[ResponseT]):
259+
"""
260+
Contains the response body, status code, content type, and optional attributes
261+
for session management and knowledge base configuration.
262+
"""
263+
264+
def __init__(
265+
self,
266+
body: Any = None,
267+
status_code: int = 200,
268+
content_type: str = "application/json",
269+
session_attributes: dict[str, Any] | None = None,
270+
prompt_session_attributes: dict[str, Any] | None = None,
271+
knowledge_bases_configuration: list[dict[str, Any]] | None = None,
272+
) -> None:
273+
self.body = body
274+
self.status_code = status_code
275+
self.content_type = content_type
276+
self.session_attributes = session_attributes
277+
self.prompt_session_attributes = prompt_session_attributes
278+
self.knowledge_bases_configuration = knowledge_bases_configuration
279+
280+
def is_json(self) -> bool:
281+
"""
282+
Returns True if the response is JSON, based on the Content-Type.
283+
"""
284+
return True
285+
286+
258287
class Response(Generic[ResponseT]):
259288
"""Response data class that provides greater control over what is returned from the proxy event"""
260289

@@ -1474,7 +1503,10 @@ def __call__(self, app: ApiGatewayResolver) -> dict | tuple | Response:
14741503
return self.current_middleware(app, self.next_middleware)
14751504

14761505

1477-
def _registered_api_adapter(app: ApiGatewayResolver, next_middleware: Callable[..., Any]) -> dict | tuple | Response:
1506+
def _registered_api_adapter(
1507+
app: ApiGatewayResolver,
1508+
next_middleware: Callable[..., Any],
1509+
) -> dict | tuple | Response | BedrockResponse:
14781510
"""
14791511
Calls the registered API using the "_route_args" from the Resolver context to ensure the last call
14801512
in the chain will match the API route function signature and ensure that Powertools passes the API
@@ -2538,7 +2570,7 @@ def _call_route(self, route: Route, route_arguments: dict[str, str]) -> Response
25382570
self._reset_processed_stack()
25392571

25402572
return self._response_builder_class(
2541-
response=self._to_response(
2573+
response=self._to_response( # type: ignore[arg-type]
25422574
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
25432575
),
25442576
serializer=self._serializer,
@@ -2627,7 +2659,7 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
26272659

26282660
return None
26292661

2630-
def _to_response(self, result: dict | tuple | Response) -> Response:
2662+
def _to_response(self, result: dict | tuple | Response | BedrockResponse) -> Response | BedrockResponse:
26312663
"""Convert the route's result to a Response
26322664
26332665
3 main result types are supported:
@@ -2638,7 +2670,7 @@ def _to_response(self, result: dict | tuple | Response) -> Response:
26382670
- Response: returned as is, and allows for more flexibility
26392671
"""
26402672
status_code = HTTPStatus.OK
2641-
if isinstance(result, Response):
2673+
if isinstance(result, (Response, BedrockResponse)):
26422674
return result
26432675
elif isinstance(result, tuple) and len(result) == 2:
26442676
# Unpack result dict and status code from tuple
@@ -2971,8 +3003,9 @@ def _get_base_path(self) -> str:
29713003
# ALB doesn't have a stage variable, so we just return an empty string
29723004
return ""
29733005

3006+
# BedrockResponse is not used here but adding the same signature to keep strong typing
29743007
@override
2975-
def _to_response(self, result: dict | tuple | Response) -> Response:
3008+
def _to_response(self, result: dict | tuple | Response | BedrockResponse) -> Response | BedrockResponse:
29763009
"""Convert the route's result to a Response
29773010
29783011
ALB requires a non-null body otherwise it converts as HTTP 5xx

aws_lambda_powertools/event_handler/bedrock_agent.py

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aws_lambda_powertools.event_handler import ApiGatewayResolver
99
from aws_lambda_powertools.event_handler.api_gateway import (
1010
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
11+
BedrockResponse,
1112
ProxyEventType,
1213
ResponseBuilder,
1314
)
@@ -23,29 +24,6 @@
2324
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
2425

2526

26-
class BedrockResponse:
27-
"""
28-
Contains the response body, status code, content type, and optional attributes
29-
for session management and knowledge base configuration.
30-
"""
31-
32-
def __init__(
33-
self,
34-
body: Any = None,
35-
status_code: int = 200,
36-
content_type: str = "application/json",
37-
session_attributes: dict[str, Any] | None = None,
38-
prompt_session_attributes: dict[str, Any] | None = None,
39-
knowledge_bases_configuration: list[dict[str, Any]] | None = None,
40-
) -> None:
41-
self.body = body
42-
self.status_code = status_code
43-
self.content_type = content_type
44-
self.session_attributes = session_attributes
45-
self.prompt_session_attributes = prompt_session_attributes
46-
self.knowledge_bases_configuration = knowledge_bases_configuration
47-
48-
4927
class BedrockResponseBuilder(ResponseBuilder):
5028
"""
5129
Bedrock Response Builder. This builds the response dict to be returned by Lambda when using Bedrock Agents.
@@ -55,18 +33,9 @@ class BedrockResponseBuilder(ResponseBuilder):
5533

5634
@override
5735
def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]:
58-
"""Build the full response dict to be returned by the lambda"""
59-
self._route(event, None)
60-
61-
bedrock_response = None
62-
if isinstance(self.response.body, dict) and "body" in self.response.body:
63-
bedrock_response = BedrockResponse(**self.response.body)
64-
body = bedrock_response.body
65-
else:
66-
body = self.response.body
67-
68-
if self.response.is_json() and not isinstance(body, str):
69-
body = self.serializer(body)
36+
body = self.response.body
37+
if self.response.is_json() and not isinstance(self.response.body, str):
38+
body = self.serializer(self.response.body)
7039

7140
response = {
7241
"messageVersion": "1.0",
@@ -84,13 +53,15 @@ def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]:
8453
}
8554

8655
# Add Bedrock-specific attributes
87-
if bedrock_response:
88-
if bedrock_response.session_attributes:
89-
response["sessionAttributes"] = bedrock_response.session_attributes
90-
if bedrock_response.prompt_session_attributes:
91-
response["promptSessionAttributes"] = bedrock_response.prompt_session_attributes
92-
if bedrock_response.knowledge_bases_configuration:
93-
response["knowledgeBasesConfiguration"] = bedrock_response.knowledge_bases_configuration # type: ignore
56+
if isinstance(self.response, BedrockResponse):
57+
if self.response.session_attributes:
58+
response["sessionAttributes"] = self.response.session_attributes
59+
60+
if self.response.prompt_session_attributes:
61+
response["promptSessionAttributes"] = self.response.prompt_session_attributes
62+
63+
if self.response.knowledge_bases_configuration:
64+
response["knowledgeBasesConfiguration"] = self.response.knowledge_bases_configuration
9465

9566
return response
9667

docs/core/event_handler/bedrock_agents.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -321,20 +321,18 @@ You can enable user confirmation with Bedrock Agents to have your application as
321321
--8<-- "examples/event_handler_bedrock_agents/src/enabling_user_confirmation.py"
322322
```
323323

324-
### Fine grained responses
325-
326-
`BedrockResponse` class that provides full control over Bedrock Agent responses.
324+
1. Add an openapi extension
327325

328-
You can use this class to add additional fields as needed, such as [session attributes, prompt session attributes, and knowledge base configurations](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-response).
326+
### Fine grained responses
329327

330328
???+ info "Note"
331329
The default response only includes the essential fields to keep the payload size minimal, as AWS Lambda has a maximum response size of 25 KB.
332330

333-
```python title="bedrockresponse.py" title="Customzing your Bedrock Response"
334-
--8<-- "examples/event_handler_bedrock_agents/src/bedrockresponse.py"
335-
```
331+
You can use `BedrockResponse` class to add additional fields as needed, such as [session attributes, prompt session attributes, and knowledge base configurations](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-response){target="_blank"}.
336332

337-
1. Add an openapi extension
333+
```python title="working_with_bedrockresponse.py" title="Customzing your Bedrock Response" hl_lines="5 16"
334+
--8<-- "examples/event_handler_bedrock_agents/src/working_with_bedrockresponse.py"
335+
```
338336

339337
## Testing your code
340338

examples/event_handler_bedrock_agents/src/bedrockresponse.py

Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from http import HTTPStatus
2+
3+
from aws_lambda_powertools import Logger, Tracer
4+
from aws_lambda_powertools.event_handler import BedrockAgentResolver
5+
from aws_lambda_powertools.event_handler.api_gateway import BedrockResponse
6+
from aws_lambda_powertools.utilities.typing import LambdaContext
7+
8+
tracer = Tracer()
9+
logger = Logger()
10+
app = BedrockAgentResolver()
11+
12+
13+
@app.get("/return_with_session", description="Returns a hello world with session attributes")
14+
@tracer.capture_method
15+
def hello_world():
16+
return BedrockResponse(
17+
status_code=HTTPStatus.OK.value,
18+
body={"message": "Hello from Bedrock!"},
19+
session_attributes={"user_id": "123"},
20+
prompt_session_attributes={"context": "testing"},
21+
knowledge_bases_configuration=[
22+
{
23+
"knowledgeBaseId": "kb-123",
24+
"retrievalConfiguration": {
25+
"vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"},
26+
},
27+
},
28+
],
29+
)
30+
31+
32+
@logger.inject_lambda_context
33+
@tracer.capture_lambda_handler
34+
def lambda_handler(event: dict, context: LambdaContext):
35+
return app.resolve(event, context)

tests/functional/event_handler/_pydantic/test_bedrock_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def test_bedrock_agent_with_partial_bedrock_response():
266266
app = BedrockAgentResolver()
267267

268268
@app.get("/claims", description="Gets claims")
269-
def claims():
269+
def claims() -> Dict[str, Any]:
270270
return BedrockResponse(
271271
body={"message": "test"},
272272
session_attributes={"user_id": "123"},
@@ -289,7 +289,7 @@ def test_bedrock_agent_with_different_attributes_combination():
289289
app = BedrockAgentResolver()
290290

291291
@app.get("/claims", description="Gets claims")
292-
def claims():
292+
def claims() -> Dict[str, Any]:
293293
return BedrockResponse(
294294
body={"message": "test"},
295295
prompt_session_attributes={"context": "testing"},

0 commit comments

Comments
 (0)