Skip to content

Commit a18bf70

Browse files
committed
Fix manager and build new embeddings
1 parent 7d0728e commit a18bf70

File tree

2 files changed

+987
-22
lines changed

2 files changed

+987
-22
lines changed

src/api/search_index_manager.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
1-
from typing import Optional
1+
from typing import Any, Optional
22

33
import glob
44
import csv
55
import json
6+
import os
67

78
from azure.core.credentials_async import AsyncTokenCredential
89
from azure.search.documents.aio import SearchClient
910
from azure.search.documents.indexes.aio import SearchIndexClient
10-
from azure.search.documents.models import VectorizedQuery
11+
from azure.core.exceptions import HttpResponseError
1112
from azure.search.documents.indexes.models import (
13+
AzureOpenAIVectorizer,
14+
AzureOpenAIVectorizerParameters,
15+
HnswAlgorithmConfiguration,
1216
SearchField,
13-
SearchFieldDataType,
14-
SimpleField,
17+
SearchFieldDataType,
1518
SearchIndex,
16-
VectorSearch,
17-
VectorSearchProfile,
18-
HnswAlgorithmConfiguration)
19-
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError
20-
from azure.search.documents.indexes.models import (
19+
SearchIndexerDataUserAssignedIdentity,
2120
SemanticSearch,
2221
SemanticConfiguration,
2322
SemanticPrioritizedFields,
2423
SemanticField,
25-
AzureOpenAIVectorizer,
26-
AzureOpenAIVectorizerParameters
24+
SimpleField,
25+
#VectorizedQuery,
26+
VectorSearch,
27+
VectorSearchProfile,
2728
)
2829

2930

@@ -40,7 +41,9 @@ class SearchIndexManager:
4041
must be the same as one use to build the file with embeddings.
4142
:param deployment_name: The name of the embedding deployment.
4243
:param embeddings_endpoint: The the endpoint used for embedding.
43-
:param auth_identity: the managed identity used to access the embedding deployment.
44+
:param auth_identity: the managed identity used to access the embedding deployment.
45+
:param embedding_client: The embedding client, used t build the embedding. Needed only
46+
to create embedding file. Not used in inference time.
4447
"""
4548

4649
MIN_DIFF_CHARACTERS_IN_LINE = 5
@@ -55,7 +58,8 @@ def __init__(
5558
model: str,
5659
deployment_name: str,
5760
embedding_endpoint: str,
58-
auth_identity: str
61+
auth_identity: str,
62+
embedding_client: Optional[Any] = None
5963
) -> None:
6064
"""Constructor."""
6165
self._dimensions = dimensions
@@ -68,6 +72,7 @@ def __init__(
6872
self._embedding_deployment = deployment_name
6973
self._auth_identity = auth_identity
7074
self._client = None
75+
self._embedding_client = embedding_client
7176

7277
def _get_client(self):
7378
"""Get search client if it is absent."""
@@ -184,7 +189,9 @@ async def _index_create(self) -> SearchIndex:
184189
parameters=AzureOpenAIVectorizerParameters(
185190
resource_url=self._embeddings_endpoint,
186191
deployment_name=self._embedding_deployment,
187-
auth_identity=self._auth_identity,
192+
auth_identity=SearchIndexerDataUserAssignedIdentity(
193+
resource_id=self._auth_identity
194+
),
188195
model_name=self._embedding_model
189196
)
190197
)
@@ -194,9 +201,9 @@ async def _index_create(self) -> SearchIndex:
194201
default_configuration_name="index_search",
195202
configurations=[
196203
SemanticConfiguration(
197-
name="search_contents",
204+
name="index_search",
198205
prioritized_fields=SemanticPrioritizedFields(
199-
title_field="embedId",
206+
title_field=SemanticField(field_name="embedId"),
200207
content_fields=[SemanticField(field_name="token")]
201208
)
202209
)
@@ -215,7 +222,7 @@ async def build_embeddings_file(
215222
self,
216223
input_directory: str,
217224
output_file: str,
218-
sentences_per_embedding: int=4
225+
sentences_per_embedding: int=4,
219226
) -> None:
220227
"""
221228
In this method we do lazy loading of nltk and download the needed data set to split
@@ -230,14 +237,14 @@ async def build_embeddings_file(
230237
:param embeddings_client: The embedding client, used to create embeddings.
231238
Must be the same as the one used for SearchIndexManager creation.
232239
:param sentences_per_embedding: The number of sentences used to build embedding.
233-
:param model: The embedding model to be used.
234240
"""
235241
import nltk
236242
nltk.download('punkt')
237243

238244
from nltk.tokenize import sent_tokenize
239245
# Split the data to sentence tokens.
240246
sentence_tokens = []
247+
references = []
241248
globs = glob.glob(input_directory + '/*.md', recursive=True)
242249
index = 0
243250
for fle in globs:
@@ -250,6 +257,7 @@ async def build_embeddings_file(
250257
for sentence in sent_tokenize(line):
251258
if index % sentences_per_embedding == 0:
252259
sentence_tokens.append(sentence)
260+
references.append(os.path.split(fle)[-1])
253261
else:
254262
sentence_tokens[-1] += ' '
255263
sentence_tokens[-1] += sentence
@@ -259,16 +267,19 @@ async def build_embeddings_file(
259267
# For each token build the embedding, which will be used in the search.
260268
batch_size = 2000
261269
with open(output_file, 'w') as fp:
262-
writer = csv.DictWriter(fp, fieldnames=['token', 'embedding'])
270+
writer = csv.DictWriter(fp, fieldnames=['token', 'embedding', 'document_reference'])
263271
writer.writeheader()
264272
for i in range(0, len(sentence_tokens), batch_size):
265-
emedding = (await self._embeddings_client.embed(
273+
emedding = (await self._embedding_client.embed(
266274
input=sentence_tokens[i:i+min(batch_size, len(sentence_tokens))],
267275
dimensions=self._dimensions,
268276
model=self._embedding_model
269277
))["data"]
270-
for token, float_data in zip(sentence_tokens, emedding):
271-
writer.writerow({'token': token, 'embedding': json.dumps(float_data['embedding'])})
278+
for token, float_data, reference in zip(sentence_tokens, emedding, references):
279+
writer.writerow({
280+
'token': token,
281+
'embedding': json.dumps(float_data['embedding']),
282+
'document_reference': reference})
272283

273284
async def close(self):
274285
"""Close the closeable resources, associated with SearchIndexManager."""

0 commit comments

Comments
 (0)