Skip to content

Commit 5f38cfe

Browse files
committed
Use openai connector
1 parent 4fcd7fa commit 5f38cfe

File tree

2 files changed

+32
-57
lines changed

2 files changed

+32
-57
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,28 @@ def get_authentication_properties(cls) -> dict:
3333

3434
return token_provider, api_key
3535

36-
async def run_completion_request(self, messages: list[dict], temperature=0):
36+
async def run_completion_request(
37+
self, messages: list[dict], temperature=0, max_tokens=2000, model="4o-mini"
38+
) -> str:
39+
if model == "4o-mini":
40+
model_deployment = os.environ["OpenAI__MiniCompletionDeployment"]
41+
elif model == "4o":
42+
model_deployment = os.environ["OpenAI__CompletionDeployment"]
43+
else:
44+
raise ValueError(f"Model {model} not found")
45+
3746
token_provider, api_key = self.get_authentication_properties()
3847
async with AsyncAzureOpenAI(
39-
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
48+
azure_deployment=model_deployment,
4049
api_version=os.environ["OpenAI__ApiVersion"],
4150
azure_endpoint=os.environ["OpenAI__Endpoint"],
4251
azure_ad_token_provider=token_provider,
4352
api_key=api_key,
4453
) as open_ai_client:
4554
response = await open_ai_client.chat.completions.create(
46-
model=os.environ["OpenAI__MiniCompletionDeployment"],
55+
model=model_deployment,
4756
messages=messages,
4857
temperature=temperature,
58+
max_tokens=max_tokens,
4959
)
5060
return response.choices[0].message.content

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py

Lines changed: 19 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
import logging
99
from pydantic import BaseModel, Field, ConfigDict, computed_field
1010
from typing import Optional
11-
from text_2_sql_core.utils.environment import IdentityType, get_identity_type
12-
from openai import AsyncAzureOpenAI
13-
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
1411
import random
1512
import re
1613
import networkx as nx
1714
from text_2_sql_core.utils.database import DatabaseEngine
1815
from tenacity import retry, stop_after_attempt, wait_exponential
16+
from text_2_sql_core.connectors.open_ai import OpenAIConnector
1917

2018
logging.basicConfig(level=logging.INFO)
2119

@@ -279,6 +277,8 @@ def __init__(
279277
if output_directory is None:
280278
self.output_directory = "."
281279

280+
self.open_ai_connector = OpenAIConnector()
281+
282282
load_dotenv(find_dotenv())
283283

284284
@property
@@ -627,9 +627,6 @@ async def extract_columns_with_definitions(
627627

628628
return columns
629629

630-
@retry(
631-
stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10)
632-
)
633630
async def send_request_to_llm(self, system_prompt: str, input: str):
634631
"""A method to use GPT to generate a definition for an entity.
635632
@@ -640,55 +637,23 @@ async def send_request_to_llm(self, system_prompt: str, input: str):
640637
Returns:
641638
str: The generated definition."""
642639

643-
MAX_TOKENS = 2000
644-
645-
api_version = os.environ["OpenAI__ApiVersion"]
646-
model = os.environ["OpenAI__CompletionDeployment"]
647-
648-
if get_identity_type() in [
649-
IdentityType.SYSTEM_ASSIGNED,
650-
IdentityType.USER_ASSIGNED,
651-
]:
652-
token_provider = get_bearer_token_provider(
653-
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
654-
)
655-
api_key = None
656-
else:
657-
token_provider = None
658-
api_key = os.environ["OpenAI__ApiKey"]
659-
660-
try:
661-
async with AsyncAzureOpenAI(
662-
api_key=api_key,
663-
api_version=api_version,
664-
azure_ad_token_provider=token_provider,
665-
azure_endpoint=os.environ.get("OpenAI__Endpoint"),
666-
) as client:
667-
response = await client.chat.completions.create(
668-
model=model,
669-
messages=[
670-
{
671-
"role": "system",
672-
"content": system_prompt,
673-
},
674-
{
675-
"role": "user",
676-
"content": [
677-
{
678-
"type": "text",
679-
"text": input,
680-
},
681-
],
682-
},
683-
],
684-
max_tokens=MAX_TOKENS,
685-
)
640+
messages = [
641+
{
642+
"role": "system",
643+
"content": system_prompt,
644+
},
645+
{
646+
"role": "user",
647+
"content": [
648+
{
649+
"type": "text",
650+
"text": input,
651+
},
652+
],
653+
},
654+
]
686655

687-
return response.choices[0].message.content
688-
except Exception as e:
689-
logging.error(f"Unable to generate definition for {input}")
690-
logging.error(f"Error generating definition: {e}")
691-
return None
656+
return await self.open_ai_connector.run_completion_request(messages)
692657

693658
async def generate_entity_definition(self, entity: EntityItem):
694659
"""A method to generate a definition for an entity.

0 commit comments

Comments
 (0)