|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import math |
| 5 | +import os |
| 6 | +import logging |
| 7 | +from typing import Dict, Union, List, Optional |
| 8 | +from typing_extensions import overload, override |
| 9 | +from azure.ai.evaluation._exceptions import EvaluationException, ErrorBlame, ErrorCategory, ErrorTarget |
| 10 | +from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase |
| 11 | +from azure.ai.evaluation._common._experimental import experimental |
| 12 | + |
| 13 | + |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | + |
| 17 | +@experimental |
| 18 | +class ToolSuccessEvaluator(PromptyEvaluatorBase[Union[str, float]]): |
| 19 | + """The Tool Success evaluator determines whether tool calls done by an AI agent includes failures or not. |
| 20 | +
|
| 21 | + This evaluator focuses solely on tool call results and tool definitions, disregarding user's query to |
| 22 | + the agent, conversation history and agent's final response. Although tool definitions is optional, |
| 23 | + providing them can help the evaluator better understand the context of the tool calls made by the |
| 24 | + agent. Please note that this evaluator validates tool calls for potential technical failures like |
| 25 | + errors, exceptions, timeouts and empty results (only in cases where empty results could indicate a |
| 26 | + failure). It does not assess the correctness or the tool result itself, like mathematical errors and |
| 27 | + unrealistic field values like name="668656". |
| 28 | +
|
| 29 | + Scoring is binary: |
| 30 | + - TRUE: All tool calls were successful |
| 31 | + - FALSE: At least one tool call failed |
| 32 | +
|
| 33 | + :param model_config: Configuration for the Azure OpenAI model. |
| 34 | + :type model_config: Union[~azure.ai.evaluation.AzureOpenAIModelConfiguration, |
| 35 | + ~azure.ai.evaluation.OpenAIModelConfiguration] |
| 36 | +
|
| 37 | + .. admonition:: Example: |
| 38 | + .. literalinclude:: ../samples/evaluation_samples_evaluate.py |
| 39 | + :start-after: [START tool_success_evaluator] |
| 40 | + :end-before: [END tool_success_evaluator] |
| 41 | + :language: python |
| 42 | + :dedent: 8 |
| 43 | + :caption: Initialize and call a ToolSuccessEvaluator with a tool definitions and response. |
| 44 | +
|
| 45 | + .. admonition:: Example using Azure AI Project URL: |
| 46 | +
|
| 47 | + .. literalinclude:: ../samples/evaluation_samples_evaluate_fdp.py |
| 48 | + :start-after: [START tool_success_evaluator] |
| 49 | + :end-before: [END tool_success_evaluator] |
| 50 | + :language: python |
| 51 | + :dedent: 8 |
| 52 | + :caption: Initialize and call ToolSuccessEvaluator using Azure AI Project URL in the following |
| 53 | + format https://{resource_name}.services.ai.azure.com/api/projects/{project_name} |
| 54 | +
|
| 55 | + """ |
| 56 | + |
| 57 | + _PROMPTY_FILE = "tool_success.prompty" |
| 58 | + _RESULT_KEY = "tool_success" |
| 59 | + _OPTIONAL_PARAMS = ["tool_definitions"] |
| 60 | + |
| 61 | + id = "azureai://built-in/evaluators/tool_success" |
| 62 | + """Evaluator identifier, experimental and to be used only with evaluation in cloud.""" |
| 63 | + |
| 64 | + @override |
| 65 | + def __init__(self, model_config, *, credential=None, **kwargs): |
| 66 | + """Initialize the Tool Success evaluator.""" |
| 67 | + current_dir = os.path.dirname(__file__) |
| 68 | + prompty_path = os.path.join(current_dir, self._PROMPTY_FILE) |
| 69 | + super().__init__( |
| 70 | + model_config=model_config, |
| 71 | + prompty_file=prompty_path, |
| 72 | + result_key=self._RESULT_KEY, |
| 73 | + threshold=1, |
| 74 | + credential=credential, |
| 75 | + _higher_is_better=True, |
| 76 | + **kwargs, |
| 77 | + ) |
| 78 | + |
| 79 | + @overload |
| 80 | + def __call__( |
| 81 | + self, |
| 82 | + *, |
| 83 | + response: Union[str, List[dict]], |
| 84 | + tool_definitions: Optional[Union[dict, List[dict]]] = None, |
| 85 | + ) -> Dict[str, Union[str, float]]: |
| 86 | + """Evaluate tool call success for a given response, and optionally tool definitions. |
| 87 | +
|
| 88 | + Example with list of messages: |
| 89 | + evaluator = ToolSuccessEvaluator(model_config) |
| 90 | + response = [{'createdAt': 1700000070, 'run_id': '0', 'role': 'assistant', |
| 91 | + 'content': [{'type': 'text', 'text': '**Day 1:** Morning: Visit Louvre Museum (9 AM - 12 PM)...'}]}] |
| 92 | +
|
| 93 | + result = evaluator(response=response, ) |
| 94 | +
|
| 95 | + :keyword response: The response being evaluated, either a string or a list of messages (full agent |
| 96 | + response potentially including tool calls) |
| 97 | + :paramtype response: Union[str, List[dict]] |
| 98 | + :keyword tool_definitions: Optional tool definitions to use for evaluation. |
| 99 | + :paramtype tool_definitions: Union[dict, List[dict]] |
| 100 | + :return: A dictionary with the tool success evaluation results. |
| 101 | + :rtype: Dict[str, Union[str, float]] |
| 102 | + """ |
| 103 | + |
| 104 | + @override |
| 105 | + def __call__( # pylint: disable=docstring-missing-param |
| 106 | + self, |
| 107 | + *args, |
| 108 | + **kwargs, |
| 109 | + ): |
| 110 | + """ |
| 111 | + Invoke the instance using the overloaded __call__ signature. |
| 112 | +
|
| 113 | + For detailed parameter types and return value documentation, see the overloaded __call__ definition. |
| 114 | + """ |
| 115 | + return super().__call__(*args, **kwargs) |
| 116 | + |
| 117 | + @override |
| 118 | + async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[str, float]]: # type: ignore[override] |
| 119 | + """Do Tool Success evaluation. |
| 120 | +
|
| 121 | + :param eval_input: The input to the evaluator. Expected to contain whatever inputs are |
| 122 | + needed for the _flow method |
| 123 | + :type eval_input: Dict |
| 124 | + :return: The evaluation result. |
| 125 | + :rtype: Dict |
| 126 | + """ |
| 127 | + if "response" not in eval_input: |
| 128 | + raise EvaluationException( |
| 129 | + message="response is a required input to the Tool Success evaluator.", |
| 130 | + internal_message="response is a required input to the Tool Success evaluator.", |
| 131 | + blame=ErrorBlame.USER_ERROR, |
| 132 | + category=ErrorCategory.MISSING_FIELD, |
| 133 | + target=ErrorTarget.TOOL_SUCCESS_EVALUATOR, |
| 134 | + ) |
| 135 | + if eval_input["response"] is None or eval_input["response"] == []: |
| 136 | + raise EvaluationException( |
| 137 | + message="response cannot be None or empty for the Tool Success evaluator.", |
| 138 | + internal_message="response cannot be None or empty for the Tool Success evaluator.", |
| 139 | + blame=ErrorBlame.USER_ERROR, |
| 140 | + category=ErrorCategory.INVALID_VALUE, |
| 141 | + target=ErrorTarget.TOOL_SUCCESS_EVALUATOR, |
| 142 | + ) |
| 143 | + |
| 144 | + eval_input["tool_calls"] = _reformat_tool_calls_results(eval_input["response"], logger) |
| 145 | + |
| 146 | + if "tool_definitions" in eval_input: |
| 147 | + tool_definitions = eval_input["tool_definitions"] |
| 148 | + filtered_tool_definitions = _filter_to_used_tools( |
| 149 | + tool_definitions=tool_definitions, |
| 150 | + msgs_list=eval_input["response"], |
| 151 | + logger=logger, |
| 152 | + ) |
| 153 | + eval_input["tool_definitions"] = _reformat_tool_definitions(filtered_tool_definitions, logger) |
| 154 | + |
| 155 | + prompty_output_dict = await self._flow(timeout=self._LLM_CALL_TIMEOUT, **eval_input) |
| 156 | + llm_output = prompty_output_dict.get("llm_output", "") |
| 157 | + |
| 158 | + if isinstance(llm_output, dict): |
| 159 | + success = llm_output.get("success", False) |
| 160 | + if isinstance(success, str): |
| 161 | + success = success.upper() == "TRUE" |
| 162 | + |
| 163 | + success_result = "pass" if success else "fail" |
| 164 | + reason = llm_output.get("explanation", "") |
| 165 | + return { |
| 166 | + f"{self._result_key}": success * 1.0, |
| 167 | + f"{self._result_key}_result": success_result, |
| 168 | + f"{self._result_key}_threshold": self._threshold, |
| 169 | + f"{self._result_key}_reason": f"{reason} {llm_output.get('details', '')}", |
| 170 | + f"{self._result_key}_prompt_tokens": prompty_output_dict.get("input_token_count", 0), |
| 171 | + f"{self._result_key}_completion_tokens": prompty_output_dict.get("output_token_count", 0), |
| 172 | + f"{self._result_key}_total_tokens": prompty_output_dict.get("total_token_count", 0), |
| 173 | + f"{self._result_key}_finish_reason": prompty_output_dict.get("finish_reason", ""), |
| 174 | + f"{self._result_key}_model": prompty_output_dict.get("model_id", ""), |
| 175 | + f"{self._result_key}_sample_input": prompty_output_dict.get("sample_input", ""), |
| 176 | + f"{self._result_key}_sample_output": prompty_output_dict.get("sample_output", ""), |
| 177 | + } |
| 178 | + if logger: |
| 179 | + logger.warning("LLM output is not a dictionary, returning NaN for the score.") |
| 180 | + |
| 181 | + score = math.nan |
| 182 | + binary_result = self._get_binary_result(score) |
| 183 | + return { |
| 184 | + self._result_key: float(score), |
| 185 | + f"{self._result_key}_result": binary_result, |
| 186 | + f"{self._result_key}_threshold": self._threshold, |
| 187 | + } |
| 188 | + |
| 189 | + |
| 190 | +def _filter_to_used_tools(tool_definitions, msgs_list, logger=None): |
| 191 | + """Filter the tool definitions to only include those that were actually used in the messages lists.""" |
| 192 | + try: |
| 193 | + used_tool_names = set() |
| 194 | + any_tools_used = False |
| 195 | + |
| 196 | + for msg in msgs_list: |
| 197 | + if msg.get("role") == "assistant" and "content" in msg: |
| 198 | + for content in msg.get("content", []): |
| 199 | + if content.get("type") == "tool_call": |
| 200 | + any_tools_used = True |
| 201 | + if "tool_call" in content and "function" in content["tool_call"]: |
| 202 | + used_tool_names.add(content["tool_call"]["function"]) |
| 203 | + elif "name" in content: |
| 204 | + used_tool_names.add(content["name"]) |
| 205 | + |
| 206 | + filtered_tools = [tool for tool in tool_definitions if tool.get("name") in used_tool_names] |
| 207 | + if any_tools_used and not filtered_tools: |
| 208 | + if logger: |
| 209 | + logger.warning("No tool definitions matched the tools used in the messages. Returning original list.") |
| 210 | + filtered_tools = tool_definitions |
| 211 | + |
| 212 | + return filtered_tools |
| 213 | + except Exception as e: |
| 214 | + if logger: |
| 215 | + logger.warning(f"Failed to filter tool definitions, returning original list. Error: {e}") |
| 216 | + return tool_definitions |
| 217 | + |
| 218 | + |
| 219 | +def _get_tool_calls_results(agent_response_msgs): |
| 220 | + """Extract formatted agent tool calls and results from response.""" |
| 221 | + agent_response_text = [] |
| 222 | + tool_results = {} |
| 223 | + |
| 224 | + # First pass: collect tool results |
| 225 | + |
| 226 | + for msg in agent_response_msgs: |
| 227 | + if msg.get("role") == "tool" and "tool_call_id" in msg: |
| 228 | + for content in msg.get("content", []): |
| 229 | + if content.get("type") == "tool_result": |
| 230 | + result = content.get("tool_result") |
| 231 | + tool_results[msg["tool_call_id"]] = f"[TOOL_RESULT] {result}" |
| 232 | + |
| 233 | + # Second pass: parse assistant messages and tool calls |
| 234 | + for msg in agent_response_msgs: |
| 235 | + if "role" in msg and msg.get("role") == "assistant" and "content" in msg: |
| 236 | + |
| 237 | + for content in msg.get("content", []): |
| 238 | + |
| 239 | + if content.get("type") == "tool_call": |
| 240 | + if "tool_call" in content and "function" in content.get("tool_call", {}): |
| 241 | + tc = content.get("tool_call", {}) |
| 242 | + func_name = tc.get("function", {}).get("name", "") |
| 243 | + args = tc.get("function", {}).get("arguments", {}) |
| 244 | + tool_call_id = tc.get("id") |
| 245 | + else: |
| 246 | + tool_call_id = content.get("tool_call_id") |
| 247 | + func_name = content.get("name", "") |
| 248 | + args = content.get("arguments", {}) |
| 249 | + args_str = ", ".join(f'{k}="{v}"' for k, v in args.items()) |
| 250 | + call_line = f"[TOOL_CALL] {func_name}({args_str})" |
| 251 | + agent_response_text.append(call_line) |
| 252 | + if tool_call_id in tool_results: |
| 253 | + agent_response_text.append(tool_results[tool_call_id]) |
| 254 | + |
| 255 | + return agent_response_text |
| 256 | + |
| 257 | + |
| 258 | +def _reformat_tool_calls_results(response, logger=None): |
| 259 | + try: |
| 260 | + if response is None or response == []: |
| 261 | + return "" |
| 262 | + agent_response = _get_tool_calls_results(response) |
| 263 | + if agent_response == []: |
| 264 | + # If no message could be extracted, likely the format changed, |
| 265 | + # fallback to the original response in that case |
| 266 | + if logger: |
| 267 | + logger.warning( |
| 268 | + f"Empty agent response extracted, likely due to input schema change. " |
| 269 | + f"Falling back to using the original response: {response}" |
| 270 | + ) |
| 271 | + return response |
| 272 | + return "\n".join(agent_response) |
| 273 | + except Exception: |
| 274 | + # If the agent response cannot be parsed for whatever |
| 275 | + # reason (e.g. the converter format changed), the original response is returned |
| 276 | + # This is a fallback to ensure that the evaluation can still proceed. |
| 277 | + # See comments on reformat_conversation_history for more details. |
| 278 | + if logger: |
| 279 | + logger.warning(f"Agent response could not be parsed, falling back to original response: {response}") |
| 280 | + return response |
| 281 | + |
| 282 | + |
| 283 | +def _reformat_tool_definitions(tool_definitions, logger=None): |
| 284 | + try: |
| 285 | + output_lines = ["TOOL_DEFINITIONS:"] |
| 286 | + for tool in tool_definitions: |
| 287 | + name = tool.get("name", "unnamed_tool") |
| 288 | + desc = tool.get("description", "").strip() |
| 289 | + params = tool.get("parameters", {}).get("properties", {}) |
| 290 | + param_names = ", ".join(params.keys()) if params else "no parameters" |
| 291 | + output_lines.append(f"- {name}: {desc} (inputs: {param_names})") |
| 292 | + return "\n".join(output_lines) |
| 293 | + except Exception: |
| 294 | + # If the tool definitions cannot be parsed for whatever reason, the original tool definitions are returned |
| 295 | + # This is a fallback to ensure that the evaluation can still proceed. |
| 296 | + # See comments on reformat_conversation_history for more details. |
| 297 | + if logger: |
| 298 | + logger.warning( |
| 299 | + f"Tool definitions could not be parsed, falling back to original definitions: {tool_definitions}" |
| 300 | + ) |
| 301 | + return tool_definitions |
0 commit comments