Skip to content

Commit 7337d71

Browse files
committed
Added more tests
1 parent 98c7c84 commit 7337d71

File tree

10 files changed

+1581
-12
lines changed

10 files changed

+1581
-12
lines changed

src/agent.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
load_dotenv()
33

44
import json
5-
from typing import Dict, Any
5+
from typing import Dict, Any, List, Generator
66

77
from utils import chatloop
88

@@ -21,10 +21,35 @@
2121
"web_fetch": WebFetch()
2222
}
2323

24-
def process_tool_calls(response: Dict[str, Any]):
24+
def process_tool_calls(response: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
25+
"""Process tool calls from the LLM response and return results.
26+
27+
Args:
28+
response: The response from the LLM containing tool calls.
29+
30+
Yields:
31+
Dict with tool response information.
32+
"""
33+
# Handle case where tool_calls is None or not present
34+
if not response or not response.get("tool_calls") or not isinstance(response.get("tool_calls"), list):
35+
return
36+
2537
for tool_call in response.get("tool_calls", []):
26-
tool_name = tool_call.get("function", {}).get("name")
27-
arguments = tool_call.get("function", {}).get("arguments", "{}")
38+
if not isinstance(tool_call, dict):
39+
continue
40+
41+
tool_id = tool_call.get("id", "unknown_tool")
42+
43+
# Extract function data, handling possible missing keys
44+
function_data = tool_call.get("function", {})
45+
if not isinstance(function_data, dict):
46+
continue
47+
48+
tool_name = function_data.get("name")
49+
if not tool_name:
50+
continue
51+
52+
arguments = function_data.get("arguments", "{}")
2853

2954
print(f"<Tool: {tool_name}>")
3055

@@ -39,11 +64,16 @@ def process_tool_calls(response: Dict[str, Any]):
3964

4065
if tool_name in tool_map:
4166
tool_instance = tool_map[tool_name]
42-
tool_result = tool_instance.run(**args)
67+
try:
68+
tool_result = tool_instance.run(**args)
69+
except Exception as e:
70+
tool_result = {
71+
"error": f"Error running tool '{tool_name}': {str(e)}"
72+
}
4373

4474
yield {
4575
"role": "tool",
46-
"tool_call_id": tool_call.get("id"),
76+
"tool_call_id": tool_id,
4777
"content": json.dumps(tool_result)
4878
}
4979

@@ -76,15 +106,30 @@ async def run_conversation(user_prompt):
76106

77107
messages.append({"role": "user", "content": user_prompt})
78108
response = await chat.send_messages(messages)
79-
80-
assistant_message = response.get("choices", [{}])[0].get("message", {})
109+
110+
# Handle possible None response
111+
if not response:
112+
return ""
113+
114+
# Handle missing or empty choices
115+
choices = response.get("choices", [])
116+
if not choices:
117+
return ""
118+
119+
assistant_message = choices[0].get("message", {})
81120
messages.append(assistant_message)
82121

83-
while assistant_message.get("tool_calls", False):
122+
# Handle the case where tool_calls might be missing or not a list
123+
while assistant_message.get("tool_calls"):
84124
for result in process_tool_calls(assistant_message):
85125
messages.append(result)
86126

87127
response = await chat.send_messages(messages)
128+
129+
# Handle possible None response or missing choices
130+
if not response or not response.get("choices"):
131+
break
132+
88133
assistant_message = response.get("choices", [{}])[0].get("message", {})
89134
messages.append(assistant_message)
90135

src/chat.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dotenv import load_dotenv
22
load_dotenv()
33

4+
from typing import Dict, Any, Optional
5+
46
from utils import chatloop
57
from utils.azureopenai.chat import Chat
68

@@ -19,14 +21,33 @@
1921
messages = [{"role": "system", "content": system_role}]
2022

2123
@chatloop("Chat")
22-
async def run_conversation(user_prompt):
24+
async def run_conversation(user_prompt: str) -> str:
25+
"""Run a conversation with the user.
26+
27+
Args:
28+
user_prompt: The user's input prompt.
29+
30+
Returns:
31+
The assistant's response as a string.
32+
"""
2333
messages.append({"role": "user", "content": user_prompt})
2434
response = await chat.send_messages(messages)
2535

36+
# Extract content from response, handling possible errors and edge cases
37+
content = ""
38+
if response:
39+
if isinstance(response, dict) and "choices" in response:
40+
choices = response.get("choices", [])
41+
if choices and len(choices) > 0:
42+
message = choices[0].get("message", {})
43+
content = message.get("content", "")
44+
2645
# Print final response
2746
hr = "\n" + "-" * 50 + "\n"
2847
print(hr, "Response:", hr)
2948
print(response, hr)
49+
50+
return content
3051

3152
if __name__ == "__main__":
3253
run_conversation()

0 commit comments

Comments
 (0)