Skip to content

Commit e42ceff

Browse files
committed
add response optional fields
1 parent abbc100 commit e42ceff

File tree

4 files changed

+159
-59
lines changed

4 files changed

+159
-59
lines changed

aws_lambda_powertools/event_handler/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
1414
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
15-
from aws_lambda_powertools.event_handler.bedrock_agent_function import BedrockAgentFunctionResolver
15+
from aws_lambda_powertools.event_handler.bedrock_agent_function import BedrockAgentFunctionResolver, BedrockResponse
1616
from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver
1717
from aws_lambda_powertools.event_handler.lambda_function_url import (
1818
LambdaFunctionUrlResolver,
@@ -31,6 +31,7 @@
3131
"CORSConfig",
3232
"LambdaFunctionUrlResolver",
3333
"Response",
34+
"BedrockResponse",
3435
"VPCLatticeResolver",
3536
"VPCLatticeV2Resolver",
3637
]

aws_lambda_powertools/event_handler/bedrock_agent_function.py

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22

33
from typing import TYPE_CHECKING, Any
44

5-
from typing_extensions import override
6-
7-
from aws_lambda_powertools.event_handler.api_gateway import Response, ResponseBuilder
8-
95
if TYPE_CHECKING:
106
from collections.abc import Callable
117

@@ -19,7 +15,49 @@ class ResponseState(Enum):
1915
REPROMPT = "REPROMPT"
2016

2117

22-
class BedrockFunctionsResponseBuilder(ResponseBuilder):
18+
class BedrockResponse:
19+
"""Response class for Bedrock Agent Functions
20+
21+
Parameters
22+
----------
23+
body : Any, optional
24+
Response body
25+
session_attributes : dict[str, str] | None
26+
Session attributes to include in the response
27+
prompt_session_attributes : dict[str, str] | None
28+
Prompt session attributes to include in the response
29+
status_code : int
30+
Status code to determine responseState (400 for REPROMPT, >=500 for FAILURE)
31+
32+
Examples
33+
--------
34+
```python
35+
@app.tool(description="Function that uses session attributes")
36+
def test_function():
37+
return BedrockResponse(
38+
body="Hello",
39+
session_attributes={"userId": "123"},
40+
prompt_session_attributes={"lastAction": "login"}
41+
)
42+
```
43+
"""
44+
45+
def __init__(
46+
self,
47+
body: Any = None,
48+
session_attributes: dict[str, str] | None = None,
49+
prompt_session_attributes: dict[str, str] | None = None,
50+
knowledge_bases: list[dict[str, Any]] | None = None,
51+
status_code: int = 200,
52+
) -> None:
53+
self.body = body
54+
self.session_attributes = session_attributes
55+
self.prompt_session_attributes = prompt_session_attributes
56+
self.knowledge_bases = knowledge_bases
57+
self.status_code = status_code
58+
59+
60+
class BedrockFunctionsResponseBuilder:
2361
"""
2462
Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda
2563
when using Bedrock Agent Functions.
@@ -28,30 +66,50 @@ class BedrockFunctionsResponseBuilder(ResponseBuilder):
2866
we override the build method.
2967
"""
3068

31-
@override
32-
def build(self, event: BedrockAgentFunctionEvent, *args) -> dict[str, Any]:
33-
"""Build the full response dict to be returned by the lambda"""
34-
self._route(event, None)
69+
def __init__(self, result: BedrockResponse | Any, status_code: int = 200) -> None:
70+
self.result = result
71+
self.status_code = status_code if not isinstance(result, BedrockResponse) else result.status_code
3572

36-
body = self.response.body
37-
if self.response.is_json() and not isinstance(self.response.body, str):
38-
body = self.serializer(body)
73+
def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]:
74+
"""Build the full response dict to be returned by the lambda"""
75+
if isinstance(self.result, BedrockResponse):
76+
body = self.result.body
77+
session_attributes = self.result.session_attributes
78+
prompt_session_attributes = self.result.prompt_session_attributes
79+
knowledge_bases = self.result.knowledge_bases
80+
else:
81+
body = self.result
82+
session_attributes = None
83+
prompt_session_attributes = None
84+
knowledge_bases = None
3985

