Skip to content

Commit 335fd9d

Browse files
committed
Refactor query handling in llm_utils
- Add Reranker - Added new dependencies: langchain-huggingface==0.1.2 and transformers==4.51.2 to requirements.txt. - Removed the QueryRefinedAgainChain and its associated logic from chains.py and graph.py to streamline the query refinement process.
1 parent c59211f commit 335fd9d

File tree

3 files changed

+52
-91
lines changed

3 files changed

+52
-91
lines changed

llm_utils/chains.py

Lines changed: 8 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,28 @@ def create_query_refiner_chain(llm):
4040
예시:
4141
사용자가 "유저 이탈 원인이 궁금해요"라고 했다면,
4242
재질문 형식이 아니라
43-
"최근 1개월 간의 접속·결제 로그를 기준으로,
43+
"접속·결제 로그를 기준으로,
4444
주로 어떤 사용자가 어떤 과정을 거쳐 이탈하는지를 분석해야 한다"처럼
4545
분석 방향이 명확해진 질문 한 문장(또는 한 문단)으로 정리해 주세요.
4646
4747
최종 출력 형식 예시:
4848
------------------------------
4949
구체화된 질문:
50-
"최근 1개월 동안 고액 결제 경험이 있는 유저가
50+
"고액 결제 경험이 있는 유저가
5151
행동 로그에서 이탈 전 어떤 패턴을 보였는지 분석"
5252
5353
가정한 조건:
54-
- 최근 1개월치 행동 로그와 결제 로그 중심
54+
- 행동 로그와 결제 로그 중심
5555
- 고액 결제자(월 결제액 10만 원 이상) 그룹 대상으로 한정
5656
------------------------------
5757
""",
5858
),
5959
MessagesPlaceholder(variable_name="user_input"),
60+
(
61+
"system",
62+
"다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:",
63+
),
64+
MessagesPlaceholder(variable_name="searched_tables"),
6065
(
6166
"system",
6267
"""
@@ -72,61 +77,6 @@ def create_query_refiner_chain(llm):
7277
return tool_choice_prompt | llm
7378

7479

