Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions interface/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def summarize_total_tokens(data):

# 결과 출력
st.write("총 토큰 사용량:", total_tokens)
# st.write("결과:", res["generated_query"].content)
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
st.write("결과:", res["generated_query"].content)
# st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
st.write("결과 설명:\n\n", res["messages"][-1].content)
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
st.write("참고한 테이블 목록:", res["searched_tables"])
78 changes: 13 additions & 65 deletions llm_utils/chains.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, load_prompt, SystemMessagePromptTemplate

from .llm_factory import get_llm

from dotenv import load_dotenv
from prompt.template_loader import get_prompt_template
import yaml

env_path = os.path.join(os.getcwd(), ".env")

Expand All @@ -12,6 +14,7 @@
else:
print(f"⚠️ 환경변수 파일(.env)이 {os.getcwd()}에 없습니다!")


llm = get_llm(
model_type="openai",
model_name="gpt-4o-mini",
Expand All @@ -20,45 +23,12 @@


def create_query_refiner_chain(llm):
prompt = get_prompt_template('query_refiner_prompt')
tool_choice_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다.
현재 subscription_activities, contract_activities, marketing_activities,
sales_activities, success_activities, support_activities, trial_activities 데이터를
보유하고 있으며, 사용자의 질문이 모호할 경우에도 우리가 가진 데이터를 기반으로
충분히 답변 가능한 형태로 질문을 구체화해 주세요.

주의:
- 사용자에게 추가 정보를 요구하는 ‘재질문(추가 질문)’을 하지 마세요.
- 질문에 포함해야 할 요소(예: 특정 기간, 대상 유저 그룹, 분석 대상 로그 종류 등)가
불충분하더라도, 합리적으로 추론해 가정한 뒤
답변에 충분한 질문 형태로 완성해 주세요.

예시:
사용자가 "유저 이탈 원인이 궁금해요"라고 했다면,
재질문 형식이 아니라
"최근 1개월 간의 접속·결제 로그를 기준으로,
주로 어떤 사용자가 어떤 과정을 거쳐 이탈하는지를 분석해야 한다"처럼
분석 방향이 명확해진 질문 한 문장(또는 한 문단)으로 정리해 주세요.

최종 출력 형식 예시:
------------------------------
구체화된 질문:
"최근 1개월 동안 고액 결제 경험이 있는 유저가
행동 로그에서 이탈 전 어떤 패턴을 보였는지 분석"

가정한 조건:
- 최근 1개월치 행동 로그와 결제 로그 중심
- 고액 결제자(월 결제액 10만 원 이상) 그룹 대상으로 한정
------------------------------
""",
),
SystemMessagePromptTemplate.from_template(prompt),
MessagesPlaceholder(variable_name="user_input"),
(
"system",
SystemMessagePromptTemplate.from_template(
"""
위 사용자의 입력을 바탕으로
분석 관점에서 **충분히 답변 가능한 형태**로
Expand All @@ -74,36 +44,11 @@ def create_query_refiner_chain(llm):

# QueryMakerChain
def create_query_maker_chain(llm):
# SystemPrompt만 yaml 파일로 불러와서 사용
prompt = get_prompt_template('query_maker_prompt')
query_maker_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다.
사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하세요.

주의사항:
- 사용자의 질문이 다소 모호하더라도, 주어진 데이터를 참고하여 합리적인 가정을 통해 SQL 쿼리를 완성하세요.
- 불필요한 재질문 없이, 가능한 가장 명확한 분석 쿼리를 만들어 주세요.
- 최종 출력 형식은 반드시 아래와 같아야 합니다.

최종 형태 예시:

<SQL>
```sql
SELECT COUNT(DISTINCT user_id)
FROM stg_users
```

<해석>
```plaintext (max_length_per_line=100)
이 쿼리는 stg_users 테이블에서 고유한 사용자의 수를 계산합니다.
사용자는 유니크한 user_id를 가지고 있으며
중복을 제거하기 위해 COUNT(DISTINCT user_id)를 사용했습니다.
```

""",
),
SystemMessagePromptTemplate.from_template(prompt),
(
"system",
"아래는 사용자의 질문 및 구체화된 질문입니다:",
Expand All @@ -127,3 +72,6 @@ def create_query_maker_chain(llm):

query_refiner_chain = create_query_refiner_chain(llm)
query_maker_chain = create_query_maker_chain(llm)

if __name__ == "__main__":
query_refiner_chain.invoke()
10 changes: 6 additions & 4 deletions llm_utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class QueryMakerState(TypedDict):

# 노드 함수: QUERY_REFINER 노드
def query_refiner_node(state: QueryMakerState):
print('query_refiner_node 진입 [md]')
res = query_refiner_chain.invoke(
input={
"user_input": [state["messages"][0].content],
Expand All @@ -60,6 +61,7 @@ def get_table_info_node(state: QueryMakerState):
)
except:
documents = get_info_from_db()
print("db_embedding 진입")
db = FAISS.from_documents(documents, embeddings)
db.save_local(os.getcwd() + "/table_info_db")
print("table_info_db not found")
Expand Down Expand Up @@ -139,10 +141,10 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
# 노드 추가
builder.add_node(QUERY_REFINER, query_refiner_node)
builder.add_node(GET_TABLE_INFO, get_table_info_node)
# builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
builder.add_node(
QUERY_MAKER, query_maker_node_with_db_guide
) # query_maker_node_with_db_guide
builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
# builder.add_node(
# QUERY_MAKER, query_maker_node_with_db_guide
# ) # query_maker_node_with_db_guide

# 기본 엣지 설정
builder.add_edge(QUERY_REFINER, GET_TABLE_INFO)
Expand Down
33 changes: 33 additions & 0 deletions llm_utils/prompts_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from langchain.chains.sql_database.prompt import SQL_PROMPTS
import os

from langchain_core.prompts import load_prompt


class SQLPrompt():
def __init__(self):
# os library를 확인해서 SQL_PROMPTS key에 해당하는ㅁ prompt가 있으면, 이를 교체
self.sql_prompts = SQL_PROMPTS
self.target_db_list = list(SQL_PROMPTS.keys())
self.prompt_path = '../prompt'

def update_prompt_from_path(self):
if os.path.exists(self.prompt_path):
path_list = os.listdir(self.prompt_path)
# yaml 파일만 가져옴
file_list = [file for file in path_list if file.endswith('.yaml')]
key_path_dict = {key.split('.')[0]: os.path.join(self.prompt_path, key) for key in file_list if key.split('.')[0] in self.target_db_list}
# file_list에서 sql_prompts의 key에 해당하는 파일이 있는 것만 가져옴
for key, path in key_path_dict.items():
self.sql_prompts[key] = load_prompt(path, encoding='utf-8')
else:
raise FileNotFoundError(f"Prompt file not found in {self.prompt_path}")
return False

if __name__ == '__main__':
sql_prompts_class = SQLPrompt()
print(sql_prompts_class.sql_prompts['mysql'])
print(sql_prompts_class.update_prompt_from_path())

print(sql_prompts_class.sql_prompts['mysql'])
print(sql_prompts_class.sql_prompts)
1 change: 1 addition & 0 deletions llm_utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _get_table_info() -> Dict[str, str]:
table_description = fetcher.get_table_description(urn)
if table_name and table_description:
table_info[table_name] = table_description
print(f'table_name {urn}')
return table_info


Expand Down
Empty file added prompt/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions prompt/query_maker_prompt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Role

당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다.
사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하세요.

# 주의사항
- 사용자의 질문이 다소 모호하더라도, 주어진 데이터를 참고하여 합리적인 가정을 통해 SQL 쿼리를 완성하세요.
- 불필요한 재질문 없이, 가능한 가장 명확한 분석 쿼리를 만들어 주세요.
- 최종 출력 형식은 반드시 아래와 같아야 합니다.

# Output Example
최종 형태 예시:
<SQL>
```sql
SELECT COUNT(DISTINCT user_id)
FROM stg_users
```

<해석>
```plaintext (max_length_per_line=100)
이 쿼리는 stg_users 테이블에서 고유한 사용자의 수를 계산합니다.
사용자는 유니크한 user_id를 가지고 있으며
중복을 제거하기 위해 COUNT(DISTINCT user_id)를 사용했습니다.
32 changes: 32 additions & 0 deletions prompt/query_refiner_prompt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Role

당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다.
현재 subscription_activities, contract_activities, marketing_activities,
sales_activities, success_activities, support_activities, trial_activities 데이터를
보유하고 있으며, 사용자의 질문이 모호할 경우에도 우리가 가진 데이터를 기반으로
충분히 답변 가능한 형태로 질문을 구체화해 주세요.

# 주의사항:
- 사용자에게 추가 정보를 요구하는 ‘재질문(추가 질문)’을 하지 마세요.
- 질문에 포함해야 할 요소(예: 특정 기간, 대상 유저 그룹, 분석 대상 로그 종류 등)가
불충분하더라도, 합리적으로 추론해 가정한 뒤
답변에 충분한 질문 형태로 완성해 주세요.
예시:
사용자가 "유저 이탈 원인이 궁금해요"라고 했다면,
재질문 형식이 아니라
"최근 1개월 간의 접속·결제 로그를 기준으로,
주로 어떤 사용자가 어떤 과정을 거쳐 이탈하는지를 분석해야 한다"처럼
분석 방향이 명확해진 질문 한 문장(또는 한 문단)으로 정리해 주세요.

# Output Example

최종 출력 형식 예시:
------------------------------
구체화된 질문:
"최근 1개월 동안 고액 결제 경험이 있는 유저가
행동 로그에서 이탈 전 어떤 패턴을 보였는지 분석"

가정한 조건:
- 최근 1개월치 행동 로그와 결제 로그 중심
- 고액 결제자(월 결제액 10만 원 이상) 그룹 대상으로 한정
------------------------------
19 changes: 19 additions & 0 deletions prompt/template_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
import re
from datetime import datetime

from langchain_core.prompts import PromptTemplate
from langgraph.prebuilt.chat_agent_executor import AgentState


def get_prompt_template(prompt_name: str) -> str:
try:
with open(os.path.join(os.path.dirname(__file__), f"{prompt_name}.md"), "r", encoding="utf-8") as f:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❗제 환경에서는 에러가 나네요...ㅜㅜ

    raise FileNotFoundError(f"경고: '{prompt_name}.md' 파일을 찾을 수 없습니다.")

디버깅
ls -ltra /home/pseudo.dwlee038/miniconda3/lib/python3.12/site-packages/prompt/

-rw-rw-r--   1 pseudo.dwlee038 pseudo.dwlee038   790 Apr 21 10:31 template_loader.py
-rw-rw-r--   1 pseudo.dwlee038 pseudo.dwlee038     0 Apr 21 10:31 __init__.py
drwxrwxr-x   2 pseudo.dwlee038 pseudo.dwlee038  4096 Apr 21 10:31 __pycache__
drwxrwxr-x   3 pseudo.dwlee038 pseudo.dwlee038  4096 Apr 21 10:31 .
drwxrwxr-x 437 pseudo.dwlee038 pseudo.dwlee038 20480 Apr 21 10:31 ..
  • 왜 md파일은 안생기는걸까요!?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

혹시 해당 코드를 받으신 이후 pip install . 을 해보셨을까요?
아니면, pip install -e .을 해서 새로 build가 필요할 것 같습니다.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 pip install -e . 로 진행하니 됩니다!

template = f.read()
except FileNotFoundError:
raise FileNotFoundError(f"경고: '{prompt_name}.md' 파일을 찾을 수 없습니다.")
return template

if __name__ == "__main__":
print(get_prompt_template("system_prompt"))
# print(apply_prompt_template("prompt_md_sample", {"messages": []}))
Loading