4086
response: dict[str, Any] = {
4187
"messageVersion": "1.0",
4288
"response": {
4389
"actionGroup": event.action_group,
4490
"function": event.function,
45-
"functionResponse": {"responseBody": {"TEXT": {"body": str(body)}}},
91+
"functionResponse": {"responseBody": {"TEXT": {"body": str(body if body is not None else "")}}},
4692
},
4793
}
4894

4995
# Add responseState if it's an error
50-
if self.response.status_code >= 400:
96+
if self.status_code >= 400:
5197
response["response"]["functionResponse"]["responseState"] = (
52-
ResponseState.REPROMPT.value if self.response.status_code == 400 else ResponseState.FAILURE.value
98+
ResponseState.REPROMPT.value if self.status_code == 400 else ResponseState.FAILURE.value
5399
)
54100

101+
# Add session attributes if provided in response or maintain from input
102+
response.update(
103+
{
104+
"sessionAttributes": session_attributes or event.session_attributes or {},
105+
"promptSessionAttributes": prompt_session_attributes or event.prompt_session_attributes or {},
106+
},
107+
)
108+
109+
# Add knowledge bases configuration if provided
110+
if knowledge_bases:
111+
response["knowledgeBasesConfiguration"] = knowledge_bases
112+
55113
return response
56114

57115

@@ -127,26 +185,20 @@ def _resolve(self) -> dict[str, Any]:
127185
function_name = self.current_event.function
128186