75-
# QueryRefinedAgainChain
76-
def create_query_redefined_again_chain(llm):
77-
query_redefined_again_prompt = ChatPromptTemplate.from_messages(
78-
[
79-
(
80-
"system",
81-
"""
82-
당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다.
83-
사용자의 질문과 이미 구체화된 질문을 바탕으로, 실제 사용 가능한 테이블과 컬럼 정보를 검토하여
84-
더욱 정교하게 질문을 재정의해 주세요.
85-
86-
주의사항:
87-
- 이전에 구체화된 질문을 기반으로 하되, 실제 DB 환경에서 사용 가능한 테이블/컬럼을 고려해 현실적인 분석 방향을 제시하세요.
88-
- 불필요한 재질문 없이, 주어진 데이터로 최대한 분석 가능한 형태로 질문을 구체화하세요.
89-
- 테이블 구조에 맞게 분석 질문을 조정하고, 필요한 가정을 추가하세요.
90-
- 최종 출력 형식은 반드시 아래와 같아야 합니다.
91-
92-
최종 형태 예시:
93-
94-
<최종 구체화된 질문>
95-
```
96-
최근 30일간 결제 금액이 10만원 이상인 사용자들의 서비스 이용 패턴과 이탈율을 분석하여,
97-
어떤 활동 패턴을 보이는 고액 결제자가 이탈하는지 파악
98-
```
99-
100-
<분석 접근 방향>
101-
```
102-
1. subscription_activities와 contract_activities 테이블을 조인하여 고액 결제자 식별
103-
2. 해당 사용자들의 activity_type 분포 확인
104-
3. 이탈 사용자(30일 이상 미접속)와 활성 사용자의 행동 패턴 비교 분석
105-
4. 주요 이탈 지점 식별
106-
```
107-
""",
108-
),
109-
(
110-
"system",
111-
"아래는 사용자의 원래 질문 및 1차 구체화된 질문입니다:",
112-
),
113-
MessagesPlaceholder(variable_name="user_input"),
114-
MessagesPlaceholder(variable_name="refined_input"),
115-
(
116-
"system",
117-
"다음은 사용자의 DB 환경정보와 실제 사용 가능한 테이블 및 컬럼 정보입니다:",
118-
),
119-
MessagesPlaceholder(variable_name="user_database_env"),
120-
MessagesPlaceholder(variable_name="searched_tables"),
121-
(
122-
"system",
123-
"위 정보를 바탕으로 DB 구조에 맞게 더욱 구체화된 최종 질문과 분석 접근 방향을 최종 형태 예시와 같은 형식으로 작성해주세요.",
124-
),
125-
]
126-
)
127-
return query_redefined_again_prompt | llm
128-
129-
13080
# QueryMakerChain
13181
def create_query_maker_chain(llm):
13282
query_maker_prompt = ChatPromptTemplate.from_messages(
@@ -165,7 +115,6 @@ def create_query_maker_chain(llm):
165115
),
166116
MessagesPlaceholder(variable_name="user_input"),
167117
MessagesPlaceholder(variable_name="refined_input"),
168-
MessagesPlaceholder(variable_name="refined_input_again"),
169118
(
170119
"system",
171120
"다음은 사용자의 db 환경정보와 사용 가능한 테이블 및 컬럼 정보입니다:",
@@ -182,5 +131,4 @@ def create_query_maker_chain(llm):
182131

183132

184133
query_refiner_chain = create_query_refiner_chain(llm)
185-
query_redefined_again_chain = create_query_redefined_again_chain(llm)
186134
query_maker_chain = create_query_maker_chain(llm)

llm_utils/graph.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@
1010

1111
from llm_utils.chains import (
1212
query_refiner_chain,
13-
query_redefined_again_chain,
1413
query_maker_chain,
1514
)
1615

1716
from llm_utils.tools import get_info_from_db
1817

1918
# 노드 식별자 정의
2019
QUERY_REFINER = "query_refiner"
21-
QUERY_REFINED_AGAIN = "query_redefined_again"
2220
GET_TABLE_INFO = "get_table_info"
2321
TOOL = "tool"
2422
TABLE_FILTER = "table_filter"
@@ -32,7 +30,6 @@ class QueryMakerState(TypedDict):
3230
searched_tables: dict[str, dict[str, str]]
3331
best_practice_query: str
3432
refined_input: str
35-
refined_input_again: str
3633
generated_query: str
3734

3835

@@ -43,6 +40,7 @@ def query_refiner_node(state: QueryMakerState):
4340
"user_input": [state["messages"][0].content],
4441
"user_database_env": [state["user_database_env"]],
4542
"best_practice_query": [state["best_practice_query"]],
43+
"searched_tables": [json.dumps(state["searched_tables"])],
4644
}
4745
)
4846
state["messages"].append(res)
@@ -66,9 +64,42 @@ def get_table_info_node(state: QueryMakerState):
6664
db = FAISS.from_documents(documents, embeddings)
6765
db.save_local(os.getcwd() + "/table_info_db")
6866
print("table_info_db not found")
69-
doc_res = db.similarity_search(state["messages"][-1].content)
70-
documents_dict = {}
7167

