Skip to content

Commit c6383aa

Browse files
Adarsh-RamanathanUbuntuMilesHolland
authored
address API review comments (#36058)
* address comments * reintroduce index * retrigger checks * lint fixes * retrigger checks * remove experimental tag * investigating build issues * revert experimental changes * make dataclass kw-only * commit docstring * make index config experimental * remove unused experimental import * correct build_index docstring/typehint * kevin's suggestion - remove redundant experimental warning * make all init args keyword only * re-order annotations * run black --------- Co-authored-by: Ubuntu <azureuser@adramadev0.1u2n51k150xetd4yrig4jsoeod.xx.internal.cloudapp.net> Co-authored-by: Miles Holland <[email protected]>
1 parent 892881a commit c6383aa

File tree

4 files changed

+58
-29
lines changed

4 files changed

+58
-29
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_indexes/input/_ai_search_config.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,24 @@
88
# Defines stuff related to the resulting created index, like the index type.
99

1010
from typing import Optional
11+
from azure.ai.ml._utils._experimental import experimental
1112

1213

14+
@experimental
1315
class AzureAISearchConfig:
1416
"""Config class for creating an Azure AI Search index.
1517
16-
:param ai_search_index_name: The name of the Azure AI Search index.
17-
:type ai_search_index_name: Optional[str]
18-
:param ai_search_connection_id: The Azure AI Search connection ID.
19-
:type ai_search_connection_id: Optional[str]
18+
:param index_name: The name of the Azure AI Search index.
19+
:type index_name: Optional[str]
20+
:param connection_id: The Azure AI Search connection ID.
21+
:type connection_id: Optional[str]
2022
"""
2123

2224
def __init__(
2325
self,
2426
*,
25-
ai_search_index_name: Optional[str] = None,
26-
ai_search_connection_id: Optional[str] = None,
27+
index_name: Optional[str] = None,
28+
connection_id: Optional[str] = None,
2729
) -> None:
28-
self.ai_search_index_name = ai_search_index_name
29-
self.ai_search_connection_id = ai_search_connection_id
30+
self.index_name = index_name
31+
self.connection_id = connection_id

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_indexes/input/_index_data_source.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# ---------------------------------------------------------
44
from typing import Union
55

6+
from azure.ai.ml._utils._experimental import experimental
67
from azure.ai.ml.entities._inputs_outputs import Input
78
from azure.ai.ml.constants._common import IndexInputType
89

@@ -12,6 +13,7 @@
1213

1314

1415
# Defines stuff related to supplying inputs for an index AKA the base data.
16+
@experimental
1517
class IndexDataSource:
1618
"""Base class for configs that define data that will be processed into an ML index.
1719
This class should not be instantiated directly. Use one of its child classes instead.
@@ -28,24 +30,26 @@ def __init__(self, *, input_type: Union[str, IndexInputType]):
2830
# Field bundle for creating an index from files located in a Git repo.
2931
# TODO Does git_url need to specifically be an SSH or HTTPS style link?
3032
# TODO What is git connection id?
33+
@experimental
3134
class GitSource(IndexDataSource):
3235
"""Config class for creating an ML index from files located in a git repository.
3336
34-
:param git_url: A link to the repository to use.
35-
:type git_url: str
36-
:param git_branch_name: The name of the branch to use from the target repository.
37-
:type git_branch_name: str
38-
:param git_connection_id: The connection ID for GitHub
39-
:type git_connection_id: str
37+
:param url: A link to the repository to use.
38+
:type url: str
39+
:param branch_name: The name of the branch to use from the target repository.
40+
:type branch_name: str
41+
:param connection_id: The connection ID for GitHub
42+
:type connection_id: str
4043
"""
4144

42-
def __init__(self, *, git_url: str, git_branch_name: str, git_connection_id: str):
43-
self.git_url = git_url
44-
self.git_branch_name = git_branch_name
45-
self.git_connection_id = git_connection_id
45+
def __init__(self, *, url: str, branch_name: str, connection_id: str):
46+
self.url = url
47+
self.branch_name = branch_name
48+
self.connection_id = connection_id
4649
super().__init__(input_type=IndexInputType.GIT)
4750

4851

52+
@experimental
4953
class LocalSource(IndexDataSource):
5054
"""Config class for creating an ML index from a collection of local files.
5155

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_indexes/model_config.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# ---------------------------------------------------------
44
from dataclasses import dataclass
55
from typing import Any, Dict, Optional
6+
from azure.ai.ml._utils._experimental import experimental
67
from azure.ai.ml._utils.utils import camel_to_snake
78
from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection
89
from azure.ai.ml.entities._workspace.connections.connection_subtypes import (
@@ -11,6 +12,7 @@
1112
)
1213

1314

15+
@experimental
1416
@dataclass
1517
class ModelConfiguration:
1618
"""Configuration for a embedding model.
@@ -42,12 +44,33 @@ class ModelConfiguration:
4244
deployment_name: Optional[str]
4345
model_kwargs: Dict[str, Any]
4446

47+
def __init__(
48+
self,
49+
*,
50+
api_base: Optional[str],
51+
api_key: Optional[str],
52+
api_version: Optional[str],
53+
connection_name: Optional[str],
54+
connection_type: Optional[str],
55+
model_name: Optional[str],
56+
deployment_name: Optional[str],
57+
model_kwargs: Dict[str, Any]
58+
):
59+
self.api_base = api_base
60+
self.api_key = api_key
61+
self.api_version = api_version
62+
self.connection_name = connection_name
63+
self.connection_type = connection_type
64+
self.model_name = model_name
65+
self.deployment_name = deployment_name
66+
self.model_kwargs = model_kwargs
67+
4568
@staticmethod
4669
def from_connection(
4770
connection: WorkspaceConnection,
4871
model_name: Optional[str] = None,
4972
deployment_name: Optional[str] = None,
50-
**model_kwargs
73+
**kwargs
5174
) -> "ModelConfiguration":
5275
"""Create an model configuration from a Connection.
5376
@@ -57,8 +80,8 @@ def from_connection(
5780
:type model_name: Optional[str]
5881
:param deployment_name: The name of the deployment.
5982
:type deployment_name: Optional[str]
60-
:keyword model_kwargs: Additional keyword arguments for the model.
61-
:paramtype model_kwargs: Dict[str, Any]
83+
:keyword kwargs: Additional keyword arguments for the model.
84+
:paramtype kwargs: Dict[str, Any]
6285
:return: The model configuration.
6386
:rtype: ~azure.ai.ml.entities._indexes.entities.ModelConfiguration
6487
:raises TypeError: If the connection is not an AzureOpenAIConnection.
@@ -97,5 +120,5 @@ def from_connection(
97120
connection_type=connection_type,
98121
model_name=model_name,
99122
deployment_name=deployment_name,
100-
model_kwargs=model_kwargs,
123+
model_kwargs=kwargs,
101124
)

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_index_operations.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def build_index(
270270
######## data source info ########
271271
input_source: Union[IndexDataSource, str],
272272
input_source_credential: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
273-
) -> Union["MLIndex", "Job"]: # type: ignore[name-defined]
273+
) -> Union["Index", "Job"]: # type: ignore[name-defined]
274274
"""Builds an index on the cloud using the Azure AI Resources service.
275275
276276
:keyword name: The name of the index to be created.
@@ -295,7 +295,7 @@ def build_index(
295295
:paramtype input_source_credential: Optional[Union[~azure.ai.ml.entities.ManagedIdentityConfiguration,
296296
~azure.ai.ml.entities.UserIdentityConfiguration]]
297297
:return: If the `source_input` is a GitSource, returns a created DataIndex Job object.
298-
:rtype: Union[~azure.ai.ml.entities._indexes.MLIndex, ~azure.ai.ml.entities.Job]
298+
:rtype: Union[~azure.ai.ml.entities.Index, ~azure.ai.ml.entities.Job]
299299
:raises ValueError: If the `source_input` is not type ~typing.Str or
300300
~azure.ai.ml.entities._indexes.LocalSource.
301301
"""
@@ -333,8 +333,8 @@ def build_index(
333333
index=(
334334
IndexStore(
335335
type="acs",
336-
connection=build_connection_id(index_config.ai_search_connection_id, self._operation_scope),
337-
name=index_config.ai_search_index_name,
336+
connection=build_connection_id(index_config.connection_id, self._operation_scope),
337+
name=index_config.index_name,
338338
)
339339
if index_config is not None
340340
else IndexStore(type="faiss")
@@ -390,9 +390,9 @@ def git_to_index(
390390
return index_job.outputs
391391

392392
git_index_job = git_to_index(
393-
git_url=input_source.git_url,
394-
branch_name=input_source.git_branch_name,
395-
git_connection_id=input_source.git_connection_id,
393+
git_url=input_source.url,
394+
branch_name=input_source.branch_name,
395+
git_connection_id=input_source.connection_id,
396396
)
397397
# Ensure repo cloned each run to get latest, comment out to have first clone reused.
398398
git_index_job.settings.force_rerun = True

0 commit comments

Comments
 (0)