Skip to content

Commit 20215ed

Browse files
Small refactor + documentation
1 parent 4211b72 commit 20215ed

File tree

9 files changed

+331
-285
lines changed

9 files changed

+331
-285
lines changed

aws_lambda_powertools/event_handler/bedrock_agent_function.py

Lines changed: 77 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,45 @@
22

33
import inspect
44
import warnings
5-
from typing import TYPE_CHECKING, Any, Literal
5+
from collections.abc import Callable
6+
from typing import Any, Literal, TypeVar
67

8+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
79
from aws_lambda_powertools.warnings import PowertoolsUserWarning
810

9-
if TYPE_CHECKING:
10-
from collections.abc import Callable
11-
12-
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
11+
# Define a generic type for the function
12+
T = TypeVar("T", bound=Callable[..., Any])
1313

1414

1515
class BedrockFunctionResponse:
16-
"""Response class for Bedrock Agent Functions
16+
"""Response class for Bedrock Agent Functions.
1717
1818
Parameters
1919
----------
2020
body : Any, optional
21-
Response body
22-
session_attributes : dict[str, str] | None
23-
Session attributes to include in the response
24-
prompt_session_attributes : dict[str, str] | None
25-
Prompt session attributes to include in the response
26-
response_state : Literal["FAILURE", "REPROMPT"] | None
27-
Response state ("FAILURE" or "REPROMPT")
21+
Response body to be returned to the caller.
22+
session_attributes : dict[str, str] or None, optional
23+
Session attributes to include in the response for maintaining state.
24+
prompt_session_attributes : dict[str, str] or None, optional
25+
Prompt session attributes to include in the response.
26+
knowledge_bases : list[dict[str, Any]] or None, optional
27+
Knowledge bases to include in the response.
28+
response_state : {"FAILURE", "REPROMPT"} or None, optional
29+
Response state indicating if the function failed or needs reprompting.
2830
2931
Examples
3032
--------
31-
```python
32-
@app.tool(description="Function that uses session attributes")
33-
def test_function():
34-
return BedrockFunctionResponse(
35-
body="Hello",
36-
session_attributes={"userId": "123"},
37-
prompt_session_attributes={"lastAction": "login"}
38-
)
39-
```
33+
>>> @app.tool(description="Function that uses session attributes")
34+
>>> def test_function():
35+
... return BedrockFunctionResponse(
36+
... body="Hello",
37+
... session_attributes={"userId": "123"},
38+
... prompt_session_attributes={"lastAction": "login"}
39+
... )
40+
41+
Notes
42+
-----
43+
The `response_state` parameter can only be set to "FAILURE" or "REPROMPT".
4044
"""
4145

4246
def __init__(
@@ -47,7 +51,7 @@ def __init__(
4751
knowledge_bases: list[dict[str, Any]] | None = None,
4852
response_state: Literal["FAILURE", "REPROMPT"] | None = None,
4953
) -> None:
50-
if response_state is not None and response_state not in ["FAILURE", "REPROMPT"]:
54+
if response_state and response_state not in ["FAILURE", "REPROMPT"]:
5155
raise ValueError("responseState must be 'FAILURE' or 'REPROMPT'")
5256

5357
self.body = body
@@ -67,45 +71,35 @@ def __init__(self, result: BedrockFunctionResponse | Any) -> None:
6771
self.result = result
6872

6973
def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]:
70-
"""Build the full response dict to be returned by the lambda"""
71-
if isinstance(self.result, BedrockFunctionResponse):
72-
body = self.result.body
73-
session_attributes = self.result.session_attributes
74-
prompt_session_attributes = self.result.prompt_session_attributes
75-
knowledge_bases = self.result.knowledge_bases
76-
response_state = self.result.response_state
77-
78-
else:
79-
body = self.result
80-
session_attributes = None
81-
prompt_session_attributes = None
82-
knowledge_bases = None
83-
response_state = None
74+
result_obj = self.result
8475

76+
# Extract attributes from BedrockFunctionResponse or use defaults
77+
body = getattr(result_obj, "body", result_obj)
78+
session_attributes = getattr(result_obj, "session_attributes", None)
79+
prompt_session_attributes = getattr(result_obj, "prompt_session_attributes", None)
80+
knowledge_bases = getattr(result_obj, "knowledge_bases", None)
81+
response_state = getattr(result_obj, "response_state", None)
82+
83+
# Build base response structure
8584
# Per AWS Bedrock documentation, currently only "TEXT" is supported as the responseBody content type
8685
# https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html
8786
response: dict[str, Any] = {
8887
"messageVersion": "1.0",
8988
"response": {
9089
"actionGroup": event.action_group,
9190
"function": event.function,
92-
"functionResponse": {"responseBody": {"TEXT": {"body": str(body if body is not None else "")}}},
91+
"functionResponse": {
92+
"responseBody": {"TEXT": {"body": str(body if body is not None else "")}},
93+
},
9394
},
95+
"sessionAttributes": session_attributes or event.session_attributes or {},
96+
"promptSessionAttributes": prompt_session_attributes or event.prompt_session_attributes or {},
9497
}
9598

96-
# Add responseState if provided
99+
# Add optional fields when present
97100
if response_state:
98101
response["response"]["functionResponse"]["responseState"] = response_state
99102

100-
# Add session attributes if provided in response or maintain from input
101-
response.update(
102-
{
103-
"sessionAttributes": session_attributes or event.session_attributes or {},
104-
"promptSessionAttributes": prompt_session_attributes or event.prompt_session_attributes or {},
105-
},
106-
)
107-
108-
# Add knowledge bases configuration if provided
109103
if knowledge_bases:
110104
response["knowledgeBasesConfiguration"] = knowledge_bases
111105

@@ -132,27 +126,35 @@ def lambda_handler(event, context):
132126
```
133127
"""
134128

