Skip to content

Commit 1b3d100

Browse files
committed
More embedding related changes
1 parent 3e6d743 commit 1b3d100

File tree

9 files changed

+109
-80
lines changed

9 files changed

+109
-80
lines changed

app/backend/app.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,8 @@ async def setup_clients():
456456
AZURE_SEARCH_QUERY_SPELLER = os.getenv("AZURE_SEARCH_QUERY_SPELLER") or "lexicon"
457457
AZURE_SEARCH_SEMANTIC_RANKER = os.getenv("AZURE_SEARCH_SEMANTIC_RANKER", "free").lower()
458458
AZURE_SEARCH_QUERY_REWRITING = os.getenv("AZURE_SEARCH_QUERY_REWRITING", "false").lower()
459+
# This defaults to the previous field name "embedding", for backwards compatibility
460+
AZURE_SEARCH_FIELD_NAME_EMBEDDING = os.getenv("AZURE_SEARCH_FIELD_NAME_EMBEDDING", "embedding")
459461

460462
AZURE_SPEECH_SERVICE_ID = os.getenv("AZURE_SPEECH_SERVICE_ID")
461463
AZURE_SPEECH_SERVICE_LOCATION = os.getenv("AZURE_SPEECH_SERVICE_LOCATION")
@@ -662,6 +664,7 @@ async def setup_clients():
662664
embedding_model=OPENAI_EMB_MODEL,
663665
embedding_deployment=AZURE_OPENAI_EMB_DEPLOYMENT,
664666
embedding_dimensions=OPENAI_EMB_DIMENSIONS,
667+
embedding_field=AZURE_SEARCH_FIELD_NAME_EMBEDDING,
665668
sourcepage_field=KB_FIELDS_SOURCEPAGE,
666669
content_field=KB_FIELDS_CONTENT,
667670
query_language=AZURE_SEARCH_QUERY_LANGUAGE,
@@ -679,6 +682,7 @@ async def setup_clients():
679682
embedding_model=OPENAI_EMB_MODEL,
680683
embedding_deployment=AZURE_OPENAI_EMB_DEPLOYMENT,
681684
embedding_dimensions=OPENAI_EMB_DIMENSIONS,
685+
embedding_field=AZURE_SEARCH_FIELD_NAME_EMBEDDING,
682686
sourcepage_field=KB_FIELDS_SOURCEPAGE,
683687
content_field=KB_FIELDS_CONTENT,
684688
query_language=AZURE_SEARCH_QUERY_LANGUAGE,
@@ -704,6 +708,7 @@ async def setup_clients():
704708
embedding_model=OPENAI_EMB_MODEL,
705709
embedding_deployment=AZURE_OPENAI_EMB_DEPLOYMENT,
706710
embedding_dimensions=OPENAI_EMB_DIMENSIONS,
711+
embedding_field=AZURE_SEARCH_FIELD_NAME_EMBEDDING,
707712
sourcepage_field=KB_FIELDS_SOURCEPAGE,
708713
content_field=KB_FIELDS_CONTENT,
709714
query_language=AZURE_SEARCH_QUERY_LANGUAGE,
@@ -725,6 +730,7 @@ async def setup_clients():
725730
embedding_model=OPENAI_EMB_MODEL,
726731
embedding_deployment=AZURE_OPENAI_EMB_DEPLOYMENT,
727732
embedding_dimensions=OPENAI_EMB_DIMENSIONS,
733+
embedding_field=AZURE_SEARCH_FIELD_NAME_EMBEDDING,
728734
sourcepage_field=KB_FIELDS_SOURCEPAGE,
729735
content_field=KB_FIELDS_CONTENT,
730736
query_language=AZURE_SEARCH_QUERY_LANGUAGE,

app/backend/approaches/approach.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def serialize_for_results(self) -> dict[str, Any]:
4747
result_dict = {
4848
"id": self.id,
4949
"content": self.content,
50+
# Should we rename to its actual field name in the index?
51+
"embedding": Document.trim_embedding(self.embedding),
5052
"imageEmbedding": Document.trim_embedding(self.image_embedding),
5153
"category": self.category,
5254
"sourcepage": self.sourcepage,
@@ -68,7 +70,6 @@ def serialize_for_results(self) -> dict[str, Any]:
6870
"score": self.score,
6971
"reranker_score": self.reranker_score,
7072
}
71-
result_dict[self.embedding_field] = Document.trim_embedding(self.embedding)
7273
return result_dict
7374

