Skip to content

Commit bb072e6

Browse files
tot0Lucas Pickup
andauthored
Handle azure-search-documents b11 when creating VectorSearch config. (Azure#33001)
* Handle azure-search-documents b11 when creating VectorSearch config. * Adjust parameter casing --------- Co-authored-by: Lucas Pickup <[email protected]>
1 parent d93bb78 commit bb072e6

File tree

1 file changed

+49
-16
lines changed
  • sdk/ai/azure-ai-generative/azure/ai/generative/index/_tasks

1 file changed

+49
-16
lines changed

sdk/ai/azure-ai-generative/azure/ai/generative/index/_tasks/update_acs.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,21 @@
1111
from typing import List, Optional
1212

1313
import yaml
14+
from packaging import version as pkg_version
15+
16+
from azure.ai.generative.index._embeddings import EmbeddingsContainer, ReferenceEmbeddedDocument
17+
from azure.ai.generative.index._mlindex import MLIndex
18+
from azure.ai.generative.index._utils.connections import get_connection_credential
19+
from azure.ai.generative.index._utils.logging import (
20+
_logger_factory,
21+
enable_appinsights_logging,
22+
enable_stdout_logging,
23+
get_logger,
24+
packages_versions_for_compatibility,
25+
safe_mlflow_start_run,
26+
track_activity,
27+
version,
28+
)
1429
from azure.core.credentials import TokenCredential
1530
from azure.identity import DefaultAzureCredential
1631
from azure.search.documents import SearchClient
@@ -26,20 +41,6 @@
2641
SemanticSettings,
2742
SimpleField,
2843
)
29-
from azure.ai.generative.index._embeddings import EmbeddingsContainer, ReferenceEmbeddedDocument
30-
from azure.ai.generative.index._mlindex import MLIndex
31-
from azure.ai.generative.index._utils.connections import get_connection_credential
32-
from azure.ai.generative.index._utils.logging import (
33-
_logger_factory,
34-
enable_appinsights_logging,
35-
enable_stdout_logging,
36-
get_logger,
37-
packages_versions_for_compatibility,
38-
safe_mlflow_start_run,
39-
track_activity,
40-
version,
41-
)
42-
from packaging import version as pkg_version
4344

4445
logger = get_logger("update_acs")
4546

@@ -131,7 +132,38 @@ def create_search_index_sdk(acs_config: dict, credential, embeddings: Optional[E
131132
and embeddings.kind != "none"
132133
and "embedding" in acs_config[MLIndex.INDEX_FIELD_MAPPING_KEY]
133134
):
134-
if current_version >= pkg_version.parse("11.4.0b8"):
135+
if current_version >= pkg_version.parse("11.4.0b11"):
136+
from azure.search.documents.indexes.models import (
137+
HnswParameters,
138+
HnswVectorSearchAlgorithmConfiguration,
139+
VectorSearch,
140+
VectorSearchAlgorithmKind,
141+
VectorSearchProfile,
142+
)
143+
144+
vector_config_name = f"{acs_config[MLIndex.INDEX_FIELD_MAPPING_KEY]['embedding']}_config"
145+
hnsw_name = "azureml_default_hnsw_config"
146+
vector_search_args["vector_search"] = VectorSearch(
147+
algorithms=[
148+
HnswVectorSearchAlgorithmConfiguration(
149+
name=hnsw_name,
150+
kind=VectorSearchAlgorithmKind.HNSW,
151+
parameters=HnswParameters(
152+
m=4,
153+
ef_construction=400,
154+
ef_search=500,
155+
metric="cosine",
156+
),
157+
)
158+
],
159+
profiles=[
160+
VectorSearchProfile(
161+
name=vector_config_name,
162+
algorithm=hnsw_name,
163+
),
164+
],
165+
)
166+
elif current_version >= pkg_version.parse("11.4.0b8"):
135167
from azure.search.documents.indexes.models import HnswVectorSearchAlgorithmConfiguration, VectorSearch
136168

137169
vector_search_args["vector_search"] = VectorSearch(
@@ -463,6 +495,7 @@ def main(args, logger, activity_logger):
463495
elif "endpoint_key_name" in acs_config:
464496
connection_args["connection_type"] = "workspace_keyvault"
465497
from azureml.core import Run
498+
466499
run = Run.get_context()
467500
ws = run.experiment.workspace
468501
connection_args["connection"] = {
@@ -476,7 +509,7 @@ def main(args, logger, activity_logger):
476509
raw_embeddings_uri = args.embeddings
477510
logger.info(f"got embeddings uri as input: {raw_embeddings_uri}")
478511
splits = raw_embeddings_uri.split("/")
479-
embeddings_dir_name = splits.pop(len(splits)-2)
512+
embeddings_dir_name = splits.pop(len(splits) - 2)
480513
logger.info(f"extracted embeddings directory name: {embeddings_dir_name}")
481514
parent = "/".join(splits)
482515
logger.info(f"extracted embeddings container path: {parent}")

0 commit comments

Comments
 (0)