@@ -170,31 +170,82 @@ async def _do_eval(self, eval_input: Any) -> DoEvalResult[T_EvalValue]:
170
170
171
171
# ~~~ METHODS THAT MIGHT NEED TO BE OVERRIDDEN BY CHILDREN~~~
172
172
173
- def _derive_singleton_inputs (self ) -> List [str ]:
173
+ def _derive_singleton_inputs (self ) -> List [List [ str ] ]:
174
174
"""Inspect the evaluator's __call__ function to determine what singleton inputs are expected
175
175
when the evaluator is being used in a non-conversation context.
176
176
By default, it's assumed that any input that is NOT kwargs or a conversation are singleton inputs.
177
177
Thankfully this works the way you'd hope, with the call_signature being based on the child
178
178
function's signature, not the parent's.
179
179
180
- :return: A list of strings representing the names of singleton inputs.
181
- :rtype: List[str]
180
+ :return: A list of lists, where each inner list represents the singleton inputs for each overload .
181
+ :rtype: List[List[ str] ]
182
182
"""
183
183
184
184
overloads = get_overloads (self .__call__ )
185
185
if not overloads :
186
186
call_signatures = [inspect .signature (self .__call__ )]
187
187
else :
188
188
call_signatures = [inspect .signature (overload ) for overload in overloads ]
189
- call_signature = inspect . signature ( self . __call__ )
190
- singletons = []
189
+
190
+ overload_inputs = []
191
191
for call_signature in call_signatures :
192
192
params = call_signature .parameters
193
193
if any (not_singleton_input in params for not_singleton_input in self ._not_singleton_inputs ):
194
194
continue
195
195
# exclude self since it is not a singleton input
196
- singletons .extend ([p for p in params if p != "self" ])
197
- return singletons
196
+ overload_inputs .append ([p for p in params if p != "self" ])
197
+ return overload_inputs
198
+
199
+ def _get_matching_overload_inputs (self , ** kwargs ) -> List [str ]:
200
+ """Find the overload that matches the provided kwargs and return its input parameters.
201
+
202
+ :keyword kwargs: The keyword arguments to match against overloads.
203
+ :type kwargs: Dict
204
+ :return: List of input parameter names for the matching overload.
205
+ :rtype: List[str]
206
+ """
207
+ overload_inputs = self ._singleton_inputs
208
+ provided_keys = set (key for key , value in kwargs .items () if value is not None )
209
+
210
+ # Find the overload that best matches the provided parameters
211
+ best_match = None
212
+ best_score = - 1
213
+
214
+ for inputs in overload_inputs :
215
+ input_set = set (inputs )
216
+
217
+ # Calculate match score: how many of the overload's params are provided
218
+ if input_set .issubset (provided_keys ):
219
+ score = len (input_set )
220
+ if score > best_score :
221
+ best_score = score
222
+ best_match = inputs
223
+
224
+ # If exact match found, return it
225
+ if best_match is not None :
226
+ return best_match
227
+
228
+ # If no exact match, find the overload with the most overlap
229
+ for inputs in overload_inputs :
230
+ input_set = set (inputs )
231
+ overlap = len (input_set .intersection (provided_keys ))
232
+ if overlap > best_score :
233
+ best_score = overlap
234
+ best_match = inputs
235
+
236
+ # Return the best match or the first overload as fallback
237
+ return best_match if best_match is not None else (overload_inputs [0 ] if overload_inputs else [])
238
+
239
+ def _get_all_singleton_inputs (self ) -> List [str ]:
240
+ """Get a flattened list of all possible singleton inputs across all overloads.
241
+
242
+ :return: Flattened list of all singleton input names.
243
+ :rtype: List[str]
244
+ """
245
+ all_inputs = set ()
246
+ for inputs in self ._singleton_inputs :
247
+ all_inputs .update (inputs )
248
+ return list (all_inputs )
198
249
199
250
def _derive_conversation_converter (
200
251
self ,
@@ -206,10 +257,11 @@ def _derive_conversation_converter(
206
257
:return: The function that will be used to convert conversations to evaluable inputs.
207
258
:rtype: Callable
208
259
"""
209
- include_context = "context" in self ._singleton_inputs
210
- include_query = "query" in self ._singleton_inputs
211
- include_response = "response" in self ._singleton_inputs
212
- include_ground_truth = "ground_truth" in self ._singleton_inputs
260
+ all_singleton_inputs = self ._get_all_singleton_inputs ()
261
+ include_context = "context" in all_singleton_inputs
262
+ include_query = "query" in all_singleton_inputs
263
+ include_response = "response" in all_singleton_inputs
264
+ include_ground_truth = "ground_truth" in all_singleton_inputs
213
265
214
266
def converter (conversation : Dict ) -> List [DerivedEvalInput ]:
215
267
messages = cast (List [Dict [str , Any ]], conversation ["messages" ])
@@ -319,9 +371,9 @@ def _convert_kwargs_to_eval_input(self, **kwargs) -> Union[List[Dict], List[Deri
319
371
(like a query and response), or they receive conversation that iss a list of dictionary
320
372
values.
321
373
322
- The self._singleton_inputs list assigned during initialization is used to find and extract
323
- singleton keywords, and self._allow_conversation_input is used to determine if a conversation
324
- is a valid input .
374
+ The self._singleton_inputs list (containing overload signatures) assigned during initialization
375
+ is used to find and extract singleton keywords, and determine which overload matches the
376
+ provided arguments .
325
377
326
378
If both conversations and singletons are allowed, the function will raise an exception if both
327
379
are inputted.
@@ -339,7 +391,10 @@ def _convert_kwargs_to_eval_input(self, **kwargs) -> Union[List[Dict], List[Deri
339
391
conversation = kwargs .get ("conversation" , None )
340
392
singletons = {}
341
393
if len (self ._singleton_inputs ) > 0 :
342
- singletons = {key : kwargs .get (key , None ) for key in self ._singleton_inputs }
394
+ # Get all possible singleton inputs and check what's provided
395
+ all_singleton_inputs = self ._get_all_singleton_inputs ()
396
+ singletons = {key : kwargs .get (key , None ) for key in all_singleton_inputs }
397
+
343
398
# Check that both conversation and other inputs aren't set
344
399
if conversation is not None and any (singletons .values ()):
345
400
msg = f"{ type (self ).__name__ } : Cannot provide both 'conversation' and individual inputs at the same time."
@@ -354,10 +409,16 @@ def _convert_kwargs_to_eval_input(self, **kwargs) -> Union[List[Dict], List[Deri
354
409
if self ._is_multi_modal_conversation (conversation ):
355
410
return self ._derive_multi_modal_conversation_converter ()(conversation )
356
411
return self ._derive_conversation_converter ()(conversation )
357
- # Handle Singletons
358
- required_singletons = remove_optional_singletons (self , singletons )
359
- if all (value is not None for value in required_singletons .values ()):
360
- return [singletons ]
412
+
413
+ # Handle Singletons - find matching overload
414
+ matching_inputs = self ._get_matching_overload_inputs (** kwargs )
415
+ if matching_inputs :
416
+ # Check if all required inputs for this overload are provided
417
+ required_singletons = {key : kwargs .get (key , None ) for key in matching_inputs }
418
+ required_singletons = remove_optional_singletons (self , required_singletons )
419
+ if all (value is not None for value in required_singletons .values ()):
420
+ return [singletons ]
421
+
361
422
# Missing input
362
423
msg = f"{ type (self ).__name__ } : Either 'conversation' or individual inputs must be provided."
363
424
raise EvaluationException (
@@ -416,6 +477,39 @@ def _aggregate_results(self, per_turn_results: List[DoEvalResult[T_EvalValue]])
416
477
aggregated ["evaluation_per_turn" ] = evaluation_per_turn
417
478
return aggregated
418
479
480
+ def _parse_tools_from_response (self , response ):
481
+ """Parse the response to extract tool calls and results.
482
+ :param response: The response to parse.
483
+ :type response: Union[str, List[dict]]
484
+ :return: List of tool calls extracted from the response.
485
+ :rtype: List[dict]
486
+ """
487
+ tool_calls = []
488
+ tool_results_map = {}
489
+ if isinstance (response , list ):
490
+ for message in response :
491
+ # Extract tool calls from assistant messages
492
+ if message .get ("role" ) == "assistant" and isinstance (message .get ("content" ), list ):
493
+ for content_item in message .get ("content" ):
494
+ if isinstance (content_item , dict ) and content_item .get ("type" ) == "tool_call" :
495
+ tool_calls .append (content_item )
496
+
497
+ # Extract tool results from tool messages
498
+ elif message .get ("role" ) == "tool" and message .get ("tool_call_id" ):
499
+ tool_call_id = message .get ("tool_call_id" )
500
+ if isinstance (message .get ("content" ), list ) and len (message .get ("content" )) > 0 :
501
+ result_content = message .get ("content" )[0 ]
502
+ if isinstance (result_content , dict ) and result_content .get ("type" ) == "tool_result" :
503
+ tool_results_map [tool_call_id ] = result_content
504
+
505
+ # Attach results to their corresponding calls
506
+ for tool_call in tool_calls :
507
+ tool_call_id = tool_call .get ("tool_call_id" )
508
+ if tool_call_id in tool_results_map :
509
+ tool_call ["tool_result" ] = tool_results_map [tool_call_id ]["tool_result" ]
510
+
511
+ return tool_calls
512
+
419
513
async def _real_call (self , ** kwargs ) -> Union [DoEvalResult [T_EvalValue ], AggregateResult [T_EvalValue ]]:
420
514
"""The asynchronous call where real end-to-end evaluation logic is performed.
421
515
0 commit comments