129+
context: dict
130+
135131
def __init__(self) -> None:
136132
self._tools: dict[str, dict[str, Any]] = {}
137133
self.current_event: BedrockAgentFunctionEvent | None = None
134+
self.context = {}
138135
self._response_builder_class = BedrockFunctionsResponseBuilder
139136

140137
def tool(
141138
self,
142-
description: str | None = None,
143139
name: str | None = None,
144-
) -> Callable:
140+
description: str | None = None,
141+
) -> Callable[[T], T]:
145142
"""Decorator to register a tool function
146143
147144
Parameters
148145
----------
149-
description : str | None
150-
Description of what the tool does
151146
name : str | None
152147
Custom name for the tool. If not provided, uses the function name
148+
description : str | None
149+
Description of what the tool does
150+
151+
Returns
152+
-------
153+
Callable
154+
Decorator function that registers and returns the original function
153155
"""
154156

155-
def decorator(func: Callable) -> Callable:
157+
def decorator(func: T) -> T:
156158
function_name = name or func.__name__
157159
if function_name in self._tools:
158160
warnings.warn(
@@ -175,7 +177,7 @@ def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]:
175177
self.current_event = BedrockAgentFunctionEvent(event)
176178
return self._resolve()
177179
except KeyError as e:
178-
raise ValueError(f"Missing required field: {str(e)}")
180+
raise ValueError(f"Missing required field: {str(e)}") from e
179181

180182
def _resolve(self) -> dict[str, Any]:
181183
"""Internal resolution logic"""
@@ -185,24 +187,30 @@ def _resolve(self) -> dict[str, Any]:
185187
function_name = self.current_event.function
186188

187189
try:
188-
parameters = {}
189-
if hasattr(self.current_event, "parameters"):
190-
for param in self.current_event.parameters:
191-
parameters[param.name] = param.value
190+
# Extract parameters from the event
191+
parameters = {param.name: param.value for param in getattr(self.current_event, "parameters", [])}
192192

193193
func = self._tools[function_name]["function"]
194+
# Filter parameters to only include those expected by the function
194195
sig = inspect.signature(func)
196+
valid_params = {name: value for name, value in parameters.items() if name in sig.parameters}
195197

196-
valid_params = {}
197-
for name, value in parameters.items():
198-
if name in sig.parameters:
199-
valid_params[name] = value
200-
198+
# Call the function with the filtered parameters
201199
result = func(**valid_params)
200+
201+
self.clear_context()
202+
203+
# Build and return the response
202204
return BedrockFunctionsResponseBuilder(result).build(self.current_event)
203-
except Exception as e:
204-
return BedrockFunctionsResponseBuilder(
205-
BedrockFunctionResponse(
206-
body=f"Error: {str(e)}",
207-
),
208-
).build(self.current_event)
205+
except Exception as error:
206+
# Return a formatted error response
207+
error_response = BedrockFunctionResponse(body=f"Error: {str(error)}", response_state="FAILURE")
208+
return BedrockFunctionsResponseBuilder(error_response).build(self.current_event)
209+
210+
def append_context(self, **additional_context):
211+
"""Append key=value data as routing context"""
212+
self.context.update(**additional_context)
213+
214+
def clear_context(self):
215+
"""Resets routing context"""
216+
self.context.clear()

0 commit comments

Comments
 (0)