1
+ from __future__ import annotations
2
+
3
+ import pytest
4
+ from aws_lambda_powertools .event_handler import BedrockAgentFunctionResolver
5
+ from aws_lambda_powertools .utilities .data_classes import BedrockAgentFunctionEvent
6
+ from tests .functional .utils import load_event
7
+
8
+
9
+ def test_bedrock_agent_function ():
10
+ # GIVEN a Bedrock Agent Function resolver
11
+ app = BedrockAgentFunctionResolver ()
12
+
13
+ @app .tool (description = "Gets the current time" )
14
+ def get_current_time ():
15
+ assert isinstance (app .current_event , BedrockAgentFunctionEvent )
16
+ return "2024-02-01T12:00:00Z"
17
+
18
+ # WHEN calling the event handler
19
+ raw_event = load_event ("bedrockAgentFunctionEvent.json" )
20
+ raw_event ["function" ] = "get_current_time" # ensure function name matches
21
+ result = app .resolve (raw_event , {})
22
+
23
+ # THEN process event correctly
24
+ assert result ["messageVersion" ] == "1.0"
25
+ assert result ["response" ]["actionGroup" ] == raw_event ["actionGroup" ]
26
+ assert result ["response" ]["function" ] == "get_current_time"
27
+ assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == "2024-02-01T12:00:00Z"
28
+
29
+
30
+ def test_bedrock_agent_function_with_error ():
31
+ # GIVEN a Bedrock Agent Function resolver
32
+ app = BedrockAgentFunctionResolver ()
33
+
34
+ @app .tool (description = "Function that raises error" )
35
+ def error_function ():
36
+ raise ValueError ("Something went wrong" )
37
+
38
+ # WHEN calling the event handler with a function that raises an error
39
+ raw_event = load_event ("bedrockAgentFunctionEvent.json" )
40
+ raw_event ["function" ] = "error_function"
41
+ result = app .resolve (raw_event , {})
42
+
43
+ # THEN process the error correctly
44
+ assert result ["messageVersion" ] == "1.0"
45
+ assert result ["response" ]["actionGroup" ] == raw_event ["actionGroup" ]
46
+ assert result ["response" ]["function" ] == "error_function"
47
+ assert "Error: Something went wrong" in result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ]
48
+
49
+
50
+ def test_bedrock_agent_function_not_found ():
51
+ # GIVEN a Bedrock Agent Function resolver
52
+ app = BedrockAgentFunctionResolver ()
53
+
54
+ @app .tool (description = "Test function" )
55
+ def test_function ():
56
+ return "test"
57
+
58
+ # WHEN calling the event handler with a non-existent function
59
+ raw_event = load_event ("bedrockAgentFunctionEvent.json" )
60
+ raw_event ["function" ] = "nonexistent_function"
61
+ result = app .resolve (raw_event , {})
62
+
63
+ # THEN return function not found response
64
+ assert result ["messageVersion" ] == "1.0"
65
+ assert result ["response" ]["actionGroup" ] == raw_event ["actionGroup" ]
66
+ assert result ["response" ]["function" ] == "nonexistent_function"
67
+ assert "Function not found" in result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ]
68
+
69
+
70
+ def test_bedrock_agent_function_missing_description ():
71
+ # GIVEN a Bedrock Agent Function resolver
72
+ app = BedrockAgentFunctionResolver ()
73
+
74
+ # WHEN registering a tool without description
75
+ # THEN raise ValueError
76
+ with pytest .raises (ValueError , match = "Tool description is required" ):
77
+ @app .tool ()
78
+ def test_function ():
79
+ return "test"
80
+
81
+
82
+ def test_bedrock_agent_function_duplicate_registration ():
83
+ # GIVEN a Bedrock Agent Function resolver
84
+ app = BedrockAgentFunctionResolver ()
85
+
86
+ # WHEN registering the same function twice
87
+ @app .tool (description = "First registration" )
88
+ def test_function ():
89
+ return "test"
90
+
91
+ # THEN raise ValueError on second registration
92
+ with pytest .raises (ValueError , match = "Tool 'test_function' already registered" ):
93
+ @app .tool (description = "Second registration" )
94
+ def test_function (): # noqa: F811
95
+ return "test"
96
+
97
+
98
+ def test_bedrock_agent_function_invalid_event ():
99
+ # GIVEN a Bedrock Agent Function resolver
100
+ app = BedrockAgentFunctionResolver ()
101
+
102
+ @app .tool (description = "Test function" )
103
+ def test_function ():
104
+ return "test"
105
+
106
+ # WHEN calling with invalid event
107
+ # THEN raise ValueError
108
+ with pytest .raises (ValueError , match = "Missing required field" ):
109
+ app .resolve ({}, {})
0 commit comments