7475
@classmethod
@@ -258,7 +259,7 @@ class ExtraArgs(TypedDict, total=False):
258259
)
259260
query_vector = embedding.data[0].embedding
260261
# TODO: use optimizations from rag time journey 3
261-
return VectorizedQuery(vector=query_vector, k_nearest_neighbors=50, fields=self.embedding)
262+
return VectorizedQuery(vector=query_vector, k_nearest_neighbors=50, fields=self.embedding_field)
262263

263264
async def compute_image_embedding(self, q: str):
264265
endpoint = urljoin(self.vision_endpoint, "computervision/retrieval:vectorizeText")

app/backend/approaches/chatreadretrieveread.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
3535
embedding_model: str,
3636
embedding_dimensions: int,
37+
embedding_field: str,
3738
sourcepage_field: str,
3839
content_field: str,
3940
query_language: str,
@@ -48,6 +49,7 @@ def __init__(
4849
self.embedding_deployment = embedding_deployment
4950
self.embedding_model = embedding_model
5051
self.embedding_dimensions = embedding_dimensions
52+
self.embedding_field = embedding_field
5153
self.sourcepage_field = sourcepage_field
5254
self.content_field = content_field
5355
self.query_language = query_language

app/backend/approaches/retrievethenread.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
embedding_model: str,
2929
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
3030
embedding_dimensions: int,
31+
embedding_field: str,
3132
sourcepage_field: str,
3233
content_field: str,
3334
query_language: str,
@@ -43,6 +44,7 @@ def __init__(
4344
self.embedding_dimensions = embedding_dimensions
4445
self.chatgpt_deployment = chatgpt_deployment
4546
self.embedding_deployment = embedding_deployment
47+
self.embedding_field = embedding_field
4648
self.sourcepage_field = sourcepage_field
4749
self.content_field = content_field
4850
self.query_language = query_language

app/backend/prepdocslib/searchmanager.py

Lines changed: 87 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from azure.search.documents.indexes.models import (
77
AzureOpenAIVectorizer,
88
AzureOpenAIVectorizerParameters,
9+
BinaryQuantizationCompression,
910
HnswAlgorithmConfiguration,
1011
HnswParameters,
12+
RescoringOptions,
1113
SearchableField,
1214
SearchField,
1315
SearchFieldDataType,
@@ -18,8 +20,8 @@
1820
SemanticSearch,
1921
SimpleField,
2022
VectorSearch,
23+
VectorSearchCompressionRescoreStorageMethod,
2124
VectorSearchProfile,
22-
VectorSearchVectorizer,
2325
)
2426

2527
from .blobmanager import BlobManager
@@ -69,11 +71,44 @@ def __init__(
6971
self.embedding_field = embedding_field
7072
self.search_images = search_images
7173

72-
async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] = None):
74+
async def create_index(self):
7375
logger.info("Checking whether search index %s exists...", self.search_info.index_name)
7476

7577
async with self.search_info.create_search_index_client() as search_index_client:
7678

79+
vectorizer = None
80+
embedding_field = None
81+
if self.embeddings and isinstance(self.embeddings, AzureOpenAIEmbeddingService):
82+
vectorizer = AzureOpenAIVectorizer(
83+
vectorizer_name=f"{self.search_info.index_name}-vectorizer",
84+
parameters=AzureOpenAIVectorizerParameters(
85+
resource_url=self.embeddings.open_ai_endpoint,
86+
deployment_name=self.embeddings.open_ai_deployment,
87+
model_name=self.embeddings.open_ai_model_name,
88+
),
89+
)
90+
if self.embeddings:
91+
if self.embedding_dimensions is None:
92+
raise ValueError(
93+
"Embedding dimensions must be set in order to add an embedding field to the search index"
94+
)
95+
if self.embedding_field is None:
96+
raise ValueError(
97+
"Embedding field must be set in order to add an embedding field to the search index"
98+
)
99+
embedding_field = SearchField(
100+
name=self.embedding_field,
101+
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
102+
hidden=True,
103+
searchable=True,
104+
filterable=False,
105+
sortable=False,
106+
facetable=False,
107+
vector_search_dimensions=self.embedding_dimensions,
108+
vector_search_profile_name="embedding_config",
109+
stored=False,
110+
)
111+
77112
if self.search_info.index_name not in [name async for name in search_index_client.list_index_names()]:
78113
logger.info("Creating new search index %s", self.search_info.index_name)
79114
fields = [
@@ -95,17 +130,6 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]]
95130
type="Edm.String",
96131
analyzer_name=self.search_analyzer_name,
97132
),
98-
SearchField(
99-
name=self.embedding_field,
100-
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
101-
hidden=False,
102-
searchable=True,
103-
filterable=False,
104-
sortable=False,
105-
facetable=False,
106-
vector_search_dimensions=self.embedding_dimensions,
107-
vector_search_profile_name="embedding_config",
108-
),
109133
SimpleField(name="category", type="Edm.String", filterable=True, facetable=True),
110134
SimpleField(
111135
name="sourcepage",
@@ -160,27 +184,50 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]]
160184
),
161185
)
162186

