Skip to content

Commit bed8f3f

Browse files
committed
create resolver
1 parent 41bc401 commit bed8f3f

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

aws_lambda_powertools/event_handler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
1414
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
15+
from aws_lambda_powertools.event_handler.bedrock_agent_function import BedrockAgentFunctionResolver
1516
from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver
1617
from aws_lambda_powertools.event_handler.lambda_function_url import (
1718
LambdaFunctionUrlResolver,
@@ -26,6 +27,7 @@
2627
"ALBResolver",
2728
"ApiGatewayResolver",
2829
"BedrockAgentResolver",
30+
"BedrockAgentFunctionResolver",
2931
"CORSConfig",
3032
"LambdaFunctionUrlResolver",
3133
"Response",
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
if TYPE_CHECKING:
6+
from collections.abc import Callable
7+
8+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
9+
10+
11+
class BedrockAgentFunctionResolver:
12+
"""Bedrock Agent Function resolver that handles function definitions
13+
14+
Examples
15+
--------
16+
Simple example with a custom lambda handler
17+
18+
```python
19+
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver
20+
21+
app = BedrockAgentFunctionResolver()
22+
23+
@app.tool(description="Gets the current UTC time")
24+
def get_current_time():
25+
from datetime import datetime
26+
return datetime.utcnow().isoformat()
27+
28+
def lambda_handler(event, context):
29+
return app.resolve(event, context)
30+
```
31+
"""
32+
def __init__(self) -> None:
33+
self._tools: dict[str, dict[str, Any]] = {}
34+
self.current_event: BedrockAgentFunctionEvent | None = None
35+
36+
def tool(self, description: str | None = None) -> Callable:
37+
"""Decorator to register a tool function"""
38+
def decorator(func: Callable) -> Callable:
39+
if not description:
40+
raise ValueError("Tool description is required")
41+
42+
function_name = func.__name__
43+
if function_name in self._tools:
44+
raise ValueError(f"Tool '{function_name}' already registered")
45+
46+
self._tools[function_name] = {
47+
"function": func,
48+
"description": description,
49+
}
50+
return func
51+
return decorator
52+
53+
def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]:
54+
"""Resolves the function call from Bedrock Agent event"""
55+
try:
56+
self.current_event = BedrockAgentFunctionEvent(event)
57+
return self._resolve()
58+
except KeyError as e:
59+
raise ValueError(f"Missing required field: {str(e)}")
60+
61+
def _resolve(self) -> dict[str, Any]:
62+
"""Internal resolution logic"""
63+
function_name = self.current_event.function
64+
action_group = self.current_event.action_group
65+
66+
if function_name not in self._tools:
67+
return self._create_response(
68+
action_group=action_group,
69+
function_name=function_name,
70+
result=f"Function not found: {function_name}"
71+
)
72+
73+
try:
74+
result = self._tools[function_name]["function"]()
75+
return self._create_response(
76+
action_group=action_group,
77+
function_name=function_name,
78+
result=result
79+
)
80+
except Exception as e:
81+
return self._create_response(
82+
action_group=action_group,
83+
function_name=function_name,
84+
result=f"Error: {str(e)}"
85+
)
86+
87+
def _create_response(self, action_group: str, function_name: str, result: Any) -> dict[str, Any]:
88+
"""Create response in Bedrock Agent format"""
89+
return {
90+
"messageVersion": "1.0",
91+
"response": {
92+
"actionGroup": action_group,
93+
"function": function_name,
94+
"functionResponse": {
95+
"responseBody": {
96+
"TEXT": {
97+
"body": str(result)
98+
}
99+
}
100+
}
101+
}
102+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver
5+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
6+
from tests.functional.utils import load_event
7+
8+
9+
def test_bedrock_agent_function():
10+
# GIVEN a Bedrock Agent Function resolver
11+
app = BedrockAgentFunctionResolver()
12+
13+
@app.tool(description="Gets the current time")
14+
def get_current_time():
15+
assert isinstance(app.current_event, BedrockAgentFunctionEvent)
16+
return "2024-02-01T12:00:00Z"
17+
18+
# WHEN calling the event handler
19+
raw_event = load_event("bedrockAgentFunctionEvent.json")
20+
raw_event["function"] = "get_current_time" # ensure function name matches
21+
result = app.resolve(raw_event, {})
22+
23+
# THEN process event correctly
24+
assert result["messageVersion"] == "1.0"
25+
assert result["response"]["actionGroup"] == raw_event["actionGroup"]
26+
assert result["response"]["function"] == "get_current_time"
27+
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "2024-02-01T12:00:00Z"
28+
29+
30+
def test_bedrock_agent_function_with_error():
31+
# GIVEN a Bedrock Agent Function resolver
32+
app = BedrockAgentFunctionResolver()
33+
34+
@app.tool(description="Function that raises error")
35+
def error_function():
36+
raise ValueError("Something went wrong")
37+
38+
# WHEN calling the event handler with a function that raises an error
39+
raw_event = load_event("bedrockAgentFunctionEvent.json")
40+
raw_event["function"] = "error_function"
41+
result = app.resolve(raw_event, {})
42+
43+
# THEN process the error correctly
44+
assert result["messageVersion"] == "1.0"
45+
assert result["response"]["actionGroup"] == raw_event["actionGroup"]
46+
assert result["response"]["function"] == "error_function"
47+
assert "Error: Something went wrong" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
48+
49+
50+
def test_bedrock_agent_function_not_found():
51+
# GIVEN a Bedrock Agent Function resolver
52+
app = BedrockAgentFunctionResolver()
53+
54+
@app.tool(description="Test function")
55+
def test_function():
56+
return "test"
57+
58+
# WHEN calling the event handler with a non-existent function
59+
raw_event = load_event("bedrockAgentFunctionEvent.json")
60+
raw_event["function"] = "nonexistent_function"
61+
result = app.resolve(raw_event, {})
62+
63+
# THEN return function not found response
64+
assert result["messageVersion"] == "1.0"
65+
assert result["response"]["actionGroup"] == raw_event["actionGroup"]
66+
assert result["response"]["function"] == "nonexistent_function"
67+
assert "Function not found" in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
68+
69+
70+
def test_bedrock_agent_function_missing_description():
71+
# GIVEN a Bedrock Agent Function resolver
72+
app = BedrockAgentFunctionResolver()
73+
74+
# WHEN registering a tool without description
75+
# THEN raise ValueError
76+
with pytest.raises(ValueError, match="Tool description is required"):
77+
@app.tool()
78+
def test_function():
79+
return "test"
80+
81+
82+
def test_bedrock_agent_function_duplicate_registration():
83+
# GIVEN a Bedrock Agent Function resolver
84+
app = BedrockAgentFunctionResolver()
85+
86+
# WHEN registering the same function twice
87+
@app.tool(description="First registration")
88+
def test_function():
89+
return "test"
90+
91+
# THEN raise ValueError on second registration
92+
with pytest.raises(ValueError, match="Tool 'test_function' already registered"):
93+
@app.tool(description="Second registration")
94+
def test_function(): # noqa: F811
95+
return "test"
96+
97+
98+
def test_bedrock_agent_function_invalid_event():
99+
# GIVEN a Bedrock Agent Function resolver
100+
app = BedrockAgentFunctionResolver()
101+
102+
@app.tool(description="Test function")
103+
def test_function():
104+
return "test"
105+
106+
# WHEN calling with invalid event
107+
# THEN raise ValueError
108+
with pytest.raises(ValueError, match="Missing required field"):
109+
app.resolve({}, {})

0 commit comments

Comments
 (0)