Skip to content

Commit 3db9057

Browse files
Merge pull request #819 from MervinPraison/claude/issue-818-20250711_091438
fix: enable tool calling for Gemini models
2 parents 261b905 + 0f7dfc9 commit 3db9057

File tree

7 files changed

+691
-12
lines changed

7 files changed

+691
-12
lines changed

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -713,26 +713,69 @@ def _apply_guardrail_with_retry(self, response_text, prompt, temperature=0.2, to
713713
)
714714

715715
return current_response
716+
717+
def _build_system_prompt(self, tools=None):
718+
"""Build the system prompt with tool information.
719+
720+
Args:
721+
tools: Optional list of tools to use (defaults to self.tools)
722+
723+
Returns:
724+
str: The system prompt or None if use_system_prompt is False
725+
"""
726+
if not self.use_system_prompt:
727+
return None
728+
729+
system_prompt = f"""{self.backstory}\n
730+
Your Role: {self.role}\n
731+
Your Goal: {self.goal}"""
732+
733+
# Add tool usage instructions if tools are available
734+
# Use provided tools or fall back to self.tools
735+
tools_to_use = tools if tools is not None else self.tools
736+
if tools_to_use:
737+
tool_names = []
738+
for tool in tools_to_use:
739+
try:
740+
if callable(tool) and hasattr(tool, '__name__'):
741+
tool_names.append(tool.__name__)
742+
elif isinstance(tool, dict) and isinstance(tool.get('function'), dict) and 'name' in tool['function']:
743+
tool_names.append(tool['function']['name'])
744+
elif isinstance(tool, str):
745+
tool_names.append(tool)
746+
elif hasattr(tool, "to_openai_tool"):
747+
# Handle MCP tools
748+
openai_tools = tool.to_openai_tool()
749+
if isinstance(openai_tools, list):
750+
for t in openai_tools:
751+
if isinstance(t, dict) and 'function' in t and 'name' in t['function']:
752+
tool_names.append(t['function']['name'])
753+
elif isinstance(openai_tools, dict) and 'function' in openai_tools:
754+
tool_names.append(openai_tools['function']['name'])
755+
except (AttributeError, KeyError, TypeError) as e:
756+
logging.warning(f"Could not extract tool name from {tool}: {e}")
757+
continue
758+
759+
if tool_names:
760+
system_prompt += f"\n\nYou have access to the following tools: {', '.join(tool_names)}. Use these tools when appropriate to help complete your tasks. Always use tools when they can help provide accurate information or perform actions."
761+
762+
return system_prompt
716763

717-
def _build_messages(self, prompt, temperature=0.2, output_json=None, output_pydantic=None):
764+
def _build_messages(self, prompt, temperature=0.2, output_json=None, output_pydantic=None, tools=None):
718765
"""Build messages list for chat completion.
719766
720767
Args:
721768
prompt: The user prompt (str or list)
722769
temperature: Temperature for the chat
723770
output_json: Optional Pydantic model for JSON output
724771
output_pydantic: Optional Pydantic model for JSON output (alias)
772+
tools: Optional list of tools to use (defaults to self.tools)
725773
726774
Returns:
727775
tuple: (messages list, original prompt)
728776
"""
729-
# Build system prompt if enabled
730-
system_prompt = None
731-
if self.use_system_prompt:
732-
system_prompt = f"""{self.backstory}\n
733-
Your Role: {self.role}\n
734-
Your Goal: {self.goal}
735-
"""
777+
# Build system prompt using the helper method
778+
system_prompt = self._build_system_prompt(tools)
736779