163-
vectorizers = []
164-
if self.embeddings and isinstance(self.embeddings, AzureOpenAIEmbeddingService):
165-
logger.info(
166-
"Including vectorizer for search index %s, using Azure OpenAI service %s",
167-
self.search_info.index_name,
168-
self.embeddings.open_ai_service,
169-
)
170-
vectorizers.append(
171-
AzureOpenAIVectorizer(
172-
vectorizer_name=f"{self.search_info.index_name}-vectorizer",
173-
parameters=AzureOpenAIVectorizerParameters(
174-
resource_url=self.embeddings.open_ai_endpoint,
175-
deployment_name=self.embeddings.open_ai_deployment,
176-
model_name=self.embeddings.open_ai_model_name,
177-
),
187+
vector_search = None
188+
if self.embeddings:
189+
logger.info("Including embedding field in new index %s", self.search_info.index_name)
190+
fields.append(embedding_field)
191+
192+
vectorizers = []
193+
if vectorizer is not None:
194+
logger.info("Including vectorizer in new index %s", self.search_info.index_name)
195+
vectorizers.append(vectorizer)
196+
else:
197+
logger.info(
198+
"New index %s will not have vectorizer, since no Azure OpenAI service is set",
199+
self.search_info.index_name,
178200
)
179-
)
180-
else:
181-
logger.info(
182-
"Not including vectorizer for search index %s, no Azure OpenAI service found",
183-
self.search_info.index_name,
201+
202+
vector_search = VectorSearch(
203+
profiles=[
204+
VectorSearchProfile(
205+
name="embedding_config",
206+
algorithm_configuration_name="hnsw_config",
207+
compression_name="binary-quantization",
208+
**({"vectorizer_name": vectorizer.vectorizer_name if vectorizer else None}),
209+
),
210+
],
211+
algorithms=[
212+
HnswAlgorithmConfiguration(
213+
name="hnsw_config",
214+
parameters=HnswParameters(metric="cosine"),
215+
)
216+
],
217+
vectorizers=vectorizers,
218+
compressions=[
219+
BinaryQuantizationCompression(
220+
compression_name="binary-quantization",
221+
rescoring_options=RescoringOptions(
222+
enable_rescoring=True,
223+
default_oversampling=10,
224+
rescore_storage_method=VectorSearchCompressionRescoreStorageMethod.PRESERVE_ORIGINALS,
225+
),
226+
# Explicitly set deprecated parameters to None
227+
rerank_with_original_vectors=None,
228+
default_oversampling=None,
229+
)
230+
],
184231
)
185232

186233
index = SearchIndex(
@@ -196,22 +243,7 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]]
196243
)
197244
]
198245
),
199-
vector_search=VectorSearch(
200-
algorithms=[
201-
HnswAlgorithmConfiguration(
202-
name="hnsw_config",
203-
parameters=HnswParameters(metric="cosine"),
204-
)
205-
],
206-
profiles=[
207-
VectorSearchProfile(
208-
name="embedding_config",
209-
algorithm_configuration_name="hnsw_config",
210-
vectorizer_name=(f"{self.search_info.index_name}-vectorizer"),
211-
),
212-
],
213-
vectorizers=vectorizers,
214-
),
246+
vector_search=vector_search,
215247
)
216248

217249
await search_index_client.create_index(index)
@@ -229,45 +261,23 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]]
229261
),
230262
)
231263
await search_index_client.create_or_update_index(existing_index)
232-
# check if embedding field exists
233-
if not any(field.name == self.embedding_field for field in existing_index.fields):
264+
# check if embedding field exists - TODO: will this really work if we havent redfined vector search?
265+
if self.embeddings and not any(field.name == self.embedding_field for field in existing_index.fields):
234266
logger.info("Adding embedding field to index %s", self.search_info.index_name)
235-
existing_index.fields.append(
236-
SearchField(
237-
name=self.embedding_field,
238-
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
239-
hidden=False,
240-
searchable=True,
241-
filterable=False,
242-
sortable=False,
243-
facetable=False,
244-
# TODO: use optimizations here
245-
vector_search_dimensions=self.embedding_dimensions,
246-
vector_search_profile_name="embedding_config",
247-
),
248-
)
267+
existing_index.fields.append(embedding_field)
249268
await search_index_client.create_or_update_index(existing_index)
250269
if existing_index.vector_search is not None and (
251270
existing_index.vector_search.vectorizers is None
252271
or len(existing_index.vector_search.vectorizers) == 0
253272
):
254273
if self.embeddings is not None and isinstance(self.embeddings, AzureOpenAIEmbeddingService):
255274
logger.info("Adding vectorizer to search index %s", self.search_info.index_name)
256-
existing_index.vector_search.vectorizers = [
257-
AzureOpenAIVectorizer(
258-
vectorizer_name=f"{self.search_info.index_name}-vectorizer",
259-
parameters=AzureOpenAIVectorizerParameters(
260-
resource_url=self.embeddings.open_ai_endpoint,
261-
deployment_name=self.embeddings.open_ai_deployment,
262-
model_name=self.embeddings.open_ai_model_name,
263-
),
264-
)
265-
]
275+
existing_index.vector_search.vectorizers = [vectorizer]
266276
await search_index_client.create_or_update_index(existing_index)
267277
else:
268278
logger.info(
269-
"Can't add vectorizer to search index %s since no Azure OpenAI embeddings service is defined",
270-
self.search_info,
279+
"Search index %s will not have vectorizer, since no Azure OpenAI service is set",
280+
self.search_info.index_name,
271281
)
272282

