@@ -19,11 +19,7 @@ def __init__(self, max_steps: int = 6, use_tool_calls: bool | None = None):
1919 self .finalize = None # fallback finalize handled via LM prompt
2020 self ._tool_list = [t .spec () for t in TOOLS .values ()]
2121 self .max_steps = max_steps
22- self ._provider = None
23- try :
24- self ._provider = (self .lm .model .split ("/" , 1 )[0 ] if getattr (self .lm , "model" , None ) else None )
25- except Exception :
26- self ._provider = None
22+ self ._provider = self ._infer_provider (self .lm )
2723 # Determine function-calls mode
2824 env_override = os .getenv ("USE_TOOL_CALLS" )
2925 if isinstance (use_tool_calls , bool ):
@@ -37,10 +33,36 @@ def __init__(self, max_steps: int = 6, use_tool_calls: bool | None = None):
3733 try :
3834 from dspy .adapters import JSONAdapter
3935 dspy .settings .configure (adapter = JSONAdapter ())
36+ if to_dspy_tools ():
37+ self .planner = dspy .Predict (PlanWithTools )
38+ self ._load_compiled_demos ()
39+ else :
40+ self ._use_tool_calls = False
4041 except Exception :
41- pass
42- self .planner = dspy .Predict (PlanWithTools )
43- self ._load_compiled_demos ()
42+ self ._use_tool_calls = False
43+
44+ def _infer_provider (self , lm ) -> str | None :
45+ try :
46+ prov = getattr (lm , "provider" , None ) or getattr (lm , "_provider" , None )
47+ if isinstance (prov , str ) and prov .strip ():
48+ return prov .strip ().lower ()
49+ except Exception :
50+ pass
51+ try :
52+ cls_name = lm .__class__ .__name__ .lower ()
53+ if "openai" in cls_name :
54+ return "openai"
55+ if "ollama" in cls_name :
56+ return "ollama"
57+ except Exception :
58+ pass
59+ try :
60+ model = getattr (lm , "model" , None )
61+ if isinstance (model , str ) and "/" in model :
62+ return model .split ("/" , 1 )[0 ].lower ()
63+ except Exception :
64+ pass
65+ return None
4466
4567 def _load_compiled_demos (self ):
4668 import json as _json
@@ -142,7 +164,9 @@ def needs_math(q: str) -> bool:
142164
143165 def needs_time (q : str ) -> bool :
144166 ql = q .lower ()
145- return any (w in ql for w in ["time" , "date" , "utc" , "current time" , "now" ])
167+ if "current time" in ql or "current date" in ql :
168+ return True
169+ return re .search (r"\b(time|times|date|dates|utc|now|today|tomorrow|yesterday|timestamp|datetime)\b" , ql ) is not None
146170
147171 def used_tool (state , name : str ) -> bool :
148172 return any (step .get ("tool" ) == name for step in state )
@@ -216,8 +240,19 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
216240 # If tool calls are proposed, execute them.
217241 calls = getattr (pred , 'tool_calls' , None )
218242 executed_any = False
243+ had_validation_error = False
244+ had_policy_violation = False
219245 if calls and getattr (calls , 'tool_calls' , None ):
220- for call in calls .tool_calls :
246+ call_list = list (calls .tool_calls )
247+ if len (call_list ) > 1 :
248+ had_policy_violation = True
249+ state .append ({
250+ "tool" : "⛔️policy_violation" ,
251+ "args" : {"reason" : "multiple_tool_calls" , "count" : len (call_list )},
252+ "observation" : "Model returned multiple tool calls in one step; executing only the first." ,
253+ })
254+ call_list = call_list [:1 ]
255+ for call in call_list :
221256 try :
222257 name = getattr (call , 'name' )
223258 args = getattr (call , 'args' ) or {}
@@ -226,6 +261,7 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
226261 # Validate/execute; on validation error, record and continue planning
227262 obs = run_tool (name , args )
228263 if isinstance (obs , dict ) and "error" in obs and "validation" in obs .get ("error" , "" ):
264+ had_validation_error = True
229265 state .append ({
230266 "tool" : "⛔️validation_error" ,
231267 "args" : {"name" : name , "args" : args },
@@ -239,6 +275,8 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
239275 # Check finalization.
240276 final = getattr (pred , 'final' , None )
241277 if final :
278+ if had_policy_violation or had_validation_error :
279+ continue
242280 if must_math and not used_tool (state , "calculator" ):
243281 state .append ({"tool" : "⛔️policy_violation" , "args" : {}, "observation" : "Finalize before calculator (OpenAI path)." })
244282 # If tools were suggested and executed this step, iterate; else force tool suggestion by continuing.
@@ -399,6 +437,9 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
399437 "tool_calls" : tool_calls ,
400438 "provider" : self ._provider ,
401439 "model" : getattr (self .lm , "model" , None ),
440+ "cost" : total_cost ,
441+ "input_tokens" : total_in_tokens ,
442+ "output_tokens" : total_out_tokens ,
402443 }
403444 return p
404445
0 commit comments