2
2
3
3
from typing import TYPE_CHECKING , Any
4
4
5
- from typing_extensions import override
6
-
7
- from aws_lambda_powertools .event_handler .api_gateway import Response , ResponseBuilder
8
-
9
5
if TYPE_CHECKING :
10
6
from collections .abc import Callable
11
7
@@ -19,7 +15,49 @@ class ResponseState(Enum):
19
15
REPROMPT = "REPROMPT"
20
16
21
17
22
- class BedrockFunctionsResponseBuilder (ResponseBuilder ):
18
+ class BedrockResponse :
19
+ """Response class for Bedrock Agent Functions
20
+
21
+ Parameters
22
+ ----------
23
+ body : Any, optional
24
+ Response body
25
+ session_attributes : dict[str, str] | None
26
+ Session attributes to include in the response
27
+ prompt_session_attributes : dict[str, str] | None
28
+ Prompt session attributes to include in the response
29
+ status_code : int
30
+ Status code to determine responseState (400 for REPROMPT, >=500 for FAILURE)
31
+
32
+ Examples
33
+ --------
34
+ ```python
35
+ @app.tool(description="Function that uses session attributes")
36
+ def test_function():
37
+ return BedrockResponse(
38
+ body="Hello",
39
+ session_attributes={"userId": "123"},
40
+ prompt_session_attributes={"lastAction": "login"}
41
+ )
42
+ ```
43
+ """
44
+
45
+ def __init__ (
46
+ self ,
47
+ body : Any = None ,
48
+ session_attributes : dict [str , str ] | None = None ,
49
+ prompt_session_attributes : dict [str , str ] | None = None ,
50
+ knowledge_bases : list [dict [str , Any ]] | None = None ,
51
+ status_code : int = 200 ,
52
+ ) -> None :
53
+ self .body = body
54
+ self .session_attributes = session_attributes
55
+ self .prompt_session_attributes = prompt_session_attributes
56
+ self .knowledge_bases = knowledge_bases
57
+ self .status_code = status_code
58
+
59
+
60
+ class BedrockFunctionsResponseBuilder :
23
61
"""
24
62
Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda
25
63
when using Bedrock Agent Functions.
@@ -28,30 +66,50 @@ class BedrockFunctionsResponseBuilder(ResponseBuilder):
28
66
we override the build method.
29
67
"""
30
68
31
- @override
32
- def build (self , event : BedrockAgentFunctionEvent , * args ) -> dict [str , Any ]:
33
- """Build the full response dict to be returned by the lambda"""
34
- self ._route (event , None )
69
+ def __init__ (self , result : BedrockResponse | Any , status_code : int = 200 ) -> None :
70
+ self .result = result
71
+ self .status_code = status_code if not isinstance (result , BedrockResponse ) else result .status_code
35
72
36
- body = self .response .body
37
- if self .response .is_json () and not isinstance (self .response .body , str ):
38
- body = self .serializer (body )
73
+ def build (self , event : BedrockAgentFunctionEvent ) -> dict [str , Any ]:
74
+ """Build the full response dict to be returned by the lambda"""
75
+ if isinstance (self .result , BedrockResponse ):
76
+ body = self .result .body
77
+ session_attributes = self .result .session_attributes
78
+ prompt_session_attributes = self .result .prompt_session_attributes
79
+ knowledge_bases = self .result .knowledge_bases
80
+ else :
81
+ body = self .result
82
+ session_attributes = None
83
+ prompt_session_attributes = None
84
+ knowledge_bases = None
39
85
40
86
response : dict [str , Any ] = {
41
87
"messageVersion" : "1.0" ,
42
88
"response" : {
43
89
"actionGroup" : event .action_group ,
44
90
"function" : event .function ,
45
- "functionResponse" : {"responseBody" : {"TEXT" : {"body" : str (body )}}},
91
+ "functionResponse" : {"responseBody" : {"TEXT" : {"body" : str (body if body is not None else "" )}}},
46
92
},
47
93
}
48
94
49
95
# Add responseState if it's an error
50
- if self .response . status_code >= 400 :
96
+ if self .status_code >= 400 :
51
97
response ["response" ]["functionResponse" ]["responseState" ] = (
52
- ResponseState .REPROMPT .value if self .response . status_code == 400 else ResponseState .FAILURE .value
98
+ ResponseState .REPROMPT .value if self .status_code == 400 else ResponseState .FAILURE .value
53
99
)
54
100
101
+ # Add session attributes if provided in response or maintain from input
102
+ response .update (
103
+ {
104
+ "sessionAttributes" : session_attributes or event .session_attributes or {},
105
+ "promptSessionAttributes" : prompt_session_attributes or event .prompt_session_attributes or {},
106
+ },
107
+ )
108
+
109
+ # Add knowledge bases configuration if provided
110
+ if knowledge_bases :
111
+ response ["knowledgeBasesConfiguration" ] = knowledge_bases
112
+
55
113
return response
56
114
57
115
@@ -127,26 +185,20 @@ def _resolve(self) -> dict[str, Any]:
127
185
function_name = self .current_event .function
128
186
129
187
if function_name not in self ._tools :
130
- return self ._response_builder_class (
131
- Response (
132
- status_code = 400 , # Using 400 to trigger REPROMPT
188
+ return BedrockFunctionsResponseBuilder (
189
+ BedrockResponse (
133
190
body = f"Function not found: { function_name } " ,
191
+ status_code = 400 , # Using 400 to trigger REPROMPT
134
192
),
135
193
).build (self .current_event )
136
194
137
195
try :
138
196
result = self ._tools [function_name ]["function" ]()
139
- # Always wrap the result in a Response object
140
- if not isinstance (result , Response ):
141
- result = Response (
142
- status_code = 200 , # Success
143
- body = result ,
144
- )
145
- return self ._response_builder_class (result ).build (self .current_event )
197
+ return BedrockFunctionsResponseBuilder (result ).build (self .current_event )
146
198
except Exception as e :
147
- return self ._response_builder_class (
148
- Response (
149
- status_code = 500 , # Using 500 to trigger FAILURE
199
+ return BedrockFunctionsResponseBuilder (
200
+ BedrockResponse (
150
201
body = f"Error: { str (e )} " ,
202
+ status_code = 500 , # Using 500 to trigger FAILURE
151
203
),
152
204
).build (self .current_event )
0 commit comments