@@ -156,17 +156,24 @@ def forward(self, question: str):
156156
157157 def needs_math (q : str ) -> bool :
158158 ql = q .lower ()
159- if re .search (r"[0-9].*[+\-*/]" , q ):
159+ if re .search (r"[0-9].*[+\-*/% ]" , q ):
160160 return True
161- if any (w in ql for w in ["add" , "sum" , "multiply" , "divide" , "compute" , "calculate" , "total" , "power" , "factorial" , "!" , "**" , "^" ]):
161+ if re .search (r"\b\d+(?:\.\d+)?\s*(?:x|times|multiplied by)\s*\d+(?:\.\d+)?\b" , ql ):
162+ return True
163+ if re .search (r"\b\d+(?:\.\d+)?\s*(?:plus|minus|add|added to|subtract|subtracted by|divide|divided by|over)\s*\d+(?:\.\d+)?\b" , ql ):
164+ return True
165+ if re .search (r"\d" , ql ) and any (w in ql for w in [
166+ "add" , "sum" , "plus" , "minus" , "subtract" , "multiply" , "divide" ,
167+ "total" , "power" , "factorial" , "compute" , "calculate"
168+ ]):
162169 return True
163170 return False
164171
165172 def needs_time (q : str ) -> bool :
166173 ql = q .lower ()
167174 if "current time" in ql or "current date" in ql :
168175 return True
169- return re .search (r"\b(time|times| date|dates |utc|now|today|tomorrow|yesterday|timestamp|datetime)\b" , ql ) is not None
176+ return re .search (r"\b(time|date|utc|now|today|tomorrow|yesterday|timestamp|datetime)\b" , ql ) is not None
170177
171178 def used_tool (state , name : str ) -> bool :
172179 return any (step .get ("tool" ) == name for step in state )
@@ -178,17 +185,25 @@ def used_tool(state, name: str) -> bool:
178185
179186 def _accumulate_usage (input_text : str = "" , output_text : str = "" ):
180187 # Pull new usage entries from dspy.settings.trace
188+ in_tok = 0
189+ out_tok = 0
190+ cost = 0.0
181191 try :
182192 for _ , _ , out in dspy .settings .trace [- 1 :]:
183193 usage = getattr (out , "usage" , None ) or {}
184194 nonlocal total_cost , total_in_tokens , total_out_tokens
185195 c = getattr (out , "cost" , None )
186196 if c is not None :
187- total_cost += float (c or 0 )
188- total_in_tokens += int (usage .get ("input_tokens" , 0 ) or 0 )
189- total_out_tokens += int (usage .get ("output_tokens" , 0 ) or 0 )
197+ cost += float (c or 0 )
198+ in_tok += int (usage .get ("input_tokens" , 0 ) or 0 )
199+ out_tok += int (usage .get ("output_tokens" , 0 ) or 0 )
190200 except Exception :
191201 pass
202+ if in_tok or out_tok or cost :
203+ total_cost += cost
204+ total_in_tokens += in_tok
205+ total_out_tokens += out_tok
206+ return
192207 # Heuristic fallback: estimate tokens from input/output texts and compute cost via env prices
193208 try :
194209 if input_text :
@@ -206,6 +221,38 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
206221 except Exception :
207222 pass
208223
224+ def _infer_expression (q : str ) -> str :
225+ ql = q .lower ()
226+ # Handle "divide X by Y" and "subtract X from Y"
227+ m = re .search (r"\bdivide\s+(\d+(?:\.\d+)?)\s+by\s+(\d+(?:\.\d+)?)\b" , ql )
228+ if m :
229+ return f"{ m .group (1 )} /{ m .group (2 )} "
230+ m = re .search (r"\bsubtract\s+(\d+(?:\.\d+)?)\s+from\s+(\d+(?:\.\d+)?)\b" , ql )
231+ if m :
232+ return f"{ m .group (2 )} -{ m .group (1 )} "
233+ # Binary worded ops
234+ m = re .search (r"\b(\d+(?:\.\d+)?)\s*(?:x|times|multiplied by)\s*(\d+(?:\.\d+)?)\b" , ql )
235+ if m :
236+ return f"{ m .group (1 )} *{ m .group (2 )} "
237+ m = re .search (r"\b(\d+(?:\.\d+)?)\s*(?:plus|add|added to)\s*(\d+(?:\.\d+)?)\b" , ql )
238+ if m :
239+ return f"{ m .group (1 )} +{ m .group (2 )} "
240+ m = re .search (r"\b(\d+(?:\.\d+)?)\s*(?:minus|subtract|subtracted by)\s*(\d+(?:\.\d+)?)\b" , ql )
241+ if m :
242+ return f"{ m .group (1 )} -{ m .group (2 )} "
243+ m = re .search (r"\b(\d+(?:\.\d+)?)\s*(?:divide|divided by|over)\s*(\d+(?:\.\d+)?)\b" , ql )
244+ if m :
245+ return f"{ m .group (1 )} /{ m .group (2 )} "
246+ # Multi-number add/sum
247+ if "add" in ql or "sum" in ql :
248+ nums = [n for n in re .findall (r"\b\d+\b" , q )]
249+ if len (nums ) >= 2 :
250+ return "+" .join (nums )
251+ # Fallback: longest math-like substring
252+ candidates = re .findall (r"[0-9\+\-\*/%\(\)\.!\^\s]+" , q )
253+ candidates = [c .strip () for c in candidates if any (op in c for op in ["+" ,"-" ,"*" ,"/" ,"%" ,"^" ,"(" ,")" ,"!" ])]
254+ return max (candidates , key = len ) if candidates else ""
255+
209256 # Path A: OpenAI-native tool calling using DSPy signatures/adapters.
210257 if self ._use_tool_calls :
211258 dspy_tools = to_dspy_tools ()
@@ -275,6 +322,13 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
275322 # Check finalization.
276323 final = getattr (pred , 'final' , None )
277324 if final :
325+ if executed_any :
326+ state .append ({
327+ "tool" : "⛔️policy_violation" ,
328+ "args" : {"reason" : "tool_and_final" },
329+ "observation" : "Model returned a final answer alongside a tool call." ,
330+ })
331+ continue
278332 if had_policy_violation or had_validation_error :
279333 continue
280334 if must_math and not used_tool (state , "calculator" ):
@@ -319,29 +373,15 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
319373 if calculators :
320374 parts .append (str (calculators [0 ]["observation" ].get ("result" )))
321375 elif must_math :
322- # Last-chance math: infer a simple expression from the question.
323- import re as _re
324- ql = question .lower ()
325- if "add" in ql or "sum" in ql :
326- nums = [int (n ) for n in _re .findall (r"\b\d+\b" , question )]
327- if len (nums ) >= 2 :
328- res = sum (nums )
376+ expr = _infer_expression (question )
377+ if expr :
378+ try :
379+ res = safe_eval_math (expr )
329380 parts .append (str (res ))
330- # also record as a calculator step for trace parity
331- state .append ({"tool" : "calculator" , "args" : {"expression" : "+" .join (map (str , nums ))}, "observation" : {"result" : res }})
381+ state .append ({"tool" : "calculator" , "args" : {"expression" : expr }, "observation" : {"result" : res }})
332382 tool_calls += 1
333- if not parts :
334- candidates = _re .findall (r"[0-9\+\-\*/%\(\)\.!\^\s]+" , question )
335- candidates = [c .strip () for c in candidates if any (op in c for op in ["+" ,"-" ,"*" ,"/" ,"%" ,"^" ,"(" ,")" ,"!" ])]
336- expr = max (candidates , key = len ) if candidates else ""
337- if expr :
338- try :
339- res = safe_eval_math (expr )
340- parts .append (str (res ))
341- state .append ({"tool" : "calculator" , "args" : {"expression" : expr }, "observation" : {"result" : res }})
342- tool_calls += 1
343- except Exception :
344- pass
383+ except Exception :
384+ pass
345385 if nows :
346386 iso = nows [- 1 ]["observation" ].get ("iso" )
347387 if iso :
@@ -413,6 +453,13 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
413453 continue
414454
415455 if "final" in decision :
456+ if "tool" in decision :
457+ state .append ({
458+ "tool" : "⛔️policy_violation" ,
459+ "args" : {"reason" : "tool_and_final" },
460+ "observation" : "Decision contained both tool and final." ,
461+ })
462+ continue
416463 # Enforce tool usage policy: if required tools not yet used, keep planning.
417464 if must_math and not used_tool (state , "calculator" ):
418465 state .append ({"tool" : "⛔️policy_violation" , "args" : {}, "observation" : "Finalize attempted before calculator." })
@@ -430,7 +477,14 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
430477 iso = nows [- 1 ]["observation" ].get ("iso" )
431478 if iso :
432479 composed_parts .append (f"UTC: { iso } " )
433- final_text = " | " .join (composed_parts ) if composed_parts else decision ["final" ].get ("answer" , "" )
480+ if composed_parts :
481+ final_text = " | " .join (composed_parts )
482+ else :
483+ final_payload = decision .get ("final" )
484+ if isinstance (final_payload , dict ):
485+ final_text = final_payload .get ("answer" , "" )
486+ else :
487+ final_text = str (final_payload ) if final_payload is not None else ""
434488 p = dspy .Prediction (answer = final_text , trace = state )
435489 p .usage = {
436490 "lm_calls" : lm_calls ,
@@ -478,24 +532,12 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
478532 if calc_results :
479533 parts .append (str (calc_results [0 ]))
480534 if must_math and not parts :
481- # Last-chance math: try to infer a simple expression from the question.
482- ql = question .lower ()
483- # If looks like 'add X and Y', sum integers.
484- import re
485- if "add" in ql or "sum" in ql :
486- nums = [int (n ) for n in re .findall (r"\b\d+\b" , question )]
487- if len (nums ) >= 2 :
488- parts .append (str (sum (nums )))
489- if not parts :
490- # Extract longest math-like substring and evaluate.
491- candidates = re .findall (r"[0-9\+\-\*/%\(\)\.!\^\s]+" , question )
492- candidates = [c .strip () for c in candidates if any (op in c for op in ["+" ,"-" ,"*" ,"/" ,"%" ,"^" ,"(" ,")" ,"!" ])]
493- expr = max (candidates , key = len ) if candidates else ""
494- if expr :
495- try :
496- parts .append (str (safe_eval_math (expr )))
497- except Exception :
498- pass
535+ expr = _infer_expression (question )
536+ if expr :
537+ try :
538+ parts .append (str (safe_eval_math (expr )))
539+ except Exception :
540+ pass
499541 if nows :
500542 iso = nows [- 1 ]["observation" ].get ("iso" )
501543 if iso :
0 commit comments