Skip to content

Commit 44d80f8

Browse files
committed
add response
1 parent a3765f0 commit 44d80f8

File tree

2 files changed

+113
-53
lines changed

2 files changed

+113
-53
lines changed

aws_lambda_powertools/event_handler/bedrock_agent_function.py

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,64 @@
22

33
from typing import TYPE_CHECKING, Any
44

5+
from typing_extensions import override
6+
7+
from aws_lambda_powertools.event_handler.api_gateway import Response, ResponseBuilder
8+
59
if TYPE_CHECKING:
610
from collections.abc import Callable
711

12+
from enum import Enum
13+
814
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
915

1016

17+
class ResponseState(Enum):
18+
FAILURE = "FAILURE"
19+
REPROMPT = "REPROMPT"
20+
21+
22+
class BedrockFunctionsResponseBuilder(ResponseBuilder):
23+
"""
24+
Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda
25+
when using Bedrock Agent Functions.
26+
27+
Since the payload format is different from the standard API Gateway Proxy event,
28+
we override the build method.
29+
"""
30+
31+
@override
32+
def build(self, event: BedrockAgentFunctionEvent, *args) -> dict[str, Any]:
33+
"""Build the full response dict to be returned by the lambda"""
34+
self._route(event, None)
35+
36+
body = self.response.body
37+
if self.response.is_json() and not isinstance(self.response.body, str):
38+
body = self.serializer(body)
39+
40+
response: dict[str, Any] = {
41+
"messageVersion": "1.0",
42+
"response": {
43+
"actionGroup": event.action_group,
44+
"function": event.function,
45+
"functionResponse": {"responseBody": {"TEXT": {"body": str(body)}}},
46+
},
47+
}
48+
49+
# Add responseState if it's an error
50+
if self.response.status_code >= 400:
51+
response["response"]["functionResponse"]["responseState"] = (
52+
ResponseState.REPROMPT.value if self.response.status_code == 400 else ResponseState.FAILURE.value
53+
)
54+
55+
return response
56+
57+
1158
class BedrockAgentFunctionResolver:
1259
"""Bedrock Agent Function resolver that handles function definitions
1360
1461
Examples
1562
--------
16-
Simple example with a custom lambda handler
17-
1863
```python
1964
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver
2065
@@ -33,6 +78,7 @@ def lambda_handler(event, context):
3378
def __init__(self) -> None:
3479
self._tools: dict[str, dict[str, Any]] = {}
3580
self.current_event: BedrockAgentFunctionEvent | None = None
81+
self._response_builder_class = BedrockFunctionsResponseBuilder
3682

3783
def tool(self, description: str | None = None) -> Callable:
3884
"""Decorator to register a tool function"""
@@ -67,32 +113,28 @@ def _resolve(self) -> dict[str, Any]:
67113
raise ValueError("No event to process")
68114

69115
function_name = self.current_event.function
70-
action_group = self.current_event.action_group
71116

72117
if function_name not in self._tools:
73-
return self._create_response(
74-
action_group=action_group,
75-
function_name=function_name,
76-
result=f"Function not found: {function_name}",
77-
)
118+
return self._response_builder_class(
119+
Response(
120+
status_code=400, # Using 400 to trigger REPROMPT
121+
body=f"Function not found: {function_name}",
122+
),
123+
).build(self.current_event)
78124

79125
try:
80126
result = self._tools[function_name]["function"]()
81-
return self._create_response(action_group=action_group, function_name=function_name, result=result)
127+
# Always wrap the result in a Response object
128+
if not isinstance(result, Response):
129+
result = Response(
130+
status_code=200, # Success
131+
body=result,
132+
)
133+
return self._response_builder_class(result).build(self.current_event)
82134
except Exception as e:
83-
return self._create_response(
84-
action_group=action_group,
85-
function_name=function_name,
86-
result=f"Error: {str(e)}",
87-
)
88-
89-
def _create_response(self, action_group: str, function_name: str, result: Any) -> dict[str, Any]:
90-
"""Create response in Bedrock Agent format"""
91-
return {
92-
"messageVersion": "1.0",
93-
"response": {
94-
"actionGroup": action_group,
95-
"function": function_name,
96-
"functionResponse": {"responseBody": {"TEXT": {"body": str(result)}}},
97-
},
98-
}
135+
return self._response_builder_class(
136+
Response(
137+
status_code=500, # Using 500 to trigger FAILURE
138+
body=f"Error: {str(e)}",
139+
),
140+
).build(self.current_event)

tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,34 @@
11
from __future__ import annotations
22

3+
import json
4+
35
import pytest
46

5-
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver
7+
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, Response, content_types
68
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
79
from tests.functional.utils import load_event
810

911

10-
def test_bedrock_agent_function():
12+
def test_bedrock_agent_function_with_string_response():
1113
# GIVEN a Bedrock Agent Function resolver
1214
app = BedrockAgentFunctionResolver()
1315

14-
@app.tool(description="Gets the current time")
15-
def get_current_time():
16+
@app.tool(description="Returns a string")
17+
def test_function():
1618
assert isinstance(app.current_event, BedrockAgentFunctionEvent)
17-
return "2024-02-01T12:00:00Z"
19+
return "Hello from string"
1820

1921
# WHEN calling the event handler
2022
raw_event = load_event("bedrockAgentFunctionEvent.json")
21-
raw_event["function"] = "get_current_time" # ensure function name matches
23+
raw_event["function"] = "test_function"
2224
result = app.resolve(raw_event, {})
2325

24-
# THEN process event correctly
26+
# THEN process event correctly with string response
2527
assert result["messageVersion"] == "1.0"
2628
assert result["response"]["actionGroup"] == raw_event["actionGroup"]
27-
assert result["response"]["function"] == "get_current_time"
28-
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "2024-02-01T12:00:00Z"
29+
assert result["response"]["function"] == "test_function"
30+
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello from string"
31+
assert "responseState" not in result["response"]["functionResponse"] # Success has no state
2932

3033

3134
def test_bedrock_agent_function_with_error():
@@ -46,29 +49,53 @@ def error_function():
4649
assert result["response"]["actionGroup"] == raw_event["actionGroup"]
4750
assert result["response"]["function"] == "error_function"
4851
assert "Error: Something went wrong" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
52+
assert result["response"]["functionResponse"]["responseState"] == "FAILURE"
4953

5054

5155
def test_bedrock_agent_function_not_found():
5256
# GIVEN a Bedrock Agent Function resolver
5357
app = BedrockAgentFunctionResolver()
5458

55-
@app.tool(description="Test function")
56-
def test_function():
57-
return "test"
58-
5959
# WHEN calling the event handler with a non-existent function
6060
raw_event = load_event("bedrockAgentFunctionEvent.json")
6161
raw_event["function"] = "nonexistent_function"
6262
result = app.resolve(raw_event, {})
6363

64-
# THEN return function not found response
64+
# THEN return function not found response with REPROMPT state
6565
assert result["messageVersion"] == "1.0"
6666
assert result["response"]["actionGroup"] == raw_event["actionGroup"]
6767
assert result["response"]["function"] == "nonexistent_function"
6868
assert "Function not found" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
69+
assert result["response"]["functionResponse"]["responseState"] == "REPROMPT"
6970

7071

71-
def test_bedrock_agent_function_missing_description():
72+
def test_bedrock_agent_function_with_response_object():
73+
# GIVEN a Bedrock Agent Function resolver
74+
app = BedrockAgentFunctionResolver()
75+
76+
@app.tool(description="Returns a Response object")
77+
def test_function():
78+
return Response(
79+
status_code=200,
80+
content_type=content_types.APPLICATION_JSON,
81+
body={"message": "Hello from Response"},
82+
)
83+
84+
# WHEN calling the event handler
85+
raw_event = load_event("bedrockAgentFunctionEvent.json")
86+
raw_event["function"] = "test_function"
87+
result = app.resolve(raw_event, {})
88+
89+
# THEN process event correctly with Response object
90+
assert result["messageVersion"] == "1.0"
91+
assert result["response"]["actionGroup"] == raw_event["actionGroup"]
92+
assert result["response"]["function"] == "test_function"
93+
response_body = result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
94+
assert json.loads(response_body) == {"message": "Hello from Response"}
95+
assert "responseState" not in result["response"]["functionResponse"] # Success has no state
96+
97+
98+
def test_bedrock_agent_function_registration():
7299
# GIVEN a Bedrock Agent Function resolver
73100
app = BedrockAgentFunctionResolver()
74101

@@ -80,32 +107,23 @@ def test_bedrock_agent_function_missing_description():
80107
def test_function():
81108
return "test"
82109

83-
84-
def test_bedrock_agent_function_duplicate_registration():
85-
# GIVEN a Bedrock Agent Function resolver
86-
app = BedrockAgentFunctionResolver()
87-
88110
# WHEN registering the same function twice
111+
# THEN raise ValueError
89112
@app.tool(description="First registration")
90-
def test_function():
113+
def duplicate_function():
91114
return "test"
92115

93-
# THEN raise ValueError on second registration
94-
with pytest.raises(ValueError, match="Tool 'test_function' already registered"):
116+
with pytest.raises(ValueError, match="Tool 'duplicate_function' already registered"):
95117

96118
@app.tool(description="Second registration")
97-
def test_function(): # noqa: F811
119+
def duplicate_function(): # noqa: F811
98120
return "test"
99121

100122

101123
def test_bedrock_agent_function_invalid_event():
102124
# GIVEN a Bedrock Agent Function resolver
103125
app = BedrockAgentFunctionResolver()
104126

105-
@app.tool(description="Test function")
106-
def test_function():
107-
return "test"
108-
109127
# WHEN calling with invalid event
110128
# THEN raise ValueError
111129
with pytest.raises(ValueError, match="Missing required field"):

0 commit comments

Comments
 (0)