@@ -170,31 +170,82 @@ async def _do_eval(self, eval_input: Any) -> DoEvalResult[T_EvalValue]:
170170
171171 # ~~~ METHODS THAT MIGHT NEED TO BE OVERRIDDEN BY CHILDREN~~~
172172
173- def _derive_singleton_inputs (self ) -> List [str ]:
173+ def _derive_singleton_inputs (self ) -> List [List [ str ] ]:
174174 """Inspect the evaluator's __call__ function to determine what singleton inputs are expected
175175 when the evaluator is being used in a non-conversation context.
176176 By default, it's assumed that any input that is NOT kwargs or a conversation are singleton inputs.
177177 Thankfully this works the way you'd hope, with the call_signature being based on the child
178178 function's signature, not the parent's.
179179
180- :return: A list of strings representing the names of singleton inputs.
181- :rtype: List[str]
180+ :return: A list of lists, where each inner list represents the singleton inputs for each overload .
181+ :rtype: List[List[ str] ]
182182 """
183183
184184 overloads = get_overloads (self .__call__ )
185185 if not overloads :
186186 call_signatures = [inspect .signature (self .__call__ )]
187187 else :
188188 call_signatures = [inspect .signature (overload ) for overload in overloads ]
189- call_signature = inspect . signature ( self . __call__ )
190- singletons = []
189+
190+ overload_inputs = []
191191 for call_signature in call_signatures :
192192 params = call_signature .parameters
193193 if any (not_singleton_input in params for not_singleton_input in self ._not_singleton_inputs ):
194194 continue
195195 # exclude self since it is not a singleton input
196- singletons .extend ([p for p in params if p != "self" ])
197- return singletons
196+ overload_inputs .append ([p for p in params if p != "self" ])
197+ return overload_inputs
198+
199+ def _get_matching_overload_inputs (self , ** kwargs ) -> List [str ]:
200+ """Find the overload that matches the provided kwargs and return its input parameters.
201+
202+ :keyword kwargs: The keyword arguments to match against overloads.
203+ :type kwargs: Dict
204+ :return: List of input parameter names for the matching overload.
205+ :rtype: List[str]
206+ """
207+ overload_inputs = self ._singleton_inputs
208+ provided_keys = set (key for key , value in kwargs .items () if value is not None )
209+
210+ # Find the overload that best matches the provided parameters
211+ best_match = None
212+ best_score = - 1
213+
214+ for inputs in overload_inputs :
215+ input_set = set (inputs )
216+
217+ # Calculate match score: how many of the overload's params are provided
218+ if input_set .issubset (provided_keys ):
219+ score = len (input_set )
220+ if score > best_score :
221+ best_score = score
222+ best_match = inputs
223+
224+ # If exact match found, return it
225+ if best_match is not None :
226+ return best_match
227+
228+ # If no exact match, find the overload with the most overlap
229+ for inputs in overload_inputs :
230+ input_set = set (inputs )
231+ overlap = len (input_set .intersection (provided_keys ))
232+ if overlap > best_score :
233+ best_score = overlap
234+ best_match = inputs
235+
236+ # Return the best match or the first overload as fallback
237+ return best_match if best_match is not None else (overload_inputs [0 ] if overload_inputs else [])
238+
239+ def _get_all_singleton_inputs (self ) -> List [str ]:
240+ """Get a flattened list of all possible singleton inputs across all overloads.
241+
242+ :return: Flattened list of all singleton input names.
243+ :rtype: List[str]
244+ """
245+ all_inputs = set ()
246+ for inputs in self ._singleton_inputs :
247+ all_inputs .update (inputs )
248+ return list (all_inputs )
198249
199250 def _derive_conversation_converter (
200251 self ,
@@ -206,10 +257,11 @@ def _derive_conversation_converter(
206257 :return: The function that will be used to convert conversations to evaluable inputs.
207258 :rtype: Callable
208259 """
209- include_context = "context" in self ._singleton_inputs
210- include_query = "query" in self ._singleton_inputs
211- include_response = "response" in self ._singleton_inputs
212- include_ground_truth = "ground_truth" in self ._singleton_inputs
260+ all_singleton_inputs = self ._get_all_singleton_inputs ()
261+ include_context = "context" in all_singleton_inputs
262+ include_query = "query" in all_singleton_inputs
263+ include_response = "response" in all_singleton_inputs
264+ include_ground_truth = "ground_truth" in all_singleton_inputs
213265
214266 def converter (conversation : Dict ) -> List [DerivedEvalInput ]:
215267 messages = cast (List [Dict [str , Any ]], conversation ["messages" ])
@@ -319,9 +371,9 @@ def _convert_kwargs_to_eval_input(self, **kwargs) -> Union[List[Dict], List[Deri
319371 (like a query and response), or they receive conversation that iss a list of dictionary
320372 values.
321373
322- The self._singleton_inputs list assigned during initialization is used to find and extract
323- singleton keywords, and self._allow_conversation_input is used to determine if a conversation
324- is a valid input .
374+ The self._singleton_inputs list (containing overload signatures) assigned during initialization
375+ is used to find and extract singleton keywords, and determine which overload matches the
376+ provided arguments .
325377
326378 If both conversations and singletons are allowed, the function will raise an exception if both
327379 are inputted.
@@ -339,7 +391,10 @@ def _convert_kwargs_to_eval_input(self, **kwargs) -> Union[List[Dict], List[Deri
339391 conversation = kwargs .get ("conversation" , None )
340392 singletons = {}
341393 if len (self ._singleton_inputs ) > 0 :
342- singletons = {key : kwargs .get (key , None ) for key in self ._singleton_inputs }
394+ # Get all possible singleton inputs and check what's provided
395+ all_singleton_inputs = self ._get_all_singleton_inputs ()
396+ singletons = {key : kwargs .get (key , None ) for key in all_singleton_inputs }
397+
343398 # Check that both conversation and other inputs aren't set
344399 if conversation is not None and any (singletons .values ()):
345400 msg = f"{ type (self ).__name__ } : Cannot provide both 'conversation' and individual inputs at the same time."
@@ -354,10 +409,16 @@ def _convert_kwargs_to_eval_input(self, **kwargs) -> Union[List[Dict], List[Deri
354409 if self ._is_multi_modal_conversation (conversation ):
355410 return self ._derive_multi_modal_conversation_converter ()(conversation )
356411 return self ._derive_conversation_converter ()(conversation )
357- # Handle Singletons
358- required_singletons = remove_optional_singletons (self , singletons )
359- if all (value is not None for value in required_singletons .values ()):
360- return [singletons ]
412+
413+ # Handle Singletons - find matching overload
414+ matching_inputs = self ._get_matching_overload_inputs (** kwargs )
415+ if matching_inputs :
416+ # Check if all required inputs for this overload are provided
417+ required_singletons = {key : kwargs .get (key , None ) for key in matching_inputs }
418+ required_singletons = remove_optional_singletons (self , required_singletons )
419+ if all (value is not None for value in required_singletons .values ()):
420+ return [singletons ]
421+
361422 # Missing input
362423 msg = f"{ type (self ).__name__ } : Either 'conversation' or individual inputs must be provided."
363424 raise EvaluationException (
@@ -416,6 +477,39 @@ def _aggregate_results(self, per_turn_results: List[DoEvalResult[T_EvalValue]])
416477 aggregated ["evaluation_per_turn" ] = evaluation_per_turn
417478 return aggregated
418479
480+ def _parse_tools_from_response (self , response ):
481+ """Parse the response to extract tool calls and results.
482+ :param response: The response to parse.
483+ :type response: Union[str, List[dict]]
484+ :return: List of tool calls extracted from the response.
485+ :rtype: List[dict]
486+ """
487+ tool_calls = []
488+ tool_results_map = {}
489+ if isinstance (response , list ):
490+ for message in response :
491+ # Extract tool calls from assistant messages
492+ if message .get ("role" ) == "assistant" and isinstance (message .get ("content" ), list ):
493+ for content_item in message .get ("content" ):
494+ if isinstance (content_item , dict ) and content_item .get ("type" ) == "tool_call" :
495+ tool_calls .append (content_item )
496+
497+ # Extract tool results from tool messages
498+ elif message .get ("role" ) == "tool" and message .get ("tool_call_id" ):
499+ tool_call_id = message .get ("tool_call_id" )
500+ if isinstance (message .get ("content" ), list ) and len (message .get ("content" )) > 0 :
501+ result_content = message .get ("content" )[0 ]
502+ if isinstance (result_content , dict ) and result_content .get ("type" ) == "tool_result" :
503+ tool_results_map [tool_call_id ] = result_content
504+
505+ # Attach results to their corresponding calls
506+ for tool_call in tool_calls :
507+ tool_call_id = tool_call .get ("tool_call_id" )
508+ if tool_call_id in tool_results_map :
509+ tool_call ["tool_result" ] = tool_results_map [tool_call_id ]["tool_result" ]
510+
511+ return tool_calls
512+
419513 async def _real_call (self , ** kwargs ) -> Union [DoEvalResult [T_EvalValue ], AggregateResult [T_EvalValue ]]:
420514 """The asynchronous call where real end-to-end evaluation logic is performed.
421515
0 commit comments