Skip to content

Commit 54a7edf

Browse files
committed
params injection
1 parent 39e0d36 commit 54a7edf

File tree

2 files changed

+73
-21
lines changed

2 files changed

+73
-21
lines changed

aws_lambda_powertools/event_handler/bedrock_agent_function.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any
3+
import inspect
4+
import warnings
5+
from typing import TYPE_CHECKING, Any, Literal
6+
7+
from aws_lambda_powertools.warnings import PowertoolsUserWarning
48

59
if TYPE_CHECKING:
610
from collections.abc import Callable
@@ -19,7 +23,7 @@ class BedrockFunctionResponse:
1923
Session attributes to include in the response
2024
prompt_session_attributes : dict[str, str] | None
2125
Prompt session attributes to include in the response
22-
response_state : str | None
26+
response_state : Literal["FAILURE", "REPROMPT"] | None
2327
Response state ("FAILURE" or "REPROMPT")
2428
2529
Examples
@@ -41,10 +45,10 @@ def __init__(
4145
session_attributes: dict[str, str] | None = None,
4246
prompt_session_attributes: dict[str, str] | None = None,
4347
knowledge_bases: list[dict[str, Any]] | None = None,
44-
response_state: str | None = None,
48+
response_state: Literal["FAILURE", "REPROMPT"] | None = None,
4549
) -> None:
4650
if response_state is not None and response_state not in ["FAILURE", "REPROMPT"]:
47-
raise ValueError("responseState must be None, 'FAILURE' or 'REPROMPT'")
51+
raise ValueError("responseState must be 'FAILURE' or 'REPROMPT'")
4852

4953
self.body = body
5054
self.session_attributes = session_attributes
@@ -78,6 +82,8 @@ def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]:
7882
knowledge_bases = None
7983
response_state = None
8084

85+
# Per AWS Bedrock documentation, currently only "TEXT" is supported as the responseBody content type
86+
# https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html
8187
response: dict[str, Any] = {
8288
"messageVersion": "1.0",
8389
"response": {
@@ -147,12 +153,13 @@ def tool(
147153
"""
148154

149155
def decorator(func: Callable) -> Callable:
150-
if not description:
151-
raise ValueError("Tool description is required")
152-
153156
function_name = name or func.__name__
154157
if function_name in self._tools:
155-
raise ValueError(f"Tool '{function_name}' already registered")
158+
warnings.warn(
159+
f"Tool '{function_name}' already registered. Overwriting with new definition.",
160+
PowertoolsUserWarning,
161+
stacklevel=2,
162+
)
156163

157164
self._tools[function_name] = {
158165
"function": func,
@@ -178,7 +185,20 @@ def _resolve(self) -> dict[str, Any]:
178185
function_name = self.current_event.function
179186

180187
try:
181-
result = self._tools[function_name]["function"]()
188+
parameters = {}
189+
if hasattr(self.current_event, "parameters"):
190+
for param in self.current_event.parameters:
191+
parameters[param.name] = param.value
192+
193+
func = self._tools[function_name]["function"]
194+
sig = inspect.signature(func)
195+
196+
valid_params = {}
197+
for name, value in parameters.items():
198+
if name in sig.parameters:
199+
valid_params[name] = value
200+
201+
result = func(**valid_params)
182202
return BedrockFunctionsResponseBuilder(result).build(self.current_event)
183203
except Exception as e:
184204
return BedrockFunctionsResponseBuilder(

tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, BedrockFunctionResponse
66
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
7+
from aws_lambda_powertools.warnings import PowertoolsUserWarning
78
from tests.functional.utils import load_event
89

910

@@ -59,22 +60,25 @@ def test_bedrock_agent_function_registration():
5960
# GIVEN a Bedrock Agent Function resolver
6061
app = BedrockAgentFunctionResolver()
6162

62-
# WHEN registering without description or with duplicate name
63-
with pytest.raises(ValueError, match="Tool description is required"):
64-
65-
@app.tool()
66-
def test_function():
67-
return "test"
68-
63+
# WHEN registering with duplicate name
6964
@app.tool(name="custom", description="First registration")
7065
def first_function():
71-
return "test"
66+
return "first test"
7267

73-
with pytest.raises(ValueError, match="Tool 'custom' already registered"):
68+
# THEN a warning should be issued when registering a duplicate
69+
with pytest.warns(PowertoolsUserWarning, match="Tool 'custom' already registered"):
7470

7571
@app.tool(name="custom", description="Second registration")
7672
def second_function():
77-
return "test"
73+
return "second test"
74+
75+
# AND the most recent function should be registered
76+
raw_event = load_event("bedrockAgentFunctionEvent.json")
77+
raw_event["function"] = "custom"
78+
result = app.resolve(raw_event, {})
79+
80+
# The second function should be used
81+
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "second test"
7882

7983

8084
def test_bedrock_agent_function_with_optional_fields():
@@ -156,7 +160,7 @@ def test_resolve_with_no_registered_function():
156160

157161
def test_bedrock_function_response_state_validation():
158162
# GIVEN invalid and valid response states
159-
valid_states = [None, "FAILURE", "REPROMPT"]
163+
valid_states = ["FAILURE", "REPROMPT"]
160164
invalid_state = "INVALID"
161165

162166
# WHEN creating responses with valid states
@@ -172,4 +176,32 @@ def test_bedrock_function_response_state_validation():
172176
with pytest.raises(ValueError) as exc_info:
173177
BedrockFunctionResponse(body="test", response_state=invalid_state)
174178

175-
assert str(exc_info.value) == "responseState must be None, 'FAILURE' or 'REPROMPT'"
179+
assert str(exc_info.value) == "responseState must be 'FAILURE' or 'REPROMPT'"
180+
181+
182+
def test_bedrock_agent_function_with_parameters():
183+
# GIVEN a Bedrock Agent Function resolver
184+
app = BedrockAgentFunctionResolver()
185+
186+
# Track received parameters
187+
received_params = {}
188+
189+
@app.tool(description="Function that accepts parameters")
190+
def vacation_request(startDate, endDate):
191+
# Store received parameters for assertion
192+
received_params["startDate"] = startDate
193+
received_params["endDate"] = endDate
194+
return f"Vacation request from {startDate} to {endDate} submitted"
195+
196+
# WHEN calling the event handler with parameters
197+
raw_event = load_event("bedrockAgentFunctionEvent.json")
198+
raw_event["function"] = "vacation_request"
199+
result = app.resolve(raw_event, {})
200+
201+
# THEN parameters should be correctly passed to the function
202+
assert received_params["startDate"] == "2024-03-15"
203+
assert received_params["endDate"] == "2024-03-20"
204+
assert (
205+
"Vacation request from 2024-03-15 to 2024-03-20 submitted"
206+
in result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
207+
)

0 commit comments

Comments
 (0)