Skip to content

Commit 19a6cbf

Browse files
committed
Add additional dependencies to pyproject.toml
- Added specific version of transformers (==4.51.2) - Added clickhouse-driver (>=0.2.9,<0.3.0) - Added numpy (<2.0)
1 parent 585c032 commit 19a6cbf

File tree

7 files changed

+1300
-1281
lines changed

7 files changed

+1300
-1281
lines changed

.env.example

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33
###############################################
44

55
# LLM_PROVIDER=openai
6-
# OPEN_AI_LLM_KEY=
7-
# OPEN_AI_LLM_MODEL=gpt-4o
6+
# OPEN_AI_KEY=sk-proj-----
7+
# OPEN_AI_LLM_MODEL=gpt-4.1
88

99
# LLM_PROVIDER=gemini
1010
# GEMINI_API_KEY=
1111
# GEMINI_LLM_MODEL=gemini-2.0-flash-lite
1212

13-
# LLM_PROVIDER=azure
14-
# AZURE_OPENAI_LLM_ENDPOINT=
15-
# AZURE_OPENAI_LLM_KEY=
16-
# AZURE_OPENAI_LLM_MODEL=
17-
# AZURE_OPENAI_LLM_API_VERSION=
13+
LLM_PROVIDER=azure
14+
AZURE_OPENAI_LLM_ENDPOINT=https://-------.openai.azure.com/
15+
AZURE_OPENAI_LLM_KEY=-
16+
AZURE_OPENAI_LLM_MODEL=gpt4o
17+
AZURE_OPENAI_LLM_API_VERSION=2024-07-01-preview
1818

1919
# LLM_PROVIDER=ollama
2020
# OLLAMA_LLM_BASE_URL=
@@ -36,31 +36,38 @@
3636
########### Embedding API SElECTION ###########
3737
###############################################
3838
# Only used if you are using an LLM that does not natively support embedding (openai or Azure)
39-
# EMBEDDING_ENGINE='openai'
40-
# OPEN_AI_KEY=sk-xxxx
41-
# EMBEDDING_MODEL_PREF='text-embedding-ada-002'
39+
# EMBEDDING_PROVIDER='openai'
40+
# OPEN_AI_EMBEDDING_MODEL='text-embedding-ada-002'
4241

43-
# EMBEDDING_ENGINE='azure'
44-
# AZURE_OPENAI_ENDPOINT=
45-
# AZURE_OPENAI_KEY=
46-
# EMBEDDING_MODEL_PREF='my-embedder-model' # This is the "deployment" on Azure you want to use for embeddings. Not the base model. Valid base model is text-embedding-ada-002
42+
# EMBEDDING_PROVIDER=azure
43+
# AZURE_OPENAI_EMBEDDING_ENDPOINT=https://-------.openai.azure.com/openai/deployments
44+
# AZURE_OPENAI_EMBEDDING_KEY=-
45+
# AZURE_OPENAI_EMBEDDING_MODEL='textembeddingada002' # This is the "deployment" on Azure you want to use for embeddings. Not the base model. Valid base model is text-embedding-ada-002
46+
# AZURE_OPENAI_EMBEDDING_API_VERSION=2023-09-15-preview
4747

48-
# EMBEDDING_ENGINE='ollama'
48+
# EMBEDDING_PROVIDER='ollama'
4949
# EMBEDDING_BASE_PATH='http://host.docker.internal:11434'
50-
# EMBEDDING_MODEL_PREF='nomic-embed-text:latest'
50+
# EMBEDDING_MODEL='nomic-embed-text:latest'
5151
# EMBEDDING_MODEL_MAX_CHUNK_LENGTH=8192
5252

53-
# EMBEDDING_ENGINE='bedrock'
54-
# AWS_BEDROCK_EMBEDDING_ACCESS_KEY_ID=
55-
# AWS_BEDROCK_EMBEDDING_ACCESS_KEY=
56-
# AWS_BEDROCK_EMBEDDING_REGION=us-west-2
57-
# AWS_BEDROCK_EMBEDDING_MODEL_PREF=amazon.embedding-embedding-ada-002:0
53+
EMBEDDING_PROVIDER='bedrock'
54+
AWS_BEDROCK_EMBEDDING_ACCESS_KEY_ID=--
55+
AWS_BEDROCK_EMBEDDING_SECRET_ACCESS_KEY=-/-+-+-
56+
AWS_BEDROCK_EMBEDDING_REGION=us-west-2
57+
AWS_BEDROCK_EMBEDDING_MODEL=amazon.titan-embed-text-v2:0
5858

59-
# EMBEDDING_ENGINE='gemini'
59+
# EMBEDDING_PROVIDER='gemini'
6060
# GEMINI_EMBEDDING_API_KEY=
61-
# EMBEDDING_MODEL_PREF='text-embedding-004'
61+
# EMBEDDING_MODEL='text-embedding-004'
6262

