Skip to content

Commit d6a956c

Browse files
authored
Custom metrics in evaluate (#33339)
* Custom metrics in evaluate * custom evaluate metric with system prompt * Updated to allow manageable response for custom metrics * Adding reason for custom metrics * Code based custom metrics * Adding TODO item * Adding comment for Prompt based metric * Prompt based metrics * Custom metric code and promt based * Adding doc string for custom metrics * Fixing spell checks * Fixing spell check errors * Fixing file path for spell check * Adding aggregatores * Custom metrics in evaluate * custom evaluate metric with system prompt * Updated to allow manageable response for custom metrics * Adding reason for custom metrics * Code based custom metrics * Adding TODO item * Adding comment for Prompt based metric * Prompt based metrics * Custom metric code and promt based * Adding doc string for custom metrics * Fixing spell checks * Fixing spell check errors * Fixing file path for spell check * Adding aggregatores * Adding user-agent to AOAI calls * Revew comments * Review comments
1 parent c13ec18 commit d6a956c

File tree

17 files changed

+796
-63
lines changed

17 files changed

+796
-63
lines changed

.vscode/cspell.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,13 @@
12461246
"smirnov"
12471247
]
12481248
},
1249+
{
1250+
"filename": "sdk/ai/azure-ai-generative/**",
1251+
"words": [
1252+
"tqdm",
1253+
"genai"
1254+
]
1255+
},
12491256
{
12501257
"filename": "sdk/attestation/azure-security-attestation/tests/conftest.py",
12511258
"words":[

sdk/ai/azure-ai-generative/MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ include azure/__init__.py
55
include azure/ai/__init__.py
66
include azure/ai/generative/py.typed
77
include azure/ai/generative/index/_utils/encodings/*
8+
include azure/ai/generative/evaluate/metrics/templates/*
89
recursive-include azure/ai/generative/synthetic/templates *.txt
910
recursive-include azure/ai/generative/synthetic/simulator/templates *.md

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/_client/__init__.py

Whitespace-only changes.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
import asyncio
5+
import logging
6+
7+
from openai import AsyncAzureOpenAI
8+
from openai.types.chat.chat_completion import ChatCompletion
9+
10+
from azure.ai.generative._user_agent import USER_AGENT
11+
from azure.ai.generative.constants._common import USER_AGENT_HEADER_KEY
12+
13+
semaphore = asyncio.Semaphore(10)
14+
15+
LOGGER = logging.getLogger(__name__)
16+
17+
18+
class AzureOpenAIClient:
19+
20+
def __init__(self, openai_params):
21+
self._azure_endpoint = openai_params.get("azure_endpoint", None) if openai_params.get("azure_endpoint", None) \
22+
else openai_params.get("api_base", None)
23+
self._api_key = openai_params.get("api_key", None)
24+
self._api_version = openai_params.get("api_version", None)
25+
self._azure_deployment = openai_params.get("azure_deployment", None)\
26+
if openai_params.get("azure_deployment", None) else openai_params.get("deployment_id", None)
27+
28+
self._client = AsyncAzureOpenAI(
29+
azure_endpoint=self._azure_endpoint,
30+
api_version=self._api_version,
31+
api_key=self._api_key,
32+
default_headers={
33+
USER_AGENT_HEADER_KEY: USER_AGENT,
34+
"client_operation_source": "evaluate"
35+
},
36+
)
37+
38+
async def bounded_chat_completion(self, messages):
39+
async with semaphore:
40+
try:
41+
result = await self._client.with_options(max_retries=5).chat.completions.create(
42+
model=self._azure_deployment,
43+
messages=messages,
44+
temperature=0,
45+
seed=0,
46+
)
47+
return result
48+
except Exception as ex:
49+
LOGGER.debug(f"Failed to call llm with exception : {str(ex)}")
50+
return ex
51+
52+
@staticmethod
53+
def get_chat_completion_content_from_response(response):
54+
if isinstance(response, ChatCompletion):
55+
return response.choices[0].message.content
56+
return None

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,8 @@ class ChatMetrics:
7575
"qa": QaMetrics,
7676
"rag-evaluation": ChatMetrics
7777
}
78+
79+
SUPPORTED_TASK_TYPE_TO_METRICS_MAPPING = {
80+
QA: QaMetrics,
81+
CHAT: ChatMetrics
82+
}

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/_evaluate.py

Lines changed: 144 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4+
import copy
45
import json
56
import os
67
import shutil
78
import tempfile
89
import time
910
import logging
11+
from json import JSONDecodeError
1012
from pathlib import Path
1113
from typing import Callable, Optional, Dict, List, Mapping
1214

1315
import mlflow
16+
import numpy as np
1417
import pandas as pd
1518
from azure.core.tracing.decorator import distributed_trace
1619
from azure.ai.generative._telemetry import ActivityType, monitor_with_activity, monitor_with_telemetry_mixin, ActivityLogger
@@ -20,12 +23,16 @@
2023
from mlflow.protos.databricks_pb2 import ErrorCode, INVALID_PARAMETER_VALUE
2124

2225
from azure.ai.generative.evaluate._metric_handler import MetricHandler
26+
from azure.ai.generative.evaluate._metrics_handler._code_metric_handler import CodeMetricHandler
2327
from azure.ai.generative.evaluate._utils import _is_flow, load_jsonl, _get_artifact_dir_path, _copy_artifact
2428
from azure.ai.generative.evaluate._mlflow_log_collector import RedirectUserOutputStreams
25-
from azure.ai.generative.evaluate._constants import SUPPORTED_TO_METRICS_TASK_TYPE_MAPPING, SUPPORTED_TASK_TYPE, CHAT
29+
from azure.ai.generative.evaluate._constants import SUPPORTED_TO_METRICS_TASK_TYPE_MAPPING, SUPPORTED_TASK_TYPE, CHAT, \
30+
TYPE_TO_KWARGS_MAPPING, SUPPORTED_TASK_TYPE_TO_METRICS_MAPPING
2631
from azure.ai.generative.evaluate._evaluation_result import EvaluationResult
32+
from ._metrics_handler._prompt_metric_handler import PromptMetricHandler
2733

2834
from ._utils import _write_properties_to_run_history
35+
from .metrics._custom_metric import CodeMetric, PromptMetric, Metric as GenAIMetric
2936

3037
LOGGER = logging.getLogger(__name__)
3138

@@ -47,6 +54,19 @@ def _get_handler_class(
4754
return handler
4855

4956

57+
def _get_metric_handler_class(
58+
asset,
59+
):
60+
if _is_flow(asset):
61+
from azure.ai.generative.evaluate._local_flow_handler import LocalFlowHandler
62+
handler = LocalFlowHandler
63+
else:
64+
from azure.ai.generative.evaluate._local_code_handler import LocalCodeHandler
65+
handler = LocalCodeHandler
66+
67+
return handler
68+
69+
5070
def _validate_data(data, prediction_data, truth_data):
5171
errors = []
5272
prediction_data_column = ""
@@ -83,6 +103,28 @@ def _log_metrics(run_id, metrics):
83103
)
84104

85105

106+
def _validate_metrics(metrics, task_type):
107+
genai_metrics = []
108+
builtin_metrics =[]
109+
unknown_metrics = []
110+
111+
for metric in metrics:
112+
if isinstance(metric, GenAIMetric):
113+
genai_metrics.append(metric.name)
114+
elif isinstance(metric, str) and metric in SUPPORTED_TASK_TYPE_TO_METRICS_MAPPING[task_type].SUPPORTED_LIST:
115+
builtin_metrics.append(metric)
116+
else:
117+
unknown_metrics.append(metric)
118+
119+
if len(unknown_metrics) > 0:
120+
raise Exception("Unsupported metric found in the list")
121+
122+
# if len(set(genai_metrics) & set(builtin_metrics)) > 0:
123+
if len(genai_metrics) != len(set(genai_metrics)) or len(builtin_metrics) != len(set(builtin_metrics))\
124+
or (len(set(genai_metrics) & set(builtin_metrics)) > 0):
125+
raise Exception("Duplicate metric name found. Metric names should be unique")
126+
127+
86128
@distributed_trace
87129
@monitor_with_activity(package_logger, "Evaluate", ActivityType.PUBLICAPI)
88130
def evaluate(
@@ -223,7 +265,7 @@ def _evaluate(
223265
metrics_config.update({"openai_params": model_config})
224266

225267
if data_mapping:
226-
metrics_config.update(data_mapping)
268+
metrics_config.update({"data_mapping": data_mapping})
227269

228270
with mlflow.start_run(nested=True if mlflow.active_run() else False, run_name=evaluation_name) as run, \
229271
RedirectUserOutputStreams(logger=LOGGER) as _:
@@ -246,43 +288,81 @@ def _evaluate(
246288
**kwargs
247289
)
248290

249-
metrics_handler = MetricHandler(
250-
task_type=SUPPORTED_TO_METRICS_TASK_TYPE_MAPPING[task_type],
251-
metrics=metrics,
252-
prediction_data=asset_handler.prediction_data,
253-
truth_data=asset_handler.ground_truth,
254-
test_data=asset_handler.test_data,
255-
metrics_mapping=metrics_config,
256-
prediction_data_column_name=prediction_data if isinstance(prediction_data, str) else None,
257-
ground_truth_column_name=truth_data if isinstance(truth_data, str) else None,
258-
)
291+
metrics_results = {"artifacts": {}, "metrics": {}}
259292

260-
metrics = metrics_handler.calculate_metrics()
293+
if metrics is None:
294+
metrics = SUPPORTED_TASK_TYPE_TO_METRICS_MAPPING[task_type].DEFAULT_LIST
261295

262-
def _get_instance_table():
263-
metrics.get("artifacts").pop("bertscore", None)
264-
if task_type == CHAT:
265-
instance_level_metrics_table = _get_chat_instance_table(metrics.get("artifacts"))
266-
else:
267-
instance_level_metrics_table = pd.DataFrame(metrics.get("artifacts"))
268-
269-
prediction_data = asset_handler.prediction_data
270-
for column in asset_handler.prediction_data.columns.values:
271-
if column in asset_handler.test_data.columns.values:
272-
prediction_data.drop(column, axis=1, inplace=True)
273-
274-
combined_table = pd.concat(
275-
[asset_handler.test_data,
276-
prediction_data,
277-
asset_handler.ground_truth,
278-
instance_level_metrics_table
279-
],
280-
axis=1,
281-
verify_integrity=True
296+
_validate_metrics(metrics, task_type)
297+
298+
inbuilt_metrics = [metric for metric in metrics if not isinstance(metric, GenAIMetric)]
299+
custom_prompt_metrics = [metric for metric in metrics if isinstance(metric, PromptMetric)]
300+
code_metrics = [metric for metric in metrics if isinstance(metric, CodeMetric)]
301+
302+
# TODO : Once PF is used for inbuilt metrics parallelize submission of metrics calculation of different kind
303+
304+
if custom_prompt_metrics:
305+
for metric in custom_prompt_metrics:
306+
metrics_config.setdefault(metric.name, {param: param for param in metric.parameters})
307+
308+
prompt_metric_handler = PromptMetricHandler(
309+
task_type="custom-prompt-metric",
310+
metrics=custom_prompt_metrics,
311+
prediction_data=asset_handler.prediction_data,
312+
truth_data=asset_handler.ground_truth,
313+
test_data=asset_handler.test_data,
314+
metrics_mapping=metrics_config,
315+
prediction_data_column_name=prediction_data if isinstance(prediction_data, str) else None,
316+
ground_truth_column_name=truth_data if isinstance(truth_data, str) else None,
317+
type_to_kwargs="custom-prompt-metric"
282318
)
283-
return combined_table
284319

285-
_log_metrics(run_id=run.info.run_id, metrics=metrics.get("metrics"))
320+
prompt_metric_results = prompt_metric_handler.calculate_metrics()
321+
322+
if prompt_metric_results is not None:
323+
for k, v in metrics_results.items():
324+
v.update(prompt_metric_results[k])
325+
326+
if code_metrics:
327+
code_metric_handler = CodeMetricHandler(
328+
task_type="custom-code-metric",
329+
metrics=code_metrics,
330+
prediction_data=asset_handler.prediction_data,
331+
truth_data=asset_handler.ground_truth,
332+
test_data=asset_handler.test_data,
333+
metrics_mapping=metrics_config,
334+
prediction_data_column_name=prediction_data if isinstance(prediction_data, str) else None,
335+
ground_truth_column_name=truth_data if isinstance(truth_data, str) else None,
336+
type_to_kwargs="code-prompt-metric"
337+
)
338+
339+
code_metric_results = code_metric_handler.calculate_metrics()
340+
341+
if code_metric_results is not None:
342+
for k, v in metrics_results.items():
343+
v.update(code_metric_results[k])
344+
345+
if inbuilt_metrics:
346+
inbuilt_metrics_handler = MetricHandler(
347+
task_type=SUPPORTED_TO_METRICS_TASK_TYPE_MAPPING[task_type],
348+
metrics=inbuilt_metrics,
349+
prediction_data=asset_handler.prediction_data,
350+
truth_data=asset_handler.ground_truth,
351+
test_data=asset_handler.test_data,
352+
metrics_mapping=metrics_config,
353+
prediction_data_column_name=prediction_data if isinstance(prediction_data, str) else None,
354+
ground_truth_column_name=truth_data if isinstance(truth_data, str) else None,
355+
type_to_kwargs=TYPE_TO_KWARGS_MAPPING[task_type]
356+
)
357+
358+
inbuilt_metrics_results = inbuilt_metrics_handler.calculate_metrics()
359+
360+
if inbuilt_metrics_results is not None:
361+
for k, v in metrics_results.items():
362+
v.update(inbuilt_metrics_results[k])
363+
364+
if metrics_results.get("metrics"):
365+
_log_metrics(run_id=run.info.run_id, metrics=metrics_results.get("metrics"))
286366

287367
with tempfile.TemporaryDirectory() as tmpdir:
288368
for param_name, param_value in kwargs.get("params_dict", {}).items():
@@ -310,7 +390,9 @@ def _get_instance_table():
310390
else:
311391
raise ex
312392

313-
eval_artifact_df = _get_instance_table().to_json(orient="records", lines=True, force_ascii=False)
393+
eval_artifact_df = _get_instance_table(metrics_results, task_type, asset_handler).to_json(orient="records",
394+
lines=True,
395+
force_ascii=False)
314396
tmp_path = os.path.join(tmpdir, "eval_results.jsonl")
315397

316398
with open(tmp_path, "w", encoding="utf-8") as f:
@@ -322,13 +404,12 @@ def _get_instance_table():
322404
mlflow.log_param("task_type", task_type)
323405
if task_type == CHAT:
324406
log_property("_azureml.chat_history_column", data_mapping.get("y_pred"))
325-
# log_param_and_tag("_azureml.evaluate_metric_mapping", json.dumps(metrics_handler._metrics_mapping_to_log))
326407

327408
if output_path:
328409
_copy_artifact(tmp_path, output_path)
329410

330411
evaluation_result = EvaluationResult(
331-
metrics_summary=metrics.get("metrics"),
412+
metrics_summary=metrics_results.get("metrics"),
332413
artifacts={
333414
"eval_results.jsonl": f"runs:/{run.info.run_id}/eval_results.jsonl"
334415
},
@@ -396,3 +477,28 @@ def _get_chat_instance_table(metrics):
396477

397478
instance_level_metrics_table = pd.DataFrame(instance_table_metrics_dict)
398479
return instance_level_metrics_table
480+
481+
482+
def _get_instance_table(metrics, task_type, asset_handler):
483+
if metrics.get("artifacts"):
484+
metrics.get("artifacts").pop("bertscore", None)
485+
if task_type == CHAT:
486+
instance_level_metrics_table = _get_chat_instance_table(metrics.get("artifacts"))
487+
else:
488+
instance_level_metrics_table = pd.DataFrame(metrics.get("artifacts"))
489+
490+
prediction_data = asset_handler.prediction_data
491+
for column in asset_handler.prediction_data.columns.values:
492+
if column in asset_handler.test_data.columns.values:
493+
prediction_data.drop(column, axis=1, inplace=True)
494+
495+
combined_table = pd.concat(
496+
[asset_handler.test_data,
497+
prediction_data,
498+
asset_handler.ground_truth,
499+
instance_level_metrics_table
500+
],
501+
axis=1,
502+
verify_integrity=True
503+
)
504+
return combined_table

0 commit comments

Comments
 (0)