1+ # pylint: disable=import-error, wrong-import-position, missing-module-docstring 
2+ import  os 
3+ import  sys 
4+ from  unittest .mock  import  MagicMock , AsyncMock , patch 
5+ import  pytest 
6+ from  contextlib  import  contextmanager 
7+ 
8+ # Mocking necessary modules and environment variables 
9+ sys .modules ["azure.monitor.events.extension" ] =  MagicMock ()
10+ 
11+ # Mocking environment variables 
12+ os .environ ["COSMOSDB_ENDPOINT" ] =  "https://mock-endpoint" 
13+ os .environ ["COSMOSDB_KEY" ] =  "mock-key" 
14+ os .environ ["COSMOSDB_DATABASE" ] =  "mock-database" 
15+ os .environ ["COSMOSDB_CONTAINER" ] =  "mock-container" 
16+ os .environ ["AZURE_OPENAI_DEPLOYMENT_NAME" ] =  "mock-deployment-name" 
17+ os .environ ["AZURE_OPENAI_API_VERSION" ] =  "2023-01-01" 
18+ os .environ ["AZURE_OPENAI_ENDPOINT" ] =  "https://mock-openai-endpoint" 
19+ 
20+ # Importing the module to test 
21+ from  src .backend .agents .base_agent  import  BaseAgent 
22+ from  src .backend .models .messages  import  ActionRequest , Step , StepStatus , ActionResponse , AgentMessage 
23+ from  autogen_core .base  import  AgentId 
24+ from  autogen_core .components .models  import  AssistantMessage , UserMessage 
25+ 
26+ # Context manager for setting up mocks 
27+ @contextmanager  
28+ def  mock_context ():
29+     mock_runtime  =  MagicMock ()
30+     with  patch ("autogen_core.base._agent_instantiation.AgentInstantiationContext.AGENT_INSTANTIATION_CONTEXT_VAR" ) as  mock_context_var :
31+         mock_context_instance  =  MagicMock ()
32+         mock_context_var .get .return_value  =  mock_context_instance 
33+         mock_context_instance .set .return_value  =  None 
34+         yield  mock_runtime 
35+ 
36+ @pytest .fixture  
37+ def  mock_dependencies ():
38+     model_client  =  MagicMock ()
39+     model_context  =  MagicMock ()
40+     tools  =  [MagicMock (schema = "tool_schema" )]
41+     tool_agent_id  =  MagicMock ()
42+     return  {
43+         "model_client" : model_client ,
44+         "model_context" : model_context ,
45+         "tools" : tools ,
46+         "tool_agent_id" : tool_agent_id ,
47+     }
48+ 
49+ @pytest .fixture  
50+ def  base_agent (mock_dependencies ):
51+     with  mock_context ():
52+         return  BaseAgent (
53+             agent_name = "test_agent" ,
54+             model_client = mock_dependencies ["model_client" ],
55+             session_id = "test_session" ,
56+             user_id = "test_user" ,
57+             model_context = mock_dependencies ["model_context" ],
58+             tools = mock_dependencies ["tools" ],
59+             tool_agent_id = mock_dependencies ["tool_agent_id" ],
60+             system_message = "This is a system message." ,
61+         )
62+ 
63+ def  test_save_state (base_agent , mock_dependencies ):
64+     mock_dependencies ["model_context" ].save_state  =  MagicMock (return_value = {"state_key" : "state_value" })
65+     state  =  base_agent .save_state ()
66+     assert  state  ==  {"memory" : {"state_key" : "state_value" }}
67+ 
68+ def  test_load_state (base_agent , mock_dependencies ):
69+     mock_dependencies ["model_context" ].load_state  =  MagicMock ()
70+     state  =  {"memory" : {"state_key" : "state_value" }}
71+     base_agent .load_state (state )
72+     mock_dependencies ["model_context" ].load_state .assert_called_once_with ({"state_key" : "state_value" })
73+ 
74+ @pytest .mark .asyncio  
75+ async  def  test_handle_action_request_error (base_agent , mock_dependencies ):
76+     """Test handle_action_request when tool_agent_caller_loop raises an error.""" 
77+     # Mocking a Step object 
78+     step  =  Step (
79+         id = "step_1" ,
80+         status = StepStatus .approved ,
81+         human_feedback = "feedback" ,
82+         agent_reply = "" ,
83+         plan_id = "plan_id" ,
84+         action = "action" ,
85+         agent = "HumanAgent" ,
86+         session_id = "session_id" ,
87+         user_id = "user_id" ,
88+     )
89+ 
90+     # Mocking the model context methods 
91+     mock_dependencies ["model_context" ].get_step  =  AsyncMock (return_value = step )
92+     mock_dependencies ["model_context" ].add_item  =  AsyncMock ()
93+ 
94+     # Mock tool_agent_caller_loop to raise an exception 
95+     with  patch ("src.backend.agents.base_agent.tool_agent_caller_loop" , AsyncMock (side_effect = Exception ("Mock error" ))):
96+         # Define the ActionRequest message 
97+         message  =  ActionRequest (
98+             step_id = "step_1" ,
99+             session_id = "test_session" ,
100+             action = "test_action" ,
101+             plan_id = "plan_id" ,
102+             agent = "HumanAgent" ,
103+         )
104+         ctx  =  MagicMock ()
105+ 
106+         # Call handle_action_request and capture exception 
107+         with  pytest .raises (ValueError ) as  excinfo :
108+             await  base_agent .handle_action_request (message , ctx )
109+ 
110+         # Assert that the exception matches the expected ValueError 
111+         assert  "Return type <class 'NoneType'> not in return types"  in  str (excinfo .value ), (
112+             "Expected ValueError due to NoneType return, but got a different exception." 
113+         )
114+ 
115+ @pytest .mark .asyncio  
116+ async  def  test_handle_action_request_success (base_agent , mock_dependencies ):
117+     """Test handle_action_request with a successful tool_agent_caller_loop.""" 
118+     # Update Step with a valid agent enum value 
119+     step  =  Step (
120+         id = "step_1" ,
121+         status = StepStatus .approved ,
122+         human_feedback = "feedback" ,
123+         agent_reply = "" ,
124+         plan_id = "plan_id" ,
125+         action = "action" ,
126+         agent = "HumanAgent" ,
127+         session_id = "session_id" ,
128+         user_id = "user_id" 
129+     )
130+     mock_dependencies ["model_context" ].get_step  =  AsyncMock (return_value = step )
131+     mock_dependencies ["model_context" ].update_step  =  AsyncMock ()
132+     mock_dependencies ["model_context" ].add_item  =  AsyncMock ()
133+ 
134+     # Mock the tool_agent_caller_loop to return a result 
135+     with  patch ("src.backend.agents.base_agent.tool_agent_caller_loop" , new = AsyncMock (return_value = [MagicMock (content = "result" )])):
136+         # Mock the publish_message method to be awaitable 
137+         base_agent ._runtime .publish_message  =  AsyncMock ()
138+ 
139+         message  =  ActionRequest (
140+             step_id = "step_1" ,
141+             session_id = "test_session" ,
142+             action = "test_action" ,
143+             plan_id = "plan_id" ,
144+             agent = "HumanAgent" 
145+         )
146+         ctx  =  MagicMock ()
147+ 
148+         # Call the method being tested 
149+         response  =  await  base_agent .handle_action_request (message , ctx )
150+ 
151+         # Assertions to ensure the response is correct 
152+         assert  response .status  ==  StepStatus .completed 
153+         assert  response .result  ==  "result" 
154+         assert  response .plan_id  ==  "plan_id"   # Validate plan_id 
155+         assert  response .session_id  ==  "test_session"   # Validate session_id 
156+ 
157+         # Ensure publish_message was called 
158+         base_agent ._runtime .publish_message .assert_awaited_once_with (
159+             response ,
160+             AgentId (type = "group_chat_manager" , key = "test_session" ),
161+             sender = base_agent .id ,
162+             cancellation_token = None 
163+         )
164+ 
165+         # Ensure the step was updated 
166+         mock_dependencies ["model_context" ].update_step .assert_called_once_with (step )
0 commit comments