|
10 | 10 |
|
11 | 11 | from azure.ai.evaluation._exceptions import EvaluationException, ErrorBlame, ErrorCategory, ErrorTarget |
12 | 12 | from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase |
13 | | -from ..._common.utils import reformat_conversation_history, reformat_agent_response, reformat_tool_definitions |
| 13 | +from ..._common.utils import ( |
| 14 | + reformat_conversation_history, |
| 15 | + reformat_agent_response, |
| 16 | + reformat_tool_definitions, |
| 17 | +) |
14 | 18 | from azure.ai.evaluation._model_configurations import Message |
15 | 19 | from azure.ai.evaluation._common._experimental import experimental |
16 | 20 |
|
@@ -73,12 +77,14 @@ class TaskAdherenceEvaluator(PromptyEvaluatorBase[Union[str, float]]): |
73 | 77 | def __init__(self, model_config, *, threshold=_DEFAULT_TASK_ADHERENCE_SCORE, credential=None, **kwargs): |
74 | 78 | current_dir = os.path.dirname(__file__) |
75 | 79 | prompty_path = os.path.join(current_dir, self._PROMPTY_FILE) |
76 | | - self.threshold = threshold |
| 80 | + self.threshold = threshold # to be removed in favor of _threshold |
77 | 81 | super().__init__( |
78 | 82 | model_config=model_config, |
79 | 83 | prompty_file=prompty_path, |
80 | 84 | result_key=self._RESULT_KEY, |
| 85 | + threshold=threshold, |
81 | 86 | credential=credential, |
| 87 | + _higher_is_better=True, |
82 | 88 | **kwargs, |
83 | 89 | ) |
84 | 90 |
|
@@ -154,19 +160,38 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # t |
154 | 160 | eval_input["response"] = reformat_agent_response(eval_input["response"], logger, include_tool_messages=True) |
155 | 161 | if "tool_definitions" in eval_input and eval_input["tool_definitions"] is not None: |
156 | 162 | eval_input["tool_definitions"] = reformat_tool_definitions(eval_input["tool_definitions"], logger) |
157 | | - llm_output = await self._flow(timeout=self._LLM_CALL_TIMEOUT, **eval_input) |
| 163 | + |
| 164 | + prompty_output_dict = await self._flow(timeout=self._LLM_CALL_TIMEOUT, **eval_input) |
| 165 | + llm_output = prompty_output_dict["llm_output"] |
| 166 | + |
| 167 | + score = math.nan |
158 | 168 | if isinstance(llm_output, dict): |
159 | 169 | score = float(llm_output.get("score", math.nan)) |
160 | | - score_result = "pass" if score >= self.threshold else "fail" |
| 170 | + score_result = "pass" if score >= self._threshold else "fail" |
161 | 171 | reason = llm_output.get("explanation", "") |
162 | 172 | return { |
163 | 173 | f"{self._result_key}": score, |
| 174 | + f"gpt_{self._result_key}": score, |
164 | 175 | f"{self._result_key}_result": score_result, |
165 | | - f"{self._result_key}_threshold": self.threshold, |
| 176 | + f"{self._result_key}_threshold": self._threshold, |
166 | 177 | f"{self._result_key}_reason": reason, |
167 | 178 | # Uncomment the following line in the next iteration after UI contracts are validated. |
168 | 179 | # f"{self._result_key}_additional_details": llm_output |
| 180 | + f"{self._result_key}_prompt_tokens": prompty_output_dict.get("input_token_count", 0), |
| 181 | + f"{self._result_key}_completion_tokens": prompty_output_dict.get("output_token_count", 0), |
| 182 | + f"{self._result_key}_total_tokens": prompty_output_dict.get("total_token_count", 0), |
| 183 | + f"{self._result_key}_finish_reason": prompty_output_dict.get("finish_reason", ""), |
| 184 | + f"{self._result_key}_model": prompty_output_dict.get("model_id", ""), |
| 185 | + f"{self._result_key}_sample_input": prompty_output_dict.get("sample_input", ""), |
| 186 | + f"{self._result_key}_sample_output": prompty_output_dict.get("sample_output", ""), |
169 | 187 | } |
170 | 188 | if logger: |
171 | 189 | logger.warning("LLM output is not a dictionary, returning NaN for the score.") |
172 | | - return {self._result_key: math.nan} |
| 190 | + |
| 191 | + binary_result = self._get_binary_result(score) |
| 192 | + return { |
| 193 | + self._result_key: float(score), |
| 194 | + f"gpt_{self._result_key}": float(score), |
| 195 | + f"{self._result_key}_result": binary_result, |
| 196 | + f"{self._result_key}_threshold": self._threshold, |
| 197 | + } |
0 commit comments