Skip to content

Commit 041fa38

Browse files
committed
Add common embedding method
1 parent 5f38cfe commit 041fa38

File tree

2 files changed

+35
-36
lines changed

2 files changed

+35
-36
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from azure.identity import DefaultAzureCredential
4-
from openai import AsyncAzureOpenAI
54
from azure.core.credentials import AzureKeyCredential
6-
from azure.search.documents.models import VectorizedQuery, QueryType
5+
from azure.search.documents.models import QueryType, VectorizableTextQuery
76
from azure.search.documents.aio import SearchClient
87
from text_2_sql_core.utils.environment import IdentityType, get_identity_type
98
import os
@@ -12,9 +11,13 @@
1211
from datetime import datetime, timezone
1312
import json
1413
from typing import Annotated
14+
from text_2_sql_core.connectors.open_ai import OpenAIConnector
1515

1616

1717
class AISearchConnector:
18+
def __init__(self):
19+
self.open_ai_connector = OpenAIConnector()
20+
1821
async def run_ai_search_query(
1922
self,
2023
query,
@@ -30,35 +33,18 @@ async def run_ai_search_query(
3033
identity_type = get_identity_type()
3134

3235
if len(vector_fields) > 0:
33-
async with AsyncAzureOpenAI(
34-
# This is the default and can be omitted
35-
api_key=os.environ["OpenAI__ApiKey"],
36-
azure_endpoint=os.environ["OpenAI__Endpoint"],
37-
api_version=os.environ["OpenAI__ApiVersion"],
38-
) as open_ai_client:
39-
embeddings = await open_ai_client.embeddings.create(
40-
model=os.environ["OpenAI__EmbeddingModel"], input=query
41-
)
42-
43-
# Extract the embedding vector
44-
embedding_vector = embeddings.data[0].embedding
45-
4636
vector_query = [
47-
VectorizedQuery(
48-
vector=embedding_vector,
37+
VectorizableTextQuery(
38+
text=query,
4939
k_nearest_neighbors=7,
5040
fields=",".join(vector_fields),
5141
)
5242
]
5343
else:
5444
vector_query = None
5545

56-
if identity_type == IdentityType.SYSTEM_ASSIGNED:
46+
if identity_type in [IdentityType.SYSTEM_ASSIGNED, IdentityType.USER_ASSIGNED]:
5747
credential = DefaultAzureCredential()
58-
elif identity_type == IdentityType.USER_ASSIGNED:
59-
credential = DefaultAzureCredential(
60-
managed_identity_client_id=os.environ["ClientID"]
61-
)
6248
else:
6349
credential = AzureKeyCredential(
6450
os.environ["AIService__AzureSearchOptions__Key"]
@@ -253,7 +239,9 @@ async def get_entity_schemas(
253239
logging.info("Filtered Schemas: %s", filtered_schemas)
254240
return filtered_schemas
255241

256-
async def add_entry_to_index(document: dict, vector_fields: dict, index_name: str):
242+
async def add_entry_to_index(
243+
self, document: dict, vector_fields: dict, index_name: str
244+
):
257245
"""Add an entry to the search index."""
258246

259247
logging.info("Document: %s", document)
@@ -270,20 +258,13 @@ async def add_entry_to_index(document: dict, vector_fields: dict, index_name: st
270258
document["DateLastModified"] = datetime.now(timezone.utc)
271259

272260
try:
273-
async with AsyncAzureOpenAI(
274-
# This is the default and can be omitted
275-
api_key=os.environ["OpenAI__ApiKey"],
276-
azure_endpoint=os.environ["OpenAI__Endpoint"],
277-
api_version=os.environ["OpenAI__ApiVersion"],
278-
) as open_ai_client:
279-
embeddings = await open_ai_client.embeddings.create(
280-
model=os.environ["OpenAI__EmbeddingModel"],
281-
input=fields_to_embed.values(),
282-
)
261+
embeddings = await self.open_ai_connector.run_embedding_request(
262+
list(fields_to_embed.values())
263+
)
283264

284-
# Extract the embedding vector
285-
for i, field in enumerate(vector_fields.values()):
286-
document[field] = embeddings.data[i].embedding
265+
# Extract the embedding vector
266+
for i, field in enumerate(vector_fields.values()):
267+
document[field] = embeddings.data[i].embedding
287268

288269
document["Id"] = base64.urlsafe_b64encode(
289270
document["Question"].encode()

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,21 @@ async def run_completion_request(
5858
max_tokens=max_tokens,
5959
)
6060
return response.choices[0].message.content
61+
62+
async def run_embedding_request(self, batch: list[str]):
63+
token_provider, api_key = self.get_authentication_properties()
64+
65+
model_deployment = os.environ["OpenAI__EmbeddingModel"]
66+
async with AsyncAzureOpenAI(
67+
azure_deployment=model_deployment,
68+
api_version=os.environ["OpenAI__ApiVersion"],
69+
azure_endpoint=os.environ["OpenAI__Endpoint"],
70+
azure_ad_token_provider=token_provider,
71+
api_key=api_key,
72+
) as open_ai_client:
73+
embeddings = await open_ai_client.embeddings.create(
74+
model=os.environ["OpenAI__EmbeddingModel"],
75+
input=batch,
76+
)
77+
78+
return embeddings

0 commit comments

Comments
 (0)