Skip to content

Commit a4540a8

Browse files
nick863singankit
andauthored
Fixes for API (#35528)
* Adding tests to capture groundedness with expected values * Evalutors API * Fix API * Fix * Fix linter * Fix --------- Co-authored-by: Ankit Singhal <[email protected]> Co-authored-by: Ankit Singhal <[email protected]>
1 parent 36ce8e8 commit a4540a8

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

sdk/ml/azure-ai-ml/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/ml/azure-ai-ml",
5-
"Tag": "python/ml/azure-ai-ml_bcde27db64"
5+
"Tag": "python/ml/azure-ai-ml_ae30eb5b40"
66
}

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_evaluator_operations.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class EvaluatorOperations(_ScopeDependentOperations):
6363
:type datastore_operations: ~azure.ai.ml.operations._datastore_operations.DatastoreOperations
6464
:param all_operations: All operations classes of an MLClient object.
6565
:type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
66+
:param kwargs: A dictionary of additional configuration parameters.
67+
:type kwargs: dict
6668
"""
6769

6870
# pylint: disable=unused-argument
@@ -92,7 +94,7 @@ def __init__(
9294

9395
@monitor_with_activity(ops_logger, "Evaluator.CreateOrUpdate", ActivityType.PUBLICAPI)
9496
def create_or_update( # type: ignore
95-
self, model: Union[Model, WorkspaceAssetReference]
97+
self, model: Union[Model, WorkspaceAssetReference], **kwargs: Any
9698
) -> Model: # TODO: Are we going to implement job_name?
9799
"""Returns created or updated model asset.
98100
@@ -125,15 +127,15 @@ def _raise_if_not_evaluator(self, properties: Optional[Dict[str, Any]], message:
125127
)
126128

127129
@monitor_with_activity(ops_logger, "Evaluator.Get", ActivityType.PUBLICAPI)
128-
def get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Model:
130+
def get(self, name: str, *, version: Optional[str] = None, label: Optional[str] = None, **kwargs) -> Model:
129131
"""Returns information about the specified model asset.
130132
131133
:param name: Name of the model.
132134
:type name: str
133-
:param version: Version of the model.
134-
:type version: str
135-
:param label: Label of the model. (mutually exclusive with version)
136-
:type label: str
135+
:keyword version: Version of the model.
136+
:paramtype version: str
137+
:keyword label: Label of the model. (mutually exclusive with version)
138+
:paramtype label: str
137139
:raises ~azure.ai.ml.exceptions.ValidationException: Raised if Model cannot be successfully validated.
138140
Details will be provided in the error message.
139141
:return: Model asset object.
@@ -150,7 +152,7 @@ def get(self, name: str, version: Optional[str] = None, label: Optional[str] = N
150152
return model
151153

152154
@monitor_with_activity(ops_logger, "Evaluator.Download", ActivityType.PUBLICAPI)
153-
def download(self, name: str, version: str, download_path: Union[PathLike, str] = ".") -> None:
155+
def download(self, name: str, version: str, download_path: Union[PathLike, str] = ".", **kwargs: Any) -> None:
154156
"""Download files related to a model.
155157
156158
:param name: Name of the model.
@@ -171,6 +173,7 @@ def list(
171173
stage: Optional[str] = None,
172174
*,
173175
list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
176+
**kwargs: Any,
174177
) -> Iterable[Model]:
175178
"""List all model assets in workspace.
176179

sdk/ml/azure-ai-ml/tests/evaluator/e2etests/test_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_crud_file(self, client: MLClient, randstr: Callable[[], str]) -> None:
6363
assert "is-evaluator" in model.properties and model.properties["is-evaluator"] == "true"
6464
assert re.match(LONG_URI_REGEX_FORMAT, model.path)
6565

66-
model = client.evaluators.get(model.name, "3")
66+
model = client.evaluators.get(name=model.name, version="3")
6767
assert model.name == model_name
6868
assert model.version == "3"
6969
assert model.description == "This is evaluator."
@@ -93,7 +93,7 @@ def test_crud_evaluator_with_stage(self, client: MLClient, randstr: Callable[[],
9393
assert model.stage == "Production"
9494
assert re.match(LONG_URI_REGEX_FORMAT, model.path)
9595

96-
model = client.evaluators.get(model.name, "3")
96+
model = client.evaluators.get(name=model.name, version="3")
9797
assert model.name == model_name
9898
assert model.version == "3"
9999
assert model.description == "This is evaluator."
@@ -108,7 +108,7 @@ def test_evaluators_get_latest_label(self, client: MLClient, randstr: Callable[[
108108
for version in ["1", "2", "3", "4"]:
109109
model = _load_flow(model_name, version=version)
110110
client.evaluators.create_or_update(model)
111-
assert client.evaluators.get(model_name, label="latest").version == version
111+
assert client.evaluators.get(name=model_name, label="latest").version == version
112112

113113
@pytest.mark.skip(
114114
"Skipping test for archive and restore as we have removed it from interface. "

0 commit comments

Comments
 (0)