88import logging
99from pydantic import BaseModel , Field , ConfigDict , computed_field
1010from typing import Optional
11- from text_2_sql_core .utils .environment import IdentityType , get_identity_type
12- from openai import AsyncAzureOpenAI
13- from azure .identity import DefaultAzureCredential , get_bearer_token_provider
1411import random
1512import re
1613import networkx as nx
1714from text_2_sql_core .utils .database import DatabaseEngine
1815from tenacity import retry , stop_after_attempt , wait_exponential
16+ from text_2_sql_core .connectors .open_ai import OpenAIConnector
1917
2018logging .basicConfig (level = logging .INFO )
2119
@@ -279,6 +277,8 @@ def __init__(
279277 if output_directory is None :
280278 self .output_directory = "."
281279
280+ self .open_ai_connector = OpenAIConnector ()
281+
282282 load_dotenv (find_dotenv ())
283283
284284 @property
@@ -627,9 +627,6 @@ async def extract_columns_with_definitions(
627627
628628 return columns
629629
630- @retry (
631- stop = stop_after_attempt (3 ), wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 )
632- )
633630 async def send_request_to_llm (self , system_prompt : str , input : str ):
634631 """A method to use GPT to generate a definition for an entity.
635632
@@ -640,55 +637,23 @@ async def send_request_to_llm(self, system_prompt: str, input: str):
640637 Returns:
641638 str: The generated definition."""
642639
643- MAX_TOKENS = 2000
644-
645- api_version = os .environ ["OpenAI__ApiVersion" ]
646- model = os .environ ["OpenAI__CompletionDeployment" ]
647-
648- if get_identity_type () in [
649- IdentityType .SYSTEM_ASSIGNED ,
650- IdentityType .USER_ASSIGNED ,
651- ]:
652- token_provider = get_bearer_token_provider (
653- DefaultAzureCredential (), "https://cognitiveservices.azure.com/.default"
654- )
655- api_key = None
656- else :
657- token_provider = None
658- api_key = os .environ ["OpenAI__ApiKey" ]
659-
660- try :
661- async with AsyncAzureOpenAI (
662- api_key = api_key ,
663- api_version = api_version ,
664- azure_ad_token_provider = token_provider ,
665- azure_endpoint = os .environ .get ("OpenAI__Endpoint" ),
666- ) as client :
667- response = await client .chat .completions .create (
668- model = model ,
669- messages = [
670- {
671- "role" : "system" ,
672- "content" : system_prompt ,
673- },
674- {
675- "role" : "user" ,
676- "content" : [
677- {
678- "type" : "text" ,
679- "text" : input ,
680- },
681- ],
682- },
683- ],
684- max_tokens = MAX_TOKENS ,
685- )
640+ messages = [
641+ {
642+ "role" : "system" ,
643+ "content" : system_prompt ,
644+ },
645+ {
646+ "role" : "user" ,
647+ "content" : [
648+ {
649+ "type" : "text" ,
650+ "text" : input ,
651+ },
652+ ],
653+ },
654+ ]
686655
687- return response .choices [0 ].message .content
688- except Exception as e :
689- logging .error (f"Unable to generate definition for { input } " )
690- logging .error (f"Error generating definition: { e } " )
691- return None
656+ return await self .open_ai_connector .run_completion_request (messages )
692657
693658 async def generate_entity_definition (self , entity : EntityItem ):
694659 """A method to generate a definition for an entity.
0 commit comments