Skip to content

Commit 6d0d847

Browse files
authored
Adding typing for evaluate api (Azure#33113)
1 parent 7653dd5 commit 6d0d847

File tree

1 file changed

+12
-10
lines changed
  • sdk/ai/azure-ai-generative/azure/ai/generative/evaluate

1 file changed

+12
-10
lines changed

sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/_evaluate.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import time
99
import logging
1010
from pathlib import Path
11+
from typing import Callable, Optional, Dict, List
1112

1213
import mlflow
1314
import pandas as pd
@@ -81,15 +82,15 @@ def _log_metrics(run_id, metrics):
8182
@distributed_trace
8283
@monitor_with_activity(LOGGER, "Evaluate", ActivityType.PUBLICAPI)
8384
def evaluate(
84-
evaluation_name=None,
85-
target=None,
86-
data=None,
87-
task_type=None,
88-
sweep_args=None,
89-
metrics_list=None,
90-
model_config=None,
91-
data_mapping=None,
92-
output_path=None,
85+
*,
86+
evaluation_name: str = None,
87+
target: Optional[Callable] = None,
88+
data: Optional[str] = None,
89+
task_type: str = None,
90+
metrics_list: Optional[List[str]] = None,
91+
model_config: Dict[str, str] = None,
92+
data_mapping: Dict[str, str] = None,
93+
output_path: Optional[str] = None,
9394
**kwargs
9495
):
9596
"""Evaluates target or data with built-in evaluation metrics
@@ -138,6 +139,7 @@ def evaluate(
138139
if data_mapping:
139140
metrics_config.update(data_mapping)
140141

142+
sweep_args = kwargs.pop("sweep_args", None)
141143
if sweep_args:
142144
import itertools
143145
keys, values = zip(*sweep_args.items())
@@ -333,7 +335,6 @@ def _get_instance_table():
333335
return evaluation_result
334336

335337

336-
337338
def log_input(data, data_is_file):
338339
try:
339340
# Mlflow service supports only uri_folder, hence this is need to create a dir to log input data.
@@ -368,6 +369,7 @@ def log_property_and_tag(key, value, logger=LOGGER):
368369
_write_properties_to_run_history({key: value}, logger)
369370
mlflow.set_tag(key, value)
370371

372+
371373
def log_property(key, value, logger=LOGGER):
372374
_write_properties_to_run_history({key: value}, logger)
373375

0 commit comments

Comments
 (0)