Skip to content

Commit 0284039

Browse files
committed
Fix
1 parent df6240d commit 0284039

File tree

2 files changed

+92
-30
lines changed

2 files changed

+92
-30
lines changed

src/api/search_index_manager.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional
1+
from typing import Any, Dict, Optional
22

33
import csv
44
import glob
@@ -7,7 +7,7 @@
77
import time
88

99
from azure.core.credentials_async import AsyncTokenCredential
10-
from azure.search.documents.aio import SearchClient
10+
from azure.search.documents.aio import AsyncSearchItemPaged, SearchClient
1111
from azure.search.documents.indexes.aio import SearchIndexClient
1212
from azure.core.exceptions import HttpResponseError
1313
from azure.search.documents.indexes.models import (
@@ -50,6 +50,11 @@ class SearchIndexManager:
5050
MIN_DIFF_CHARACTERS_IN_LINE = 5
5151
MIN_LINE_LENGTH = 5
5252

53+
_SEMANTIC_CONFIG = "semantic_search"
54+
_EMBEDDING_CONFIG = "embedding_config"
55+
_VECTORIZER = "search_vectorizer"
56+
57+
5358
def __init__(
5459
self,
5560
endpoint: str,
@@ -141,6 +146,32 @@ def _check_dimensions(self, vector_index_dimensions: Optional[int] = None) -> in
141146
raise ValueError("vector_index_dimensions is different from dimensions provided to constructor.")
142147
return vector_index_dimensions
143148

149+
async def _format_search_results(self, response: AsyncSearchItemPaged[Dict]) -> str:
150+
"""
151+
Format the output of search.
152+
153+
:param response: The search results.
154+
:return: The formatted response string.
155+
"""
156+
results = [f"{result['token']}, source: {result['document_reference']}" async for result in response]
157+
return "\n------\n".join(results)
158+
159+
async def semantic_search(self, message: str) -> str:
160+
"""
161+
Perform the semantic search on the search resource.
162+
163+
:param message: The customer question.
164+
:return: The context for the question.
165+
"""
166+
self._raise_if_no_index()
167+
response = await self._get_client().search(
168+
search_text=message,
169+
query_type="semantic",
170+
semantic_configuration_name=SearchIndexManager._SEMANTIC_CONFIG,
171+
)
172+
return await self._format_search_results(response)
173+
174+
144175
async def search(self, message: str) -> str:
145176
"""
146177
Search the message in the vector store.
@@ -160,8 +191,7 @@ async def search(self, message: str) -> str:
160191
)
161192
# This lag is necessary, despite it is not described in documentation.
162193
time.sleep(1)
163-
results = [f"{result['token']}, source: {result['document_reference']}" async for result in response]
164-
return "\n------\n".join(results)
194+
return await self._format_search_results(response)
165195

166196
async def create_index(
167197
self,
@@ -185,7 +215,7 @@ async def create_index(
185215
"""
186216
vector_index_dimensions = self._check_dimensions(vector_index_dimensions)
187217
try:
188-
self._index = await self._index_create()
218+
self._index = await self._index_create(vector_index_dimensions)
189219
return True
190220
except HttpResponseError:
191221
if raise_on_error:
@@ -194,33 +224,44 @@ async def create_index(
194224
self._index = await ix_client.get_index(self._index_name)
195225
return False
196226

197-
async def _index_create(self) -> SearchIndex:
198-
"""Create the index."""
227+
async def _index_create(self, vector_index_dimensions: int) -> SearchIndex:
228+
"""
229+
Create the index.
230+
231+
:param vector_index_dimensions: The number of dimensions in the vector index. This parameter is
232+
needed if the embedding parameter cannot be set for the given model. It can be
233+
figured out by loading the embeddings file, generated by build_embeddings_file,
234+
loading the contents of the first row and 'embedding' column as a JSON and calculating
235+
the length of the list obtained.
236+
Also please see the embedding model documentation
237+
https://platform.openai.com/docs/models#embeddings
238+
:return: The newly created search index.
239+
"""
199240
async with SearchIndexClient(endpoint=self._endpoint, credential=self._credential) as ix_client:
200241
fields = [
201242
SimpleField(name="embedId", type=SearchFieldDataType.String, key=True),
202243
SearchField(
203244
name="embedding",
204245
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
205-
vector_search_dimensions=self._dimensions,
246+
vector_search_dimensions=vector_index_dimensions,
206247
searchable=True,
207-
vector_search_profile_name="embedding_config"
248+
vector_search_profile_name=SearchIndexManager._EMBEDDING_CONFIG
208249
),
209250
SearchField(name="token", searchable=True, type=SearchFieldDataType.String, hidden=False),
210251
SearchField(name="document_reference", type=SearchFieldDataType.String, hidden=False),
211252
]
212253
vector_search = VectorSearch(
213254
profiles=[
214255
VectorSearchProfile(
215-
name="embedding_config",
256+
name=SearchIndexManager._EMBEDDING_CONFIG,
216257
algorithm_configuration_name="embed-algorithms-config",
217-
vectorizer_name="search_vectorizer"
258+
vectorizer_name=SearchIndexManager._VECTORIZER
218259
)
219260
],
220261
algorithms=[HnswAlgorithmConfiguration(name="embed-algorithms-config")],
221262
vectorizers=[
222263
AzureOpenAIVectorizer(
223-
vectorizer_name="search_vectorizer",
264+
vectorizer_name=SearchIndexManager._VECTORIZER,
224265
parameters=AzureOpenAIVectorizerParameters(
225266
resource_url=self._embeddings_endpoint,
226267
deployment_name=self._embedding_deployment,
@@ -231,10 +272,10 @@ async def _index_create(self) -> SearchIndex:
231272
]
232273
)
233274
semantic_search = SemanticSearch(
234-
default_configuration_name="index_search",
275+
default_configuration_name=SearchIndexManager._SEMANTIC_CONFIG,
235276
configurations=[
236277
SemanticConfiguration(
237-
name="index_search",
278+
name=SearchIndexManager._SEMANTIC_CONFIG,
238279
prioritized_fields=SemanticPrioritizedFields(
239280
title_field=SemanticField(field_name="embedId"),
240281
content_fields=[SemanticField(field_name="token")]

src/gunicorn.conf.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Dict
55

66
import asyncio
7+
import csv
78
import json
89
import logging
910
import multiprocessing
@@ -101,6 +102,18 @@ async def create_index_maybe(
101102
await search_mgr.close()
102103

103104

105+
def _get_file_path(file_name: str) -> str:
106+
"""
107+
Get absolute file path.
108+
109+
:param file_name: The file name.
110+
"""
111+
return os.path.abspath(
112+
os.path.join(os.path.dirname(__file__),
113+
'files',
114+
file_name))
115+
116+
104117
async def get_available_toolset(
105118
ai_client: AIProjectClient,
106119
creds: AsyncTokenCredential) -> AsyncToolSet:
@@ -111,10 +124,12 @@ async def get_available_toolset(
111124
:param creds: The credentials, used for the index.
112125
:return: The tool set, available based on the environment.
113126
"""
127+
# File name -> {"id": file_id, "path": file_path}
128+
files: Dict[str, Dict[str, str]] = {}
114129
# First try to get an index search.
115130
conn_id = ""
116131
if os.environ.get('AZURE_AI_SEARCH_INDEX_NAME'):
117-
conn_list = ai_client.connections.list()
132+
conn_list = await ai_client.connections.list()
118133
for conn in conn_list:
119134
if conn.connection_type == ConnectionType.AZURE_AI_SEARCH:
120135
conn_id = conn.id
@@ -130,26 +145,28 @@ async def get_available_toolset(
130145

131146
toolset.add(ai_search)
132147
logger.info("agent: initialized index")
148+
# Populate file links.
149+
embeddings_path = os.path.join(
150+
os.path.dirname(__file__), 'data', 'embeddings.csv')
151+
with open(embeddings_path, newline='') as fp:
152+
reader = csv.DictReader(fp)
153+
for row in reader:
154+
if row['document_reference'] in FILES_NAMES:
155+
files[row['document_reference']] = {
156+
"id": row['document_reference'],
157+
"path": _get_file_path(row['document_reference'])
158+
}
133159
else:
134160
logger.info(
135161
"agent: index was not initialized, falling back to file search.")
136162
# Upload files for file search
137-
# File name -> {"id": file_id, "path": file_path}
138-
files: Dict[str, Dict[str, str]] = {}
139163
for file_name in FILES_NAMES:
140-
file_path = os.path.abspath(
141-
os.path.join(
142-
os.path.dirname(__file__),
143-
'files',
144-
file_name))
164+
file_path = _get_file_path(file_name)
145165
file = await ai_client.agents.upload_file_and_poll(
146166
file_path=file_path, purpose=FilePurpose.AGENTS)
147167
# Store both file id and the file path using the file name as key.
148168
files[file_name] = {"id": file.id, "path": file_path}
149169

150-
# Serialize and store files information in the environment variable (so
151-
# workers see it)
152-
os.environ["UPLOADED_FILE_MAP"] = json.dumps(files)
153170
logger.info(
154171
f"Set env UPLOADED_FILE_MAP = {os.environ['UPLOADED_FILE_MAP']}")
155172

@@ -162,6 +179,9 @@ async def get_available_toolset(
162179

163180
file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id])
164181
toolset.add(file_search_tool)
182+
# Serialize and store files information in the environment variable (so
183+
# workers see it)
184+
os.environ["UPLOADED_FILE_MAP"] = json.dumps(files)
165185
return toolset
166186

167187

@@ -173,7 +193,7 @@ async def create_agent(ai_client: AIProjectClient,
173193
model=os.environ["AZURE_AI_AGENT_DEPLOYMENT_NAME"],
174194
name=os.environ["AZURE_AI_AGENT_NAME"],
175195
instructions="You are helpful assistant",
176-
toolset=await get_available_toolset()
196+
toolset=await get_available_toolset(ai_client, creds)
177197
)
178198
return agent
179199

@@ -209,7 +229,7 @@ async def initialize_resources():
209229
os.environ["AZURE_AI_AGENT_ID"])
210230
logger.info(f"Found agent by ID: {agent.id}")
211231
# Update the agent with the latest resources
212-
agent = await update_agent(agent, ai_client)
232+
agent = await update_agent(agent, ai_client, creds)
213233
return
214234
except Exception as e:
215235
logger.warning(
@@ -221,18 +241,19 @@ async def initialize_resources():
221241
if agent_list.data:
222242
for agent_object in agent_list.data:
223243
if agent_object.name == os.environ[
224-
"AZURE_AI_AGENT_NAME"]:
244+
"AZURE_AI_AGENT_NAME"]:
225245
logger.info(
226246
"Found existing agent named "
227247
f"'{agent_object.name}'"
228248
f", ID: {agent_object.id}")
229249
os.environ["AZURE_AI_AGENT_ID"] = agent_object.id
230250
# Update the agent with the latest resources
231-
agent = await update_agent(agent_object, ai_client)
251+
agent = await update_agent(
252+
agent_object, ai_client, creds)
232253
return
233254

234255
# Create a new agent
235-
agent = await create_agent(ai_client)
256+
agent = await create_agent(ai_client, creds)
236257
os.environ["AZURE_AI_AGENT_ID"] = agent.id
237258
logger.info(f"Created agent, agent ID: {agent.id}")
238259

0 commit comments

Comments
 (0)