Skip to content

Commit 6688775

Browse files
Change OpenAI to use Identity Provider (#122)
1 parent d507b15 commit 6688775

File tree

5 files changed

+91
-133
lines changed

5 files changed

+91
-133
lines changed

deploy_ai_search/src/deploy_ai_search/environment.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,11 @@ def ai_search_credential(self) -> DefaultAzureCredential | AzureKeyCredential:
109109
Returns:
110110
DefaultAzureCredential | AzureKeyCredential: The ai search credential
111111
"""
112-
if self.identity_type == IdentityType.SYSTEM_ASSIGNED:
112+
if self.identity_type in [
113+
IdentityType.SYSTEM_ASSIGNED,
114+
IdentityType.USER_ASSIGNED,
115+
]:
113116
return DefaultAzureCredential()
114-
elif self.identity_type == IdentityType.USER_ASSIGNED:
115-
return DefaultAzureCredential(
116-
managed_identity_client_id=self.ai_search_identity_id
117-
)
118117
else:
119118
return AzureKeyCredential(
120119
os.environ.get("AIService__AzureSearchOptions__Key")
@@ -180,7 +179,7 @@ def storage_account_blob_container_name(self) -> str:
180179
if container is None:
181180
raise ValueError(
182181
f"""Populate environment variable 'StorageAccount__{
183-
self.normalised_indexer_type}__Container' with container name."""
182+
self.normalised_indexer_type}__Container' with container name."""
184183
)
185184

186185
return container

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_model_creator.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
from autogen_ext.models import AzureOpenAIChatCompletionClient
3+
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient
44
from text_2_sql_core.utils.environment import IdentityType, get_identity_type
55

66
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
@@ -30,21 +30,15 @@ def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
3030