737780
# Use openai_client's build_messages method if available
738781
if self._openai_client is not None:
@@ -1176,7 +1219,7 @@ def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pyd
11761219
# Pass everything to LLM class
11771220
response_text = self.llm_instance.get_response(
11781221
prompt=prompt,
1179-
system_prompt=f"{self.backstory}\n\nYour Role: {self.role}\n\nYour Goal: {self.goal}" if self.use_system_prompt else None,
1222+
system_prompt=self._build_system_prompt(tools),
11801223
chat_history=self.chat_history,
11811224
temperature=temperature,
11821225
tools=tool_param,
@@ -1492,7 +1535,7 @@ async def achat(self, prompt: str, temperature=0.2, tools=None, output_json=None
14921535
try:
14931536
response_text = await self.llm_instance.get_response_async(
14941537
prompt=prompt,
1495-
system_prompt=f"{self.backstory}\n\nYour Role: {self.role}\n\nYour Goal: {self.goal}" if self.use_system_prompt else None,
1538+
system_prompt=self._build_system_prompt(tools),
14961539
chat_history=self.chat_history,
14971540
temperature=temperature,
14981541
tools=tools,
@@ -1506,7 +1549,7 @@ async def achat(self, prompt: str, temperature=0.2, tools=None, output_json=None
15061549
console=self.console,
15071550
agent_name=self.name,
15081551
agent_role=self.role,
1509-
agent_tools=[t.__name__ if hasattr(t, '__name__') else str(t) for t in self.tools],
1552+
agent_tools=[t.__name__ if hasattr(t, '__name__') else str(t) for t in (tools if tools is not None else self.tools)],
15101553
execute_tool_fn=self.execute_tool_async,
15111554
reasoning_steps=reasoning_steps
15121555
)

src/praisonai-agents/praisonaiagents/llm/llm.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def _supports_streaming_tools(self) -> bool:
406406
# missing tool calls or making duplicate calls
407407
return False
408408

409-
def _build_messages(self, prompt, system_prompt=None, chat_history=None, output_json=None, output_pydantic=None):
409+
def _build_messages(self, prompt, system_prompt=None, chat_history=None, output_json=None, output_pydantic=None, tools=None):
410410
"""Build messages list for LLM completion. Works for both sync and async.
411411
412412
Args:
@@ -415,6 +415,7 @@ def _build_messages(self, prompt, system_prompt=None, chat_history=None, output_
415415
chat_history: Optional list of previous messages
416416
output_json: Optional Pydantic model for JSON output
417417
output_pydantic: Optional Pydantic model for JSON output (alias)
418+
tools: Optional list of tools available
418419
419420
Returns:
420421
tuple: (messages list, original prompt)
@@ -1858,6 +1859,21 @@ def _build_completion_params(self, **override_params) -> Dict[str, Any]:
18581859
# Override with any provided parameters
18591860
params.update(override_params)
18601861

1862+
# Add tool_choice="auto" when tools are provided (unless already specified)
1863+
if 'tools' in params and params['tools'] and 'tool_choice' not in params:
1864+
# For Gemini models, use tool_choice to encourage tool usage
1865+
# More comprehensive Gemini model detection
1866+
if any(prefix in self.model.lower() for prefix in ['gemini', 'gemini/', 'google/gemini']):
1867+
try:
1868+
import litellm
1869+
# Check if model supports function calling before setting tool_choice
1870+
if litellm.supports_function_calling(model=self.model):
1871+
params['tool_choice'] = 'auto'
1872+
except Exception as e:
1873+
# If check fails, still set tool_choice for known Gemini models
1874+
logging.debug(f"Could not verify function calling support: {e}. Setting tool_choice anyway.")
1875+
params['tool_choice'] = 'auto'
1876+
18611877
return params
18621878

18631879
def _prepare_response_logging(self, temperature: float, stream: bool, verbose: bool, markdown: bool, **kwargs) -> Optional[Dict[str, Any]]:
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""
2+
Test to verify that async agents with Gemini models properly use tools
3+
after the fix for issue #818
4+
"""
5+
import asyncio
6+
import logging
7+
from praisonaiagents import Agent, Task, PraisonAIAgents
8+
9+
# Enable logging to see tool calls
10+
logging.basicConfig(level=logging.INFO)
11+
12+
# Mock search tool
13+
async def mock_search(query: str) -> dict:
14+
"""Mock search tool for testing"""
15+
print(f"[TOOL CALLED] Searching for: {query}")
16+
return {
17+
"query": query,
18+
"results": [
19+
{
20+
"title": f"Result 1 for {query}",
21+
"snippet": f"This is a mock result about {query}",
22+
"url": "https://example.com/1"
23+
},
24+
{
25+
"title": f"Result 2 for {query}",
26+
"snippet": f"Another mock result about {query}",
27+
"url": "https://example.com/2"
28+
}
29+
],
30+
"status": "success"
31+
}
32+
33+
async def test_async_gemini_tools():
34+
"""Test async agents with Gemini models use tools correctly"""
35+
36+
# Create search agent with Gemini model
37+
search_agent = Agent(
38+
name="AsyncSearcher",
39+
role="Research Assistant",
40+
goal="Find information using the search tool",
41+
backstory="You are an expert at finding information online",
42+
tools=[mock_search],
43+
llm={"model": "gemini/gemini-1.5-flash-latest"},
44+
verbose=True
45+
)
46+
47+
# Create analysis agent without tools
48+
analysis_agent = Agent(
49+
name="Analyzer",
50+
role="Data Analyst",
51+
goal="Analyze search results",
52+
backstory="You excel at analyzing and summarizing information",
53+
llm={"model": "gemini/gemini-1.5-flash-latest"},
54+
verbose=True
55+
)
56+
57+
# Create tasks
58+
search_task = Task(
59+
name="search_task",
60+
description="Search for information about 'quantum computing breakthroughs 2024'",
61+
expected_output="Search results with at least 2 relevant findings",
62+
agent=search_agent,
63+
async_execution=True
64+
)
65+
66+
analysis_task = Task(
67+
name="analysis_task",
68+
description="Analyze the search results and provide a summary",
69+
expected_output="A concise summary of the findings",
70+
agent=analysis_agent,
71+
context=[search_task],
72+
async_execution=False
73+
)
74+
75+
# Create workflow
76+
workflow = PraisonAIAgents(
77+
agents=[search_agent, analysis_agent],
78+
tasks=[search_task, analysis_task],
79+
verbose=True
80+
)
81+
82+
# Execute async
83+
print("\n🚀 Starting async agent test with Gemini models...")
84+
result = await workflow.astart()
85+
86+
# Check results
87+
print("\n✅ Test Results:")
88+
print("-" * 50)
89+
90+
# Verify search agent used the tool
91+
search_result = str(result)
92+
if "mock result" in search_result.lower() or "tool called" in search_result.lower():
93+
print("✅ SUCCESS: Search agent properly used the mock_search tool!")
94+
else:
95+
print("❌ FAILURE: Search agent did NOT use the tool (claimed no internet access)")
96+
97+
# Show the actual output
98+
print("\nFinal output:")
99+
print(result)
100+
101+
return result
102+
103+
async def test_multiple_async_agents():
104+
"""Test multiple async agents running in parallel"""
105+
106+
agents = []
107+
tasks = []
108+
109+
# Create 3 search agents
110+
for i in range(3):
111+
agent = Agent(
112+
name=f"AsyncAgent{i}",
113+
role="Researcher",
114+
goal="Search for information",
115+
backstory="Expert researcher",
116+
tools=[mock_search],
117+
llm={"model": "gemini/gemini-1.5-flash-latest"}
118+
)
119+
120+
task = Task(
121+
name=f"task_{i}",
122+
description=f"Search for 'AI advancement #{i+1}'",
123+
expected_output="Search results",
124+
agent=agent,
125+
async_execution=True
126+
)
127+
128+
agents.append(agent)
129+
tasks.append(task)
130+
131+
# Execute all in parallel
132+
workflow = PraisonAIAgents(agents=agents, tasks=tasks)
133+
134+
print("\n🚀 Testing multiple async agents in parallel...")
135+
results = await workflow.astart()
136+
137+
# Verify all agents used tools
138+
success_count = 0
139+
for i, task in enumerate(tasks):
140+
if "mock result" in str(results).lower():
141+
success_count += 1
142+
143+
print(f"\n{success_count}/{len(tasks)} agents successfully used tools")
144+
145+
return results
146+
147+
async def main():
148+
"""Run all async tests"""
149+
try:
150+
# Test 1: Single async agent
151+
await test_async_gemini_tools()
152+
153+
# Test 2: Multiple async agents in parallel
154+
await test_multiple_async_agents()
155+
156+
print("\n🎉 All async tests completed!")
157+
158+
except Exception as e:
159+
print(f"\n❌ Error during testing: {e}")
160+
raise
161+
162+
if __name__ == "__main__":
163+
asyncio.run(main())
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Simple test to verify that tool_choice='auto' is set for Gemini models
3+
"""
4+
import logging
5+
from praisonaiagents.llm.llm import LLM
6+
7+
# Enable debug logging to see our log message
8+
logging.basicConfig(level=logging.DEBUG, format='%(levelname)s: %(message)s')
9+
10+
# Test different Gemini model formats
11+
test_models = [
12+
"gemini/gemini-1.5-flash-8b",
13+
"gemini-1.5-flash-8b",
14+
"gemini/gemini-pro",
15+
"gpt-4", # Non-Gemini model for comparison
16+
]
17+
18+
# Mock tools
19+
mock_tools = [
20+
{
21+
"type": "function",
22+
"function": {
23+
"name": "search",
24+
"description": "Search for information",
25+
"parameters": {
26+
"type": "object",
27+
"properties": {
28+
"query": {"type": "string"}
29+
}
30+
}
31+
}
32+
}
33+
]
34+
35+
print("Testing tool_choice setting for different models:\n")
36+
37+
for model in test_models:
38+
print(f"\nTesting model: {model}")
39+
try:
40+
llm = LLM(model=model)
41+
params = llm._build_completion_params(
42+
messages=[{"role": "user", "content": "test"}],
43+
tools=mock_tools
44+
)
45+
46+
tool_choice = params.get('tool_choice', 'NOT SET')
47+
print(f" tool_choice: {tool_choice}")
48+
49+
# Verify behavior
50+
if model.startswith(('gemini-', 'gemini/')):
51+
if tool_choice == 'auto':
52+
print(f" ✅ CORRECT: Gemini model has tool_choice='auto'")
53+
else:
54+
print(f" ❌ ERROR: Gemini model should have tool_choice='auto'")
55+
else:
56+
if tool_choice == 'NOT SET':
57+
print(f" ✅ CORRECT: Non-Gemini model doesn't have tool_choice set")
58+
else:
59+
print(f" ⚠️ WARNING: Non-Gemini model has tool_choice set to '{tool_choice}'")
60+
61+
except Exception as e:
62+
print(f" ❌ ERROR: {e}")
63+
64+
print("\nTest complete!")

0 commit comments

Comments
 (0)