68+
retriever = db.as_retriever(search_kwargs={"k": 10})
69+
70+
from langchain.retrievers import ContextualCompressionRetriever
71+
from langchain.retrievers.document_compressors import CrossEncoderReranker
72+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
73+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
74+
75+
# Reranking 적용 여부 설정
76+
use_rerank = True # 필요에 따라 True 또는 False로 설정
77+
78+
if use_rerank:
79+
local_model_path = os.path.join(os.getcwd(), "ko_reranker_local")
80+
81+
# 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장
82+
if os.path.exists(local_model_path) and os.path.isdir(local_model_path):
83+
print("🔄 ko-reranker 모델 로컬에서 로드 중...")
84+
else:
85+
print("⬇️ ko-reranker 모델 다운로드 및 저장 중...")
86+
model = AutoModelForSequenceClassification.from_pretrained(
87+
"Dongjin-kr/ko-reranker"
88+
)
89+
tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker")
90+
model.save_pretrained(local_model_path)
91+
tokenizer.save_pretrained(local_model_path)
92+
model = HuggingFaceCrossEncoder(model_name=local_model_path)
93+
compressor = CrossEncoderReranker(model=model, top_n=3)
94+
retriever = db.as_retriever(search_kwargs={"k": 10})
95+
compression_retriever = ContextualCompressionRetriever(
96+
base_compressor=compressor, base_retriever=retriever
97+
)
98+
99+
doc_res = compression_retriever.invoke(state["messages"][0].content)
100+
else: # Reranking 미적용
101+
doc_res = db.similarity_search(state["messages"][0].content, k=10)
102+
documents_dict = {}
72103
for doc in doc_res:
73104
lines = doc.page_content.split("\n")
74105

@@ -93,19 +124,6 @@ def get_table_info_node(state: QueryMakerState):
93124
return state
94125

95126

96-
def query_redefined_again_node(state: QueryMakerState):
97-
res = query_redefined_again_chain.invoke(
98-
input={
99-
"user_input": [state["messages"][0].content],
100-
"refined_input": [state["refined_input"]],
101-
"user_database_env": [state["user_database_env"]],
102-
"searched_tables": [json.dumps(state["searched_tables"])],
103-
}
104-
)
105-
state["refined_input_again"] = res
106-
return state
107-
108-
109127
# 노드 함수: QUERY_MAKER 노드
110128
def query_maker_node(state: QueryMakerState):
111129
res = query_maker_chain.invoke(
@@ -137,9 +155,7 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
137155
res = chain.invoke(
138156
input={
139157
"input": "\n\n---\n\n".join(
140-
[state["messages"][0].content]
141-
# + [state["refined_input"].content]
142-
+ [state["refined_input_again"].content]
158+
[state["messages"][0].content] + [state["refined_input"].content]
143159
),
144160
"table_info": [json.dumps(state["searched_tables"])],
145161
"top_k": 10,
@@ -152,21 +168,16 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
152168

153169
# StateGraph 생성 및 구성
154170
builder = StateGraph(QueryMakerState)
155-
builder.set_entry_point(QUERY_REFINER)
171+
builder.set_entry_point(GET_TABLE_INFO)
156172

157173
# 노드 추가
158-
builder.add_node(QUERY_REFINER, query_refiner_node)
159174
builder.add_node(GET_TABLE_INFO, get_table_info_node)
160-
# builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
161-
builder.add_node(
162-
QUERY_MAKER, query_maker_node_with_db_guide
163-
) # query_maker_node_with_db_guide
164-
builder.add_node(QUERY_REFINED_AGAIN, query_redefined_again_node)
175+
builder.add_node(QUERY_REFINER, query_refiner_node)
176+
builder.add_node(QUERY_MAKER, query_maker_node_with_db_guide)
165177

166178
# 기본 엣지 설정
167-
builder.add_edge(QUERY_REFINER, GET_TABLE_INFO)
168-
builder.add_edge(GET_TABLE_INFO, QUERY_REFINED_AGAIN)
169-
builder.add_edge(QUERY_REFINED_AGAIN, QUERY_MAKER)
179+
builder.add_edge(GET_TABLE_INFO, QUERY_REFINER)
180+
builder.add_edge(QUERY_REFINER, QUERY_MAKER)
170181

171182
# QUERY_MAKER 노드 후 종료
172183
builder.add_edge(QUERY_MAKER, END)

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ pre_commit==4.1.0
1111
setuptools
1212
wheel
1313
twine
14+
langchain-huggingface==0.1.2
15+
transformers==4.51.2

0 commit comments

Comments
 (0)