1
- from typing import Any , Optional
1
+ from typing import Any , Dict , Optional
2
2
3
3
import csv
4
4
import glob
7
7
import time
8
8
9
9
from azure .core .credentials_async import AsyncTokenCredential
10
- from azure .search .documents .aio import SearchClient
10
+ from azure .search .documents .aio import AsyncSearchItemPaged , SearchClient
11
11
from azure .search .documents .indexes .aio import SearchIndexClient
12
12
from azure .core .exceptions import HttpResponseError
13
13
from azure .search .documents .indexes .models import (
@@ -50,6 +50,11 @@ class SearchIndexManager:
50
50
MIN_DIFF_CHARACTERS_IN_LINE = 5
51
51
MIN_LINE_LENGTH = 5
52
52
53
+ _SEMANTIC_CONFIG = "semantic_search"
54
+ _EMBEDDING_CONFIG = "embedding_config"
55
+ _VECTORIZER = "search_vectorizer"
56
+
57
+
53
58
def __init__ (
54
59
self ,
55
60
endpoint : str ,
@@ -141,6 +146,32 @@ def _check_dimensions(self, vector_index_dimensions: Optional[int] = None) -> in
141
146
raise ValueError ("vector_index_dimensions is different from dimensions provided to constructor." )
142
147
return vector_index_dimensions
143
148
149
+ async def _format_search_results (self , response : AsyncSearchItemPaged [Dict ]) -> str :
150
+ """
151
+ Format the output of search.
152
+
153
+ :param response: The search results.
154
+ :return: The formatted response string.
155
+ """
156
+ results = [f"{ result ['token' ]} , source: { result ['document_reference' ]} " async for result in response ]
157
+ return "\n ------\n " .join (results )
158
+
159
+ async def semantic_search (self , message : str ) -> str :
160
+ """
161
+ Perform the semantic search on the search resource.
162
+
163
+ :param message: The customer question.
164
+ :return: The context for the question.
165
+ """
166
+ self ._raise_if_no_index ()
167
+ response = await self ._get_client ().search (
168
+ search_text = message ,
169
+ query_type = "semantic" ,
170
+ semantic_configuration_name = SearchIndexManager ._SEMANTIC_CONFIG ,
171
+ )
172
+ return await self ._format_search_results (response )
173
+
174
+
144
175
async def search (self , message : str ) -> str :
145
176
"""
146
177
Search the message in the vector store.
@@ -160,8 +191,7 @@ async def search(self, message: str) -> str:
160
191
)
161
192
# This lag is necessary, despite it is not described in documentation.
162
193
time .sleep (1 )
163
- results = [f"{ result ['token' ]} , source: { result ['document_reference' ]} " async for result in response ]
164
- return "\n ------\n " .join (results )
194
+ return await self ._format_search_results (response )
165
195
166
196
async def create_index (
167
197
self ,
@@ -185,7 +215,7 @@ async def create_index(
185
215
"""
186
216
vector_index_dimensions = self ._check_dimensions (vector_index_dimensions )
187
217
try :
188
- self ._index = await self ._index_create ()
218
+ self ._index = await self ._index_create (vector_index_dimensions )
189
219
return True
190
220
except HttpResponseError :
191
221
if raise_on_error :
@@ -194,33 +224,44 @@ async def create_index(
194
224
self ._index = await ix_client .get_index (self ._index_name )
195
225
return False
196
226
197
- async def _index_create (self ) -> SearchIndex :
198
- """Create the index."""
227
+ async def _index_create (self , vector_index_dimensions : int ) -> SearchIndex :
228
+ """
229
+ Create the index.
230
+
231
+ :param vector_index_dimensions: The number of dimensions in the vector index. This parameter is
232
+ needed if the embedding parameter cannot be set for the given model. It can be
233
+ figured out by loading the embeddings file, generated by build_embeddings_file,
234
+ loading the contents of the first row and 'embedding' column as a JSON and calculating
235
+ the length of the list obtained.
236
+ Also please see the embedding model documentation
237
+ https://platform.openai.com/docs/models#embeddings
238
+ :return: The newly created search index.
239
+ """
199
240
async with SearchIndexClient (endpoint = self ._endpoint , credential = self ._credential ) as ix_client :
200
241
fields = [
201
242
SimpleField (name = "embedId" , type = SearchFieldDataType .String , key = True ),
202
243
SearchField (
203
244
name = "embedding" ,
204
245
type = SearchFieldDataType .Collection (SearchFieldDataType .Single ),
205
- vector_search_dimensions = self . _dimensions ,
246
+ vector_search_dimensions = vector_index_dimensions ,
206
247
searchable = True ,
207
- vector_search_profile_name = "embedding_config"
248
+ vector_search_profile_name = SearchIndexManager . _EMBEDDING_CONFIG
208
249
),
209
250
SearchField (name = "token" , searchable = True , type = SearchFieldDataType .String , hidden = False ),
210
251
SearchField (name = "document_reference" , type = SearchFieldDataType .String , hidden = False ),
211
252
]
212
253
vector_search = VectorSearch (
213
254
profiles = [
214
255
VectorSearchProfile (
215
- name = "embedding_config" ,
256
+ name = SearchIndexManager . _EMBEDDING_CONFIG ,
216
257
algorithm_configuration_name = "embed-algorithms-config" ,
217
- vectorizer_name = "search_vectorizer"
258
+ vectorizer_name = SearchIndexManager . _VECTORIZER
218
259
)
219
260
],
220
261
algorithms = [HnswAlgorithmConfiguration (name = "embed-algorithms-config" )],
221
262
vectorizers = [
222
263
AzureOpenAIVectorizer (
223
- vectorizer_name = "search_vectorizer" ,
264
+ vectorizer_name = SearchIndexManager . _VECTORIZER ,
224
265
parameters = AzureOpenAIVectorizerParameters (
225
266
resource_url = self ._embeddings_endpoint ,
226
267
deployment_name = self ._embedding_deployment ,
@@ -231,10 +272,10 @@ async def _index_create(self) -> SearchIndex:
231
272
]
232
273
)
233
274
semantic_search = SemanticSearch (
234
- default_configuration_name = "index_search" ,
275
+ default_configuration_name = SearchIndexManager . _SEMANTIC_CONFIG ,
235
276
configurations = [
236
277
SemanticConfiguration (
237
- name = "index_search" ,
278
+ name = SearchIndexManager . _SEMANTIC_CONFIG ,
238
279
prioritized_fields = SemanticPrioritizedFields (
239
280
title_field = SemanticField (field_name = "embedId" ),
240
281
content_fields = [SemanticField (field_name = "token" )]
0 commit comments