3131
@classmethod
3232
def get_authentication_properties(cls) -> dict:
33-
if get_identity_type() == IdentityType.SYSTEM_ASSIGNED:
33+
if get_identity_type() in [
34+
IdentityType.SYSTEM_ASSIGNED,
35+
IdentityType.USER_ASSIGNED,
36+
]:
3437
# Create the token provider
3538
api_key = None
3639
token_provider = get_bearer_token_provider(
3740
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
3841
)
39-
elif get_identity_type() == IdentityType.USER_ASSIGNED:
40-
# Create the token provider
41-
api_key = None
42-
token_provider = get_bearer_token_provider(
43-
DefaultAzureCredential(
44-
managed_identity_client_id=os.environ["ClientId"]
45-
),
46-
"https://cognitiveservices.azure.com/.default",
47-
)
4842
else:
4943
token_provider = None
5044
api_key = os.environ["OpenAI__ApiKey"]
@@ -57,7 +51,7 @@ def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
5751
return AzureOpenAIChatCompletionClient(
5852
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
5953
model=os.environ["OpenAI__MiniCompletionDeployment"],
60-
api_version="2024-08-01-preview",
54+
api_version=os.environ["OpenAI__ApiVersion"],
6155
azure_endpoint=os.environ["OpenAI__Endpoint"],
6256
azure_ad_token_provider=token_provider,
6357
api_key=api_key,
@@ -75,7 +69,7 @@ def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
7569
return AzureOpenAIChatCompletionClient(
7670
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
7771
model=os.environ["OpenAI__CompletionDeployment"],
78-
api_version="2024-08-01-preview",
72+
api_version=os.environ["OpenAI__ApiVersion"],
7973
azure_endpoint=os.environ["OpenAI__Endpoint"],
8074
azure_ad_token_provider=token_provider,
8175
api_key=api_key,

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from azure.identity import DefaultAzureCredential
4-
from openai import AsyncAzureOpenAI
54
from azure.core.credentials import AzureKeyCredential
6-
from azure.search.documents.models import VectorizedQuery, QueryType
5+
from azure.search.documents.models import QueryType, VectorizableTextQuery
76
from azure.search.documents.aio import SearchClient
87
from text_2_sql_core.utils.environment import IdentityType, get_identity_type
98
import os
@@ -12,9 +11,13 @@
1211
from datetime import datetime, timezone
1312
import json
1413
from typing import Annotated
14+
from text_2_sql_core.connectors.open_ai import OpenAIConnector
1515

1616

1717
class AISearchConnector:
18+
def __init__(self):
19+
self.open_ai_connector = OpenAIConnector()
20+
1821
async def run_ai_search_query(
1922
self,
2023
query,
@@ -30,35 +33,18 @@ async def run_ai_search_query(
3033
identity_type = get_identity_type()
3134

3235
if len(vector_fields) > 0:
33-
async with AsyncAzureOpenAI(
34-
# This is the default and can be omitted
35-
api_key=os.environ["OpenAI__ApiKey"],
36-
azure_endpoint=os.environ["OpenAI__Endpoint"],
37-
api_version=os.environ["OpenAI__ApiVersion"],
38-
) as open_ai_client:
39-
embeddings = await open_ai_client.embeddings.create(
40-
model=os.environ["OpenAI__EmbeddingModel"], input=query
41-
)
42-
43-
# Extract the embedding vector
44-
embedding_vector = embeddings.data[0].embedding
45-
4636
vector_query = [
47-
VectorizedQuery(
48-
vector=embedding_vector,
37+
VectorizableTextQuery(
38+
text=query,
4939
k_nearest_neighbors=7,
5040
fields=",".join(vector_fields),
5141
)
5242
]
5343
else:
5444
vector_query = None
5545

56-
if identity_type == IdentityType.SYSTEM_ASSIGNED:
46+
if identity_type in [IdentityType.SYSTEM_ASSIGNED, IdentityType.USER_ASSIGNED]:
5747
credential = DefaultAzureCredential()
58-
elif identity_type == IdentityType.USER_ASSIGNED:
59-
credential = DefaultAzureCredential(
60-
managed_identity_client_id=os.environ["ClientID"]
61-
)
6248
else:
6349
credential = AzureKeyCredential(
6450
os.environ["AIService__AzureSearchOptions__Key"]
@@ -253,7 +239,9 @@ async def get_entity_schemas(
253239
logging.info("Filtered Schemas: %s", filtered_schemas)
254240
return filtered_schemas
255241

256-
async def add_entry_to_index(document: dict, vector_fields: dict, index_name: str):
242+
async def add_entry_to_index(
243+
self, document: dict, vector_fields: dict, index_name: str
244+
):
257245
"""Add an entry to the search index."""
258246

259247
logging.info("Document: %s", document)
@@ -270,31 +258,23 @@ async def add_entry_to_index(document: dict, vector_fields: dict, index_name: st
270258
document["DateLastModified"] = datetime.now(timezone.utc)
271259

272260
try:
273-
async with AsyncAzureOpenAI(
274-
# This is the default and can be omitted
275-
api_key=os.environ["OpenAI__ApiKey"],
276-
azure_endpoint=os.environ["OpenAI__Endpoint"],
277-
api_version=os.environ["OpenAI__ApiVersion"],
278-
) as open_ai_client:
279-
embeddings = await open_ai_client.embeddings.create(
280-
model=os.environ["OpenAI__EmbeddingModel"],
281-
input=fields_to_embed.values(),
282-
)
261+
embeddings = await self.open_ai_connector.run_embedding_request(
262+
list(fields_to_embed.values())
263+
)
283264

284-
# Extract the embedding vector
285-
for i, field in enumerate(vector_fields.values()):
286-
document[field] = embeddings.data[i].embedding
265+
# Extract the embedding vector
266+
for i, field in enumerate(vector_fields.values()):
267+
document[field] = embeddings.data[i].embedding
287268

288269
document["Id"] = base64.urlsafe_b64encode(
289270
document["Question"].encode()
290271
).decode("utf-8")
291272

292-
if identity_type == IdentityType.SYSTEM_ASSIGNED:
273+
if identity_type in [
274+
IdentityType.SYSTEM_ASSIGNED,
275+
IdentityType.USER_ASSIGNED,
276+
]:
293277
credential = DefaultAzureCredential()
294-
elif identity_type == IdentityType.USER_ASSIGNED:
295-
credential = DefaultAzureCredential(
296-
managed_identity_client_id=os.environ["ClientID"]
297-
)
298278
else:
299279
credential = AzureKeyCredential(
300280
os.environ["AIService__AzureSearchOptions__Key"]

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,61 @@
1212
class OpenAIConnector:
1313
@classmethod
1414
def get_authentication_properties(cls) -> dict:
15-
if get_identity_type() == IdentityType.SYSTEM_ASSIGNED:
15+
if get_identity_type() in [
16+
IdentityType.SYSTEM_ASSIGNED,
17+
IdentityType.USER_ASSIGNED,
18+
]:
1619
# Create the token provider
1720
api_key = None
1821
token_provider = get_bearer_token_provider(
1922
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
2023
)
21-
elif get_identity_type() == IdentityType.USER_ASSIGNED:
22-
# Create the token provider
23-
api_key = None
24-
token_provider = get_bearer_token_provider(
25-
DefaultAzureCredential(
26-
managed_identity_client_id=os.environ["ClientId"]
27-
),
28-
"https://cognitiveservices.azure.com/.default",
29-
)
3024
else:
3125
token_provider = None
3226
api_key = os.environ["OpenAI__ApiKey"]
3327

3428
return token_provider, api_key
3529

36-
async def run_completion_request(self, messages: list[dict], temperature=0):
30+
async def run_completion_request(
31+
self, messages: list[dict], temperature=0, max_tokens=2000, model="4o-mini"
32+
) -> str:
33+
if model == "4o-mini":
34+
model_deployment = os.environ["OpenAI__MiniCompletionDeployment"]
35+
elif model == "4o":
36+
model_deployment = os.environ["OpenAI__CompletionDeployment"]
37+
else:
38+
raise ValueError(f"Model {model} not found")
39+
40+
token_provider, api_key = self.get_authentication_properties()
3741
async with AsyncAzureOpenAI(
38-
api_key=os.environ["OpenAI__ApiKey"],
39-
azure_endpoint=os.environ["OpenAI__Endpoint"],
42+
azure_deployment=model_deployment,
4043
api_version=os.environ["OpenAI__ApiVersion"],
44+
azure_endpoint=os.environ["OpenAI__Endpoint"],
45+
azure_ad_token_provider=token_provider,
46+
api_key=api_key,
4147
) as open_ai_client:
4248
response = await open_ai_client.chat.completions.create(
43-
model=os.environ["OpenAI__MiniCompletionDeployment"],
49+
model=model_deployment,
4450
messages=messages,
4551
temperature=temperature,
52+
max_tokens=max_tokens,
4653
)
4754
return response.choices[0].message.content
55+
56+
async def run_embedding_request(self, batch: list[str]):
57+
token_provider, api_key = self.get_authentication_properties()
58+
59+
model_deployment = os.environ["OpenAI__EmbeddingModel"]
60+
async with AsyncAzureOpenAI(
61+
azure_deployment=model_deployment,
62+
api_version=os.environ["OpenAI__ApiVersion"],
63+
azure_endpoint=os.environ["OpenAI__Endpoint"],
64+
azure_ad_token_provider=token_provider,
65+
api_key=api_key,
66+
) as open_ai_client:
67+
embeddings = await open_ai_client.embeddings.create(
68+
model=os.environ["OpenAI__EmbeddingModel"],
69+
input=batch,
70+
)
71+
72+
return embeddings

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
import logging
99
from pydantic import BaseModel, Field, ConfigDict, computed_field
1010
from typing import Optional
11-
from text_2_sql_core.utils.environment import IdentityType, get_identity_type
12-
from openai import AsyncAzureOpenAI
13-
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
1411
import random
1512
import re
1613
import networkx as nx
1714
from text_2_sql_core.utils.database import DatabaseEngine
1815
from tenacity import retry, stop_after_attempt, wait_exponential
16+
from text_2_sql_core.connectors.open_ai import OpenAIConnector
1917

2018
logging.basicConfig(level=logging.INFO)
2119

@@ -279,6 +277,8 @@ def __init__(
279277
if output_directory is None:
280278
self.output_directory = "."
281279

280+
self.open_ai_connector = OpenAIConnector()
281+
282282
load_dotenv(find_dotenv())
283283

284284
@property
@@ -627,9 +627,6 @@ async def extract_columns_with_definitions(
627627

628628
return columns
629629

630-
@retry(
631-
stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10)
632-
)
633630
async def send_request_to_llm(self, system_prompt: str, input: str):
634631
"""A method to use GPT to generate a definition for an entity.
635632
@@ -640,60 +637,23 @@ async def send_request_to_llm(self, system_prompt: str, input: str):
640637
Returns:
641638
str: The generated definition."""
642639

643-
MAX_TOKENS = 2000
644-
645-
api_version = os.environ["OpenAI__ApiVersion"]
646-
model = os.environ["OpenAI__CompletionDeployment"]
647-
648-
if get_identity_type() == IdentityType.SYSTEM_ASSIGNED:
649-
token_provider = get_bearer_token_provider(
650-
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
651-
)
652-
api_key = None
653-
elif get_identity_type() == IdentityType.USER_ASSIGNED:
654-
token_provider = get_bearer_token_provider(
655-
DefaultAzureCredential(
656-
managed_identity_client_id=os.environ["FunctionApp__ClientId"]
657-
),
658-
"https://cognitiveservices.azure.com/.default",
659-
)
660-
api_key = None
661-
else:
662-
token_provider = None
663-
api_key = os.environ["OpenAI__ApiKey"]
664-
665-
try:
666-
async with AsyncAzureOpenAI(
667-
api_key=api_key,
668-
api_version=api_version,
669-
azure_ad_token_provider=token_provider,
670-
azure_endpoint=os.environ.get("OpenAI__Endpoint"),
671-
) as client:
672-
response = await client.chat.completions.create(
673-
model=model,
674-
messages=[
675-
{
676-
"role": "system",
677-
"content": system_prompt,
678-
},
679-
{
680-
"role": "user",
681-
"content": [
682-
{
683-
"type": "text",
684-
"text": input,
685-
},
686-
],
687-
},
688-
],
689-
max_tokens=MAX_TOKENS,
690-
)
640+
messages = [
641+
{
642+
"role": "system",
643+
"content": system_prompt,
644+
},
645+
{
646+
"role": "user",
647+
"content": [
648+
{
649+
"type": "text",
650+
"text": input,
651+
},
652+
],
653+
},
654+
]
691655

692-
return response.choices[0].message.content
693-
except Exception as e:
694-
logging.error(f"Unable to generate definition for {input}")
695-
logging.error(f"Error generating definition: {e}")
696-
return None
656+
return await self.open_ai_connector.run_completion_request(messages)
697657

698658
async def generate_entity_definition(self, entity: EntityItem):
699659
"""A method to generate a definition for an entity.

0 commit comments

Comments
 (0)