273283
async def update_content(

app/backend/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ azure-monitor-opentelemetry==1.6.1
5757
# via -r requirements.in
5858
azure-monitor-opentelemetry-exporter==1.0.0b32
5959
# via azure-monitor-opentelemetry
60-
azure-search-documents==11.6.0b9
60+
azure-search-documents==11.6.0b11
6161
# via -r requirements.in
6262
azure-storage-blob==12.22.0
6363
# via

infra/main.bicep

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ param searchIndexName string // Set in main.parameters.json
2727
param searchQueryLanguage string // Set in main.parameters.json
2828
param searchQuerySpeller string // Set in main.parameters.json
2929
param searchServiceSemanticRankerLevel string // Set in main.parameters.json
30+
param searchFieldNameEmbedding string // Set in main.parameters.json
3031
var actualSearchServiceSemanticRankerLevel = (searchServiceSkuName == 'free')
3132
? 'disabled'
3233
: searchServiceSemanticRankerLevel
@@ -390,6 +391,7 @@ var appEnvVariables = {
390391
AZURE_VISION_ENDPOINT: useGPT4V ? computerVision.outputs.endpoint : ''
391392
AZURE_SEARCH_QUERY_LANGUAGE: searchQueryLanguage
392393
AZURE_SEARCH_QUERY_SPELLER: searchQuerySpeller
394+
AZURE_SEARCH_FIELD_NAME_EMBEDDING: searchFieldNameEmbedding
393395
APPLICATIONINSIGHTS_CONNECTION_STRING: useApplicationInsights
394396
? monitoring.outputs.applicationInsightsConnectionString
395397
: ''
@@ -1284,6 +1286,7 @@ output AZURE_SEARCH_SERVICE string = searchService.outputs.name
12841286
output AZURE_SEARCH_SERVICE_RESOURCE_GROUP string = searchServiceResourceGroup.name
12851287
output AZURE_SEARCH_SEMANTIC_RANKER string = actualSearchServiceSemanticRankerLevel
12861288
output AZURE_SEARCH_SERVICE_ASSIGNED_USERID string = searchService.outputs.principalId
1289+
output AZURE_SEARCH_FIELD_NAME_EMBEDDING string = searchFieldNameEmbedding
12871290

12881291
output AZURE_COSMOSDB_ACCOUNT string = (useAuthentication && useChatHistoryCosmos) ? cosmosDb.outputs.name : ''
12891292
output AZURE_CHAT_HISTORY_DATABASE string = chatHistoryDatabaseName

infra/main.parameters.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@
8383
"searchServiceQueryRewriting": {
8484
"value": "${AZURE_SEARCH_QUERY_REWRITING=false}"
8585
},
86+
"searchFieldNameEmbedding": {
87+
"value": "${AZURE_SEARCH_FIELD_NAME_EMBEDDING=embedding3}"
88+
},
8689
"storageAccountName": {
8790
"value": "${AZURE_STORAGE_ACCOUNT}"
8891
},

tests/e2e.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def run_server(port: int):
5757
"AZURE_SPEECH_SERVICE_LOCATION": "eastus",
5858
"AZURE_OPENAI_SERVICE": "test-openai-service",
5959
"AZURE_OPENAI_CHATGPT_MODEL": "gpt-4o-mini",
60+
"AZURE_OPENAI_EMB_MODEL_NAME": "text-embedding-3-large",
61+
"AZURE_OPENAI_EMB_DIMENSIONS": "3072",
6062
},
6163
clear=True,
6264
):

0 commit comments

Comments
 (0)