11# Copyright (c) Microsoft Corporation.
22# Licensed under the MIT License.
33from azure .identity import DefaultAzureCredential
4- from openai import AsyncAzureOpenAI
54from azure .core .credentials import AzureKeyCredential
6- from azure .search .documents .models import VectorizedQuery , QueryType
5+ from azure .search .documents .models import QueryType , VectorizableTextQuery
76from azure .search .documents .aio import SearchClient
87from text_2_sql_core .utils .environment import IdentityType , get_identity_type
98import os
1211from datetime import datetime , timezone
1312import json
1413from typing import Annotated
14+ from text_2_sql_core .connectors .open_ai import OpenAIConnector
1515
1616
1717class AISearchConnector :
18+ def __init__ (self ):
19+ self .open_ai_connector = OpenAIConnector ()
20+
1821 async def run_ai_search_query (
1922 self ,
2023 query ,
@@ -30,35 +33,18 @@ async def run_ai_search_query(
3033 identity_type = get_identity_type ()
3134
3235 if len (vector_fields ) > 0 :
33- async with AsyncAzureOpenAI (
34- # This is the default and can be omitted
35- api_key = os .environ ["OpenAI__ApiKey" ],
36- azure_endpoint = os .environ ["OpenAI__Endpoint" ],
37- api_version = os .environ ["OpenAI__ApiVersion" ],
38- ) as open_ai_client :
39- embeddings = await open_ai_client .embeddings .create (
40- model = os .environ ["OpenAI__EmbeddingModel" ], input = query
41- )
42-
43- # Extract the embedding vector
44- embedding_vector = embeddings .data [0 ].embedding
45-
4636 vector_query = [
47- VectorizedQuery (
48- vector = embedding_vector ,
37+ VectorizableTextQuery (
38+ text = query ,
4939 k_nearest_neighbors = 7 ,
5040 fields = "," .join (vector_fields ),
5141 )
5242 ]
5343 else :
5444 vector_query = None
5545
56- if identity_type == IdentityType .SYSTEM_ASSIGNED :
46+ if identity_type in [ IdentityType .SYSTEM_ASSIGNED , IdentityType . USER_ASSIGNED ] :
5747 credential = DefaultAzureCredential ()
58- elif identity_type == IdentityType .USER_ASSIGNED :
59- credential = DefaultAzureCredential (
60- managed_identity_client_id = os .environ ["ClientID" ]
61- )
6248 else :
6349 credential = AzureKeyCredential (
6450 os .environ ["AIService__AzureSearchOptions__Key" ]
@@ -253,7 +239,9 @@ async def get_entity_schemas(
253239 logging .info ("Filtered Schemas: %s" , filtered_schemas )
254240 return filtered_schemas
255241
256- async def add_entry_to_index (document : dict , vector_fields : dict , index_name : str ):
242+ async def add_entry_to_index (
243+ self , document : dict , vector_fields : dict , index_name : str
244+ ):
257245 """Add an entry to the search index."""
258246
259247 logging .info ("Document: %s" , document )
@@ -270,20 +258,13 @@ async def add_entry_to_index(document: dict, vector_fields: dict, index_name: st
270258 document ["DateLastModified" ] = datetime .now (timezone .utc )
271259
272260 try :
273- async with AsyncAzureOpenAI (
274- # This is the default and can be omitted
275- api_key = os .environ ["OpenAI__ApiKey" ],
276- azure_endpoint = os .environ ["OpenAI__Endpoint" ],
277- api_version = os .environ ["OpenAI__ApiVersion" ],
278- ) as open_ai_client :
279- embeddings = await open_ai_client .embeddings .create (
280- model = os .environ ["OpenAI__EmbeddingModel" ],
281- input = fields_to_embed .values (),
282- )
261+ embeddings = await self .open_ai_connector .run_embedding_request (
262+ list (fields_to_embed .values ())
263+ )
283264
284- # Extract the embedding vector
285- for i , field in enumerate (vector_fields .values ()):
286- document [field ] = embeddings .data [i ].embedding
265+ # Extract the embedding vector
266+ for i , field in enumerate (vector_fields .values ()):
267+ document [field ] = embeddings .data [i ].embedding
287268
288269 document ["Id" ] = base64 .urlsafe_b64encode (
289270 document ["Question" ].encode ()
0 commit comments