diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py index 5e40079..8abd920 100644 --- a/gpt_oss/chat.py +++ b/gpt_oss/chat.py @@ -164,26 +164,14 @@ def main(args): messages.append(user_message) else: # Tool or function call - if last_message.recipient.startswith("browser."): - assert args.browser, "Browser tool is not enabled" - tool_name = "Search" - async def run_tool(): - results = [] - async for msg in browser_tool.process(last_message): - results.append(msg) - return results + if last_message.recipient.startswith("browser.") or last_message.recipient.startswith("python"): + tool_map = {"browser": (browser_tool, "Search"), "python": (python_tool, "Python")} + tool_key = last_message.recipient.split('.')[0] + tool, tool_name = tool_map[tool_key] - result = asyncio.run(run_tool()) - messages += result - elif last_message.recipient.startswith("python"): - assert args.python, "Python tool is not enabled" - tool_name = "Python" async def run_tool(): - results = [] - async for msg in python_tool.process(last_message): - results.append(msg) - return results - + return [msg async for msg in tool.process(last_message)] + result = asyncio.run(run_tool()) messages += result elif last_message.recipient == "functions.apply_patch":