22
33import inspect
44import warnings
5- from typing import TYPE_CHECKING , Any , Literal
5+ from collections .abc import Callable
6+ from typing import Any , Literal , TypeVar
67
8+ from aws_lambda_powertools .utilities .data_classes import BedrockAgentFunctionEvent
79from aws_lambda_powertools .warnings import PowertoolsUserWarning
810
9- if TYPE_CHECKING :
10- from collections .abc import Callable
11-
12- from aws_lambda_powertools .utilities .data_classes import BedrockAgentFunctionEvent
11+ # Define a generic type for the function
12+ T = TypeVar ("T" , bound = Callable [..., Any ])
1313
1414
1515class BedrockFunctionResponse :
16- """Response class for Bedrock Agent Functions
16+ """Response class for Bedrock Agent Functions.
1717
1818 Parameters
1919 ----------
2020 body : Any, optional
21- Response body
22- session_attributes : dict[str, str] | None
23- Session attributes to include in the response
24- prompt_session_attributes : dict[str, str] | None
25- Prompt session attributes to include in the response
26- response_state : Literal["FAILURE", "REPROMPT"] | None
27- Response state ("FAILURE" or "REPROMPT")
21+ Response body to be returned to the caller.
22+ session_attributes : dict[str, str] or None, optional
23+ Session attributes to include in the response for maintaining state.
24+ prompt_session_attributes : dict[str, str] or None, optional
25+ Prompt session attributes to include in the response.
26+ knowledge_bases : list[dict[str, Any]] or None, optional
27+ Knowledge bases to include in the response.
28+ response_state : {"FAILURE", "REPROMPT"} or None, optional
29+ Response state indicating if the function failed or needs reprompting.
2830
2931 Examples
3032 --------
31- ```python
32- @app.tool(description="Function that uses session attributes")
33- def test_function():
34- return BedrockFunctionResponse(
35- body="Hello",
36- session_attributes={"userId": "123"},
37- prompt_session_attributes={"lastAction": "login"}
38- )
39- ```
33+ >>> @app.tool(description="Function that uses session attributes")
34+ >>> def test_function():
35+ ... return BedrockFunctionResponse(
36+ ... body="Hello",
37+ ... session_attributes={"userId": "123"},
38+ ... prompt_session_attributes={"lastAction": "login"}
39+ ... )
40+
41+ Notes
42+ -----
43+ The `response_state` parameter can only be set to "FAILURE" or "REPROMPT".
4044 """
4145
4246 def __init__ (
@@ -47,7 +51,7 @@ def __init__(
4751 knowledge_bases : list [dict [str , Any ]] | None = None ,
4852 response_state : Literal ["FAILURE" , "REPROMPT" ] | None = None ,
4953 ) -> None :
50- if response_state is not None and response_state not in ["FAILURE" , "REPROMPT" ]:
54+ if response_state and response_state not in ["FAILURE" , "REPROMPT" ]:
5155 raise ValueError ("responseState must be 'FAILURE' or 'REPROMPT'" )
5256
5357 self .body = body
@@ -67,45 +71,35 @@ def __init__(self, result: BedrockFunctionResponse | Any) -> None:
6771 self .result = result
6872
6973 def build (self , event : BedrockAgentFunctionEvent ) -> dict [str , Any ]:
70- """Build the full response dict to be returned by the lambda"""
71- if isinstance (self .result , BedrockFunctionResponse ):
72- body = self .result .body
73- session_attributes = self .result .session_attributes
74- prompt_session_attributes = self .result .prompt_session_attributes
75- knowledge_bases = self .result .knowledge_bases
76- response_state = self .result .response_state
77-
78- else :
79- body = self .result
80- session_attributes = None
81- prompt_session_attributes = None
82- knowledge_bases = None
83- response_state = None
74+ result_obj = self .result
8475
76+ # Extract attributes from BedrockFunctionResponse or use defaults
77+ body = getattr (result_obj , "body" , result_obj )
78+ session_attributes = getattr (result_obj , "session_attributes" , None )
79+ prompt_session_attributes = getattr (result_obj , "prompt_session_attributes" , None )
80+ knowledge_bases = getattr (result_obj , "knowledge_bases" , None )
81+ response_state = getattr (result_obj , "response_state" , None )
82+
83+ # Build base response structure
8584 # Per AWS Bedrock documentation, currently only "TEXT" is supported as the responseBody content type
8685 # https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html
8786 response : dict [str , Any ] = {
8887 "messageVersion" : "1.0" ,
8988 "response" : {
9089 "actionGroup" : event .action_group ,
9190 "function" : event .function ,
92- "functionResponse" : {"responseBody" : {"TEXT" : {"body" : str (body if body is not None else "" )}}},
91+ "functionResponse" : {
92+ "responseBody" : {"TEXT" : {"body" : str (body if body is not None else "" )}},
93+ },
9394 },
95+ "sessionAttributes" : session_attributes or event .session_attributes or {},
96+ "promptSessionAttributes" : prompt_session_attributes or event .prompt_session_attributes or {},
9497 }
9598
96- # Add responseState if provided
99+ # Add optional fields when present
97100 if response_state :
98101 response ["response" ]["functionResponse" ]["responseState" ] = response_state
99102
100- # Add session attributes if provided in response or maintain from input
101- response .update (
102- {
103- "sessionAttributes" : session_attributes or event .session_attributes or {},
104- "promptSessionAttributes" : prompt_session_attributes or event .prompt_session_attributes or {},
105- },
106- )
107-
108- # Add knowledge bases configuration if provided
109103 if knowledge_bases :
110104 response ["knowledgeBasesConfiguration" ] = knowledge_bases
111105
@@ -132,27 +126,35 @@ def lambda_handler(event, context):
132126 ```
133127 """
134128
129+ context : dict
130+
135131 def __init__ (self ) -> None :
136132 self ._tools : dict [str , dict [str , Any ]] = {}
137133 self .current_event : BedrockAgentFunctionEvent | None = None
134+ self .context = {}
138135 self ._response_builder_class = BedrockFunctionsResponseBuilder
139136
140137 def tool (
141138 self ,
142- description : str | None = None ,
143139 name : str | None = None ,
144- ) -> Callable :
140+ description : str | None = None ,
141+ ) -> Callable [[T ], T ]:
145142 """Decorator to register a tool function
146143
147144 Parameters
148145 ----------
149- description : str | None
150- Description of what the tool does
151146 name : str | None
152147 Custom name for the tool. If not provided, uses the function name
148+ description : str | None
149+ Description of what the tool does
150+
151+ Returns
152+ -------
153+ Callable
154+ Decorator function that registers and returns the original function
153155 """
154156
155- def decorator (func : Callable ) -> Callable :
157+ def decorator (func : T ) -> T :
156158 function_name = name or func .__name__
157159 if function_name in self ._tools :
158160 warnings .warn (
@@ -175,7 +177,7 @@ def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]:
175177 self .current_event = BedrockAgentFunctionEvent (event )
176178 return self ._resolve ()
177179 except KeyError as e :
178- raise ValueError (f"Missing required field: { str (e )} " )
180+ raise ValueError (f"Missing required field: { str (e )} " ) from e
179181
180182 def _resolve (self ) -> dict [str , Any ]:
181183 """Internal resolution logic"""
@@ -185,24 +187,30 @@ def _resolve(self) -> dict[str, Any]:
185187 function_name = self .current_event .function
186188
187189 try :
188- parameters = {}
189- if hasattr (self .current_event , "parameters" ):
190- for param in self .current_event .parameters :
191- parameters [param .name ] = param .value
190+ # Extract parameters from the event
191+ parameters = {param .name : param .value for param in getattr (self .current_event , "parameters" , [])}
192192
193193 func = self ._tools [function_name ]["function" ]
194+ # Filter parameters to only include those expected by the function
194195 sig = inspect .signature (func )
196+ valid_params = {name : value for name , value in parameters .items () if name in sig .parameters }
195197
196- valid_params = {}
197- for name , value in parameters .items ():
198- if name in sig .parameters :
199- valid_params [name ] = value
200-
198+ # Call the function with the filtered parameters
201199 result = func (** valid_params )
200+
201+ self .clear_context ()
202+
203+ # Build and return the response
202204 return BedrockFunctionsResponseBuilder (result ).build (self .current_event )
203- except Exception as e :
204- return BedrockFunctionsResponseBuilder (
205- BedrockFunctionResponse (
206- body = f"Error: { str (e )} " ,
207- ),
208- ).build (self .current_event )
205+ except Exception as error :
206+ # Return a formatted error response
207+ error_response = BedrockFunctionResponse (body = f"Error: { str (error )} " , response_state = "FAILURE" )
208+ return BedrockFunctionsResponseBuilder (error_response ).build (self .current_event )
209+
210+ def append_context (self , ** additional_context ):
211+ """Append key=value data as routing context"""
212+ self .context .update (** additional_context )
213+
214+ def clear_context (self ):
215+ """Resets routing context"""
216+ self .context .clear ()
0 commit comments