129187
if function_name not in self._tools:
130-
return self._response_builder_class(
131-
Response(
132-
status_code=400, # Using 400 to trigger REPROMPT
188+
return BedrockFunctionsResponseBuilder(
189+
BedrockResponse(
133190
body=f"Function not found: {function_name}",
191+
status_code=400, # Using 400 to trigger REPROMPT
134192
),
135193
).build(self.current_event)
136194

137195
try:
138196
result = self._tools[function_name]["function"]()
139-
# Always wrap the result in a Response object
140-
if not isinstance(result, Response):
141-
result = Response(
142-
status_code=200, # Success
143-
body=result,
144-
)
145-
return self._response_builder_class(result).build(self.current_event)
197+
return BedrockFunctionsResponseBuilder(result).build(self.current_event)
146198
except Exception as e:
147-
return self._response_builder_class(
148-
Response(
149-
status_code=500, # Using 500 to trigger FAILURE
199+
return BedrockFunctionsResponseBuilder(
200+
BedrockResponse(
150201
body=f"Error: {str(e)}",
202+
status_code=500, # Using 500 to trigger FAILURE
151203
),
152204
).build(self.current_event)

tests/events/bedrockAgentFunctionEvent.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
},
99
"sessionId": "123456789123458",
1010
"sessionAttributes": {
11-
"employeeId": "EMP123",
12-
"department": "Engineering"
11+
"employeeId": "EMP123"
1312
},
1413
"promptSessionAttributes": {
1514
"lastInteraction": "2024-02-01T15:30:00Z",

tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3-
import json
3+
from typing import Any
44

55
import pytest
66

7-
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, Response, content_types
7+
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, BedrockResponse
88
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
99
from tests.functional.utils import load_event
1010

@@ -69,32 +69,6 @@ def test_bedrock_agent_function_not_found():
6969
assert result["response"]["functionResponse"]["responseState"] == "REPROMPT"
7070

7171

72-
def test_bedrock_agent_function_with_response_object():
73-
# GIVEN a Bedrock Agent Function resolver
74-
app = BedrockAgentFunctionResolver()
75-
76-
@app.tool(description="Returns a Response object")
77-
def test_function():
78-
return Response(
79-
status_code=200,
80-
content_type=content_types.APPLICATION_JSON,
81-
body={"message": "Hello from Response"},
82-
)
83-
84-
# WHEN calling the event handler
85-
raw_event = load_event("bedrockAgentFunctionEvent.json")
86-
raw_event["function"] = "test_function"
87-
result = app.resolve(raw_event, {})
88-
89-
# THEN process event correctly with Response object
90-
assert result["messageVersion"] == "1.0"
91-
assert result["response"]["actionGroup"] == raw_event["actionGroup"]
92-
assert result["response"]["function"] == "test_function"
93-
response_body = result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
94-
assert json.loads(response_body) == {"message": "Hello from Response"}
95-
assert "responseState" not in result["response"]["functionResponse"] # Success has no state
96-
97-
9872
def test_bedrock_agent_function_registration():
9973
# GIVEN a Bedrock Agent Function resolver
10074
app = BedrockAgentFunctionResolver()
@@ -148,3 +122,77 @@ def test_function():
148122
assert result["response"]["actionGroup"] == raw_event["actionGroup"]
149123
assert result["response"]["function"] == "customName"
150124
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello from custom named function"
125+
126+
127+
def test_bedrock_agent_function_with_session_attributes():
128+
# GIVEN a Bedrock Agent Function resolver
129+
app = BedrockAgentFunctionResolver()
130+
131+
@app.tool(description="Function that uses session attributes")
132+
def test_function() -> dict[str, Any]:
133+
return BedrockResponse(
134+
body="Hello",
135+
session_attributes={"userId": "123"},
136+
prompt_session_attributes={"lastAction": "login"},
137+
)
138+
139+
# WHEN calling the event handler
140+
raw_event = load_event("bedrockAgentFunctionEvent.json")
141+
raw_event["function"] = "test_function"
142+
raw_event["parameters"] = []
143+
result = app.resolve(raw_event, {})
144+
145+
# THEN include session attributes in response
146+
assert result["messageVersion"] == "1.0"
147+
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello"
148+
assert result["sessionAttributes"] == {"userId": "123"}
149+
assert result["promptSessionAttributes"] == {"lastAction": "login"}
150+
151+
152+
def test_bedrock_agent_function_with_error_response():
153+
# GIVEN a Bedrock Agent Function resolver
154+
app = BedrockAgentFunctionResolver()
155+
156+
@app.tool(description="Function that returns error")
157+
def test_function() -> dict[str, Any]:
158+
return BedrockResponse(
159+
body="Invalid input",
160+
status_code=400, # This will trigger REPROMPT
161+
session_attributes={"error": "true"},
162+
)
163+
164+
# WHEN calling the event handler
165+
raw_event = load_event("bedrockAgentFunctionEvent.json")
166+
raw_event["function"] = "test_function"
167+
raw_event["parameters"] = []
168+
result = app.resolve(raw_event, {})
169+
170+
# THEN include error state and session attributes
171+
assert result["response"]["functionResponse"]["responseState"] == "REPROMPT"
172+
assert result["sessionAttributes"] == {"error": "true"}
173+
174+
175+
def test_bedrock_agent_function_with_knowledge_bases():
176+
# GIVEN a Bedrock Agent Function resolver
177+
app = BedrockAgentFunctionResolver()
178+
179+
@app.tool(description="Returns response with knowledge bases config")
180+
def test_function() -> dict[Any]:
181+
return BedrockResponse(
182+
knowledge_bases=[
183+
{
184+
"knowledgeBaseId": "kb1",
185+
"retrievalConfiguration": {"vectorSearchConfiguration": {"numberOfResults": 5}},
186+
},
187+
],
188+
)
189+
190+
# WHEN calling the event handler
191+
raw_event = load_event("bedrockAgentFunctionEvent.json")
192+
raw_event["function"] = "test_function"
193+
result = app.resolve(raw_event, {})
194+
195+
# THEN include knowledge bases in response
196+
assert "knowledgeBasesConfiguration" in result
197+
assert len(result["knowledgeBasesConfiguration"]) == 1
198+
assert result["knowledgeBasesConfiguration"][0]["knowledgeBaseId"] == "kb1"

0 commit comments

Comments
 (0)