63-
# EMBEDDING_ENGINE='huggingface'
63+
# EMBEDDING_PROVIDER='huggingface'
6464
# HUGGING_FACE_EMBEDDING_REPO_ID=
6565
# HUGGING_FACE_EMBEDDING_MODEL=
6666
# HUGGING_FACE_EMBEDDING_API_TOKEN=
67+
68+
DATAHUB_SERVER = 'http://-.-.-.-:-'
69+
CLICKHOUSE_HOST = '-.-.-.-'
70+
CLICKHOUSE_DATABASE = 'main'
71+
CLICKHOUSE_USER = '-'
72+
CLICKHOUSE_PASSWORD = '-'
73+
CLICKHOUSE_PORT = 9000

evaluation/gen_persona.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from persona_class import PersonaList
55

66
from llm_utils.tools import _get_table_info
7-
from langchain_openai.chat_models import ChatOpenAI
7+
from llm_utils.llm_factory import get_llm
88
from langchain_core.prompts import ChatPromptTemplate
99
from argparse import ArgumentParser
1010

@@ -19,7 +19,7 @@ def get_table_des_string(tables_desc):
1919
def generate_persona(tables_desc):
2020
description_string = get_table_des_string(tables_desc)
2121

22-
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
22+
llm = get_llm(temperature=0)
2323
system_prompt = """주어진 Tabel description들을 참고하여 Text2SQL 서비스로 질문을 할만한 패르소나를 생성하세요"""
2424

2525
prompt = ChatPromptTemplate.from_messages(

interface/lang2sql.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from llm_utils.connect_db import ConnectDB
1313
from llm_utils.graph import builder
1414

15+
import re
16+
1517
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
1618
SIDEBAR_OPTIONS = {
1719
"show_total_token_usage": "Show Total Token Usage",
@@ -115,7 +117,16 @@ def display_result(
115117
if st.session_state.get("show_referenced_tables", True):
116118
st.write("참고한 테이블 목록:", res["searched_tables"])
117119
if st.session_state.get("show_table", True):
118-
sql = res["generated_query"]
120+
try:
121+
sql = re.findall(r"```sql(.*?)```", res["generated_query"].content, re.DOTALL)
122+
sql = sql[0].strip()
123+
except ValueError:
124+
st.error("SQL 쿼리를 찾을 수 없습니다.")
125+
return
126+
127+
if not sql:
128+
st.error("SQL 쿼리가 비어 있습니다.")
129+
return
119130
df = database.run_sql(sql)
120131
st.dataframe(df.head(10) if len(df) > 10 else df)
121132

llm_utils/llm_factory.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,88 +18,96 @@
1818
AzureChatOpenAI,
1919
OpenAIEmbeddings,
2020
)
21-
from langchain_community.llms.bedrock import Bedrock
2221

23-
# .env 파일 로딩
24-
load_dotenv()
22+
env_path = os.path.join(os.getcwd(), ".env")
2523

24+
if os.path.exists(env_path):
25+
load_dotenv(env_path, override=True)
26+
print(f"✅ 환경변수 파일(.env)이 {os.getcwd()}에 로드되었습니다!")
27+
else:
28+
print(f"⚠️ 환경변수 파일(.env)이 {os.getcwd()}에 없습니다!")
2629

27-
def get_llm() -> BaseLanguageModel:
30+
def get_llm(**kwargs) -> BaseLanguageModel:
2831
"""
2932
return chat model interface
3033
"""
3134
provider = os.getenv("LLM_PROVIDER")
35+
print(os.environ["LLM_PROVIDER"])
3236

3337
if provider is None:
3438
raise ValueError("LLM_PROVIDER environment variable is not set.")
3539

3640
if provider == "openai":
37-
return get_llm_openai()
41+
return get_llm_openai(**kwargs)
3842

3943
elif provider == "azure":
40-
return get_llm_azure()
44+
return get_llm_azure(**kwargs)
4145

4246
elif provider == "bedrock":
43-
return get_llm_bedrock()
47+
return get_llm_bedrock(**kwargs)
4448

4549
elif provider == "gemini":
46-
return get_llm_gemini()
50+
return get_llm_gemini(**kwargs)
4751

4852
elif provider == "ollama":
49-
return get_llm_ollama()
53+
return get_llm_ollama(**kwargs)
5054

5155
elif provider == "huggingface":
52-
return get_llm_huggingface()
56+
return get_llm_huggingface(**kwargs)
5357

5458
else:
5559
raise ValueError(f"Invalid LLM API Provider: {provider}")
5660

5761

58-
def get_llm_openai() -> BaseLanguageModel:
62+
def get_llm_openai(**kwargs) -> BaseLanguageModel:
5963
return ChatOpenAI(
60-
model=os.getenv("OPEN_MODEL_PREF", "gpt-4o"),
64+
model=os.getenv("OPEN_AI_LLM_MODEL", "gpt-4o"),
6165
api_key=os.getenv("OPEN_AI_KEY"),
66+
**kwargs,
6267
)
6368

6469

65-
def get_llm_azure() -> BaseLanguageModel:
70+
def get_llm_azure(**kwargs) -> BaseLanguageModel:
6671
return AzureChatOpenAI(
6772
api_key=os.getenv("AZURE_OPENAI_LLM_KEY"),
6873
azure_endpoint=os.getenv("AZURE_OPENAI_LLM_ENDPOINT"),
6974
azure_deployment=os.getenv("AZURE_OPENAI_LLM_MODEL"), # Deployment name
7075
api_version=os.getenv("AZURE_OPENAI_LLM_API_VERSION", "2023-07-01-preview"),
76+
**kwargs,
7177
)
7278

7379

74-
def get_llm_bedrock() -> BaseLanguageModel:
80+
def get_llm_bedrock(**kwargs) -> BaseLanguageModel:
7581
return ChatBedrockConverse(
7682
model=os.getenv("AWS_BEDROCK_LLM_MODEL"),
7783
aws_access_key_id=os.getenv("AWS_BEDROCK_LLM_ACCESS_KEY_ID"),
7884
aws_secret_access_key=os.getenv("AWS_BEDROCK_LLM_SECRET_ACCESS_KEY"),
7985
region_name=os.getenv("AWS_BEDROCK_LLM_REGION", "us-east-1"),
86+
**kwargs,
8087
)
8188

8289

83-
def get_llm_gemini() -> BaseLanguageModel:
84-
return ChatGoogleGenerativeAI(model=os.getenv("GEMINI_LLM_MODEL"))
90+
def get_llm_gemini(**kwargs) -> BaseLanguageModel:
91+
return ChatGoogleGenerativeAI(model=os.getenv("GEMINI_LLM_MODEL"), **kwargs)
8592

8693

87-
def get_llm_ollama() -> BaseLanguageModel:
94+
def get_llm_ollama(**kwargs) -> BaseLanguageModel:
8895
base_url = os.getenv("OLLAMA_LLM_BASE_URL")
8996
if base_url:
90-
return ChatOllama(base_url=base_url, model=os.getenv("OLLAMA_LLM_MODEL"))
97+
return ChatOllama(base_url=base_url, model=os.getenv("OLLAMA_LLM_MODEL"), **kwargs)
9198
else:
92-
return ChatOllama(model=os.getenv("OLLAMA_LLM_MODEL"))
99+
return ChatOllama(model=os.getenv("OLLAMA_LLM_MODEL"), **kwargs)
93100

94101

95-
def get_llm_huggingface() -> BaseLanguageModel:
102+
def get_llm_huggingface(**kwargs) -> BaseLanguageModel:
96103
return ChatHuggingFace(
97104
llm=HuggingFaceEndpoint(
98105
model=os.getenv("HUGGING_FACE_LLM_MODEL"),
99106
repo_id=os.getenv("HUGGING_FACE_LLM_REPO_ID"),
100107
task="text-generation",
101108
endpoint_url=os.getenv("HUGGING_FACE_LLM_ENDPOINT"),
102109
huggingfacehub_api_token=os.getenv("HUGGING_FACE_LLM_API_TOKEN"),
110+
**kwargs,
103111
)
104112
)
105113

@@ -109,6 +117,7 @@ def get_embeddings() -> Optional[BaseLanguageModel]:
109117
return embedding model interface
110118
"""
111119
provider = os.getenv("EMBEDDING_PROVIDER")
120+
print(provider)
112121

113122
if provider is None:
114123
raise ValueError("EMBEDDING_PROVIDER environment variable is not set.")
@@ -135,7 +144,7 @@ def get_embeddings() -> Optional[BaseLanguageModel]:
135144
def get_embeddings_openai() -> BaseLanguageModel:
136145
return OpenAIEmbeddings(
137146
model=os.getenv("OPEN_AI_EMBEDDING_MODEL"),
138-
openai_api_key=os.getenv("OPEN_AI_EMBEDDING_KEY"),
147+
openai_api_key=os.getenv("OPEN_AI_KEY"),
139148
)
140149

141150

llm_utils/retrieval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from transformers import AutoModelForSequenceClassification, AutoTokenizer
88

99
from .tools import get_info_from_db
10+
from .llm_factory import get_embeddings
1011

1112

1213
def get_vector_db():
1314
"""벡터 데이터베이스를 로드하거나 생성합니다."""
14-
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
15+
embeddings = get_embeddings()
1516
try:
1617
db = FAISS.load_local(
1718
os.getcwd() + "/table_info_db",

0 commit comments

Comments
 (0)