11
11
from typing import List , Optional
12
12
13
13
import yaml
14
+ from packaging import version as pkg_version
15
+
16
+ from azure .ai .generative .index ._embeddings import EmbeddingsContainer , ReferenceEmbeddedDocument
17
+ from azure .ai .generative .index ._mlindex import MLIndex
18
+ from azure .ai .generative .index ._utils .connections import get_connection_credential
19
+ from azure .ai .generative .index ._utils .logging import (
20
+ _logger_factory ,
21
+ enable_appinsights_logging ,
22
+ enable_stdout_logging ,
23
+ get_logger ,
24
+ packages_versions_for_compatibility ,
25
+ safe_mlflow_start_run ,
26
+ track_activity ,
27
+ version ,
28
+ )
14
29
from azure .core .credentials import TokenCredential
15
30
from azure .identity import DefaultAzureCredential
16
31
from azure .search .documents import SearchClient
26
41
SemanticSettings ,
27
42
SimpleField ,
28
43
)
29
- from azure .ai .generative .index ._embeddings import EmbeddingsContainer , ReferenceEmbeddedDocument
30
- from azure .ai .generative .index ._mlindex import MLIndex
31
- from azure .ai .generative .index ._utils .connections import get_connection_credential
32
- from azure .ai .generative .index ._utils .logging import (
33
- _logger_factory ,
34
- enable_appinsights_logging ,
35
- enable_stdout_logging ,
36
- get_logger ,
37
- packages_versions_for_compatibility ,
38
- safe_mlflow_start_run ,
39
- track_activity ,
40
- version ,
41
- )
42
- from packaging import version as pkg_version
43
44
44
45
logger = get_logger ("update_acs" )
45
46
@@ -131,7 +132,38 @@ def create_search_index_sdk(acs_config: dict, credential, embeddings: Optional[E
131
132
and embeddings .kind != "none"
132
133
and "embedding" in acs_config [MLIndex .INDEX_FIELD_MAPPING_KEY ]
133
134
):
134
- if current_version >= pkg_version .parse ("11.4.0b8" ):
135
+ if current_version >= pkg_version .parse ("11.4.0b11" ):
136
+ from azure .search .documents .indexes .models import (
137
+ HnswParameters ,
138
+ HnswVectorSearchAlgorithmConfiguration ,
139
+ VectorSearch ,
140
+ VectorSearchAlgorithmKind ,
141
+ VectorSearchProfile ,
142
+ )
143
+
144
+ vector_config_name = f"{ acs_config [MLIndex .INDEX_FIELD_MAPPING_KEY ]['embedding' ]} _config"
145
+ hnsw_name = "azureml_default_hnsw_config"
146
+ vector_search_args ["vector_search" ] = VectorSearch (
147
+ algorithms = [
148
+ HnswVectorSearchAlgorithmConfiguration (
149
+ name = hnsw_name ,
150
+ kind = VectorSearchAlgorithmKind .HNSW ,
151
+ parameters = HnswParameters (
152
+ m = 4 ,
153
+ ef_construction = 400 ,
154
+ ef_search = 500 ,
155
+ metric = "cosine" ,
156
+ ),
157
+ )
158
+ ],
159
+ profiles = [
160
+ VectorSearchProfile (
161
+ name = vector_config_name ,
162
+ algorithm = hnsw_name ,
163
+ ),
164
+ ],
165
+ )
166
+ elif current_version >= pkg_version .parse ("11.4.0b8" ):
135
167
from azure .search .documents .indexes .models import HnswVectorSearchAlgorithmConfiguration , VectorSearch
136
168
137
169
vector_search_args ["vector_search" ] = VectorSearch (
@@ -463,6 +495,7 @@ def main(args, logger, activity_logger):
463
495
elif "endpoint_key_name" in acs_config :
464
496
connection_args ["connection_type" ] = "workspace_keyvault"
465
497
from azureml .core import Run
498
+
466
499
run = Run .get_context ()
467
500
ws = run .experiment .workspace
468
501
connection_args ["connection" ] = {
@@ -476,7 +509,7 @@ def main(args, logger, activity_logger):
476
509
raw_embeddings_uri = args .embeddings
477
510
logger .info (f"got embeddings uri as input: { raw_embeddings_uri } " )
478
511
splits = raw_embeddings_uri .split ("/" )
479
- embeddings_dir_name = splits .pop (len (splits )- 2 )
512
+ embeddings_dir_name = splits .pop (len (splits ) - 2 )
480
513
logger .info (f"extracted embeddings directory name: { embeddings_dir_name } " )
481
514
parent = "/" .join (splits )
482
515
logger .info (f"extracted embeddings container path: { parent } " )
0 commit comments