Skip to content

Commit a6d7a76

Browse files
committed
feat: retrieval 분리 및 reranker 추가작업
#46, #52
1 parent 755b574 commit a6d7a76

File tree

6 files changed

+154
-40
lines changed

6 files changed

+154
-40
lines changed

interface/lang2sql.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,17 @@ def execute_query(
4646
*,
4747
query: str,
4848
database_env: str,
49+
retriever_name: str = "기본",
50+
top_n: int = 5,
4951
) -> dict:
5052
"""
5153
Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.
5254
5355
Args:
5456
query (str): 자연어로 작성된 사용자 쿼리.
5557
database_env (str): 사용할 데이터베이스 환경 설정 이름.
58+
retriever_name (str): 사용할 검색기 이름.
59+
top_n (int): 검색할 테이블 정보의 개수.
5660
5761
Returns:
5862
dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
@@ -64,6 +68,8 @@ def execute_query(
6468
"messages": [HumanMessage(content=query)],
6569
"user_database_env": database_env,
6670
"best_practice_query": "",
71+
"retriever_name": retriever_name,
72+
"top_n": top_n,
6773
}
6874
)
6975

@@ -123,6 +129,27 @@ def display_result(
123129
index=0,
124130
)
125131

132+
retriever_options = {
133+
"기본": "벡터 검색 (기본)",
134+
"Reranker": "Reranker 검색 (정확도 향상)",
135+
}
136+
137+
user_retriever = st.selectbox(
138+
"검색기 유형을 선택하세요:",
139+
options=list(retriever_options.keys()),
140+
format_func=lambda x: retriever_options[x],
141+
index=0,
142+
)
143+
144+
user_top_n = st.slider(
145+
"검색할 테이블 정보 개수:",
146+
min_value=1,
147+
max_value=20,
148+
value=5,
149+
step=1,
150+
help="검색할 테이블 정보의 개수를 설정합니다. 값이 클수록 더 많은 테이블 정보를 검색하지만 처리 시간이 길어질 수 있습니다.",
151+
)
152+
126153
st.sidebar.title("Output Settings")
127154
for key, label in SIDEBAR_OPTIONS.items():
128155
st.sidebar.checkbox(label, value=True, key=key)
@@ -131,5 +158,7 @@ def display_result(
131158
result = execute_query(
132159
query=user_query,
133160
database_env=user_database_env,
161+
retriever_name=user_retriever,
162+
top_n=user_top_n,
134163
)
135164
display_result(res=result, database=db)

llm_utils/chains.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def create_query_refiner_chain(llm):
2626
[
2727
SystemMessagePromptTemplate.from_template(prompt),
2828
MessagesPlaceholder(variable_name="user_input"),
29+
SystemMessagePromptTemplate.from_template(
30+
"다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:"
31+
),
32+
MessagesPlaceholder(variable_name="searched_tables"),
2933
SystemMessagePromptTemplate.from_template(
3034
"""
3135
위 사용자의 입력을 바탕으로

llm_utils/graph.py

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515

1616
from llm_utils.tools import get_info_from_db
17+
from llm_utils.retrieval import search_tables
1718

1819
# 노드 식별자 정의
1920
QUERY_REFINER = "query_refiner"
@@ -31,6 +32,8 @@ class QueryMakerState(TypedDict):
3132
best_practice_query: str
3233
refined_input: str
3334
generated_query: str
35+
retriever_name: str
36+
top_n: int
3437

3538

3639
# 노드 함수: QUERY_REFINER 노드
@@ -40,6 +43,7 @@ def query_refiner_node(state: QueryMakerState):
4043
"user_input": [state["messages"][0].content],
4144
"user_database_env": [state["user_database_env"]],
4245
"best_practice_query": [state["best_practice_query"]],
46+
"searched_tables": [json.dumps(state["searched_tables"])],
4347
}
4448
)
4549
state["messages"].append(res)
@@ -48,42 +52,12 @@ def query_refiner_node(state: QueryMakerState):
4852

4953

5054
def get_table_info_node(state: QueryMakerState):
51-
from langchain_community.vectorstores import FAISS
52-
from langchain_openai import OpenAIEmbeddings
53-
54-
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
55-
try:
56-
db = FAISS.load_local(
57-
os.getcwd() + "/table_info_db",
58-
embeddings,
59-
allow_dangerous_deserialization=True,
60-
)
61-
except:
62-
documents = get_info_from_db()
63-
db = FAISS.from_documents(documents, embeddings)
64-
db.save_local(os.getcwd() + "/table_info_db")
65-
doc_res = db.similarity_search(state["messages"][-1].content)
66-
documents_dict = {}
67-
68-
for doc in doc_res:
69-
lines = doc.page_content.split("\n")
70-
71-
# 테이블명 및 설명 추출
72-
table_name, table_desc = lines[0].split(": ", 1)
73-
74-
# 컬럼 정보 추출
75-
columns = {}
76-
if len(lines) > 2 and lines[1].strip() == "Columns:":
77-
for line in lines[2:]:
78-
if ": " in line:
79-
col_name, col_desc = line.split(": ", 1)
80-
columns[col_name.strip()] = col_desc.strip()
81-
82-
# 딕셔너리 저장
83-
documents_dict[table_name] = {
84-
"table_description": table_desc.strip(),
85-
**columns, # 컬럼 정보 추가
86-
}
55+
# retriever_name과 top_n을 이용하여 검색 수행
56+
documents_dict = search_tables(
57+
query=state["messages"][0].content,
58+
retriever_name=state["retriever_name"],
59+
top_n=state["top_n"],
60+
)
8761
state["searched_tables"] = documents_dict
8862

8963
return state
@@ -129,19 +103,19 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
129103

130104
# StateGraph 생성 및 구성
131105
builder = StateGraph(QueryMakerState)
132-
builder.set_entry_point(QUERY_REFINER)
106+
builder.set_entry_point(GET_TABLE_INFO)
133107

134108
# 노드 추가
135-
builder.add_node(QUERY_REFINER, query_refiner_node)
136109
builder.add_node(GET_TABLE_INFO, get_table_info_node)
110+
builder.add_node(QUERY_REFINER, query_refiner_node)
137111
builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
138112
# builder.add_node(
139113
# QUERY_MAKER, query_maker_node_with_db_guide
140114
# ) # query_maker_node_with_db_guide
141115

142116
# 기본 엣지 설정
143-
builder.add_edge(QUERY_REFINER, GET_TABLE_INFO)
144-
builder.add_edge(GET_TABLE_INFO, QUERY_MAKER)
117+
builder.add_edge(GET_TABLE_INFO, QUERY_REFINER)
118+
builder.add_edge(QUERY_REFINER, QUERY_MAKER)
145119

146120
# QUERY_MAKER 노드 후 종료
147121
builder.add_edge(QUERY_MAKER, END)

llm_utils/retrieval.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import os
2+
from langchain_community.vectorstores import FAISS
3+
from langchain_openai import OpenAIEmbeddings
4+
from langchain.retrievers import ContextualCompressionRetriever
5+
from langchain.retrievers.document_compressors import CrossEncoderReranker
6+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
7+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
8+
9+
from .tools import get_info_from_db
10+
11+
12+
def get_vector_db():
13+
"""벡터 데이터베이스를 로드하거나 생성합니다."""
14+
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
15+
try:
16+
db = FAISS.load_local(
17+
os.getcwd() + "/table_info_db",
18+
embeddings,
19+
allow_dangerous_deserialization=True,
20+
)
21+
except:
22+
documents = get_info_from_db()
23+
db = FAISS.from_documents(documents, embeddings)
24+
db.save_local(os.getcwd() + "/table_info_db")
25+
print("table_info_db not found")
26+
return db
27+
28+
29+
def load_reranker_model():
30+
"""한국어 reranker 모델을 로드하거나 다운로드합니다."""
31+
local_model_path = os.path.join(os.getcwd(), "ko_reranker_local")
32+
33+
# 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장
34+
if os.path.exists(local_model_path) and os.path.isdir(local_model_path):
35+
print("🔄 ko-reranker 모델 로컬에서 로드 중...")
36+
else:
37+
print("⬇️ ko-reranker 모델 다운로드 및 저장 중...")
38+
model = AutoModelForSequenceClassification.from_pretrained(
39+
"Dongjin-kr/ko-reranker",
40+
)
41+
tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker")
42+
model.save_pretrained(local_model_path)
43+
tokenizer.save_pretrained(local_model_path)
44+
45+
return HuggingFaceCrossEncoder(model_name=local_model_path)
46+
47+
48+
def get_retriever(retriever_name: str = "기본", top_n: int = 5):
49+
"""검색기 타입에 따라 적절한 검색기를 생성합니다.
50+
51+
Args:
52+
retriever_name: 사용할 검색기 이름 ("기본", "재순위", 등)
53+
top_n: 반환할 상위 결과 개수
54+
"""
55+
retrievers = {
56+
"기본": lambda: get_vector_db().as_retriever(search_kwargs={"k": top_n}),
57+
"Reranker": lambda: ContextualCompressionRetriever(
58+
base_compressor=CrossEncoderReranker(
59+
model=load_reranker_model(), top_n=top_n
60+
),
61+
base_retriever=get_vector_db().as_retriever(search_kwargs={"k": top_n}),
62+
),
63+
}
64+
65+
if retriever_name not in retrievers:
66+
print(
67+
f"경고: '{retriever_name}' 검색기를 찾을 수 없습니다. 기본 검색기를 사용합니다."
68+
)
69+
retriever_name = "기본"
70+
71+
return retrievers[retriever_name]()
72+
73+
74+
def search_tables(query: str, retriever_name: str = "기본", top_n: int = 5):
75+
"""쿼리에 맞는 테이블 정보를 검색합니다."""
76+
if retriever_name == "기본":
77+
db = get_vector_db()
78+
doc_res = db.similarity_search(query, k=top_n)
79+
else:
80+
retriever = get_retriever(retriever_name=retriever_name, top_n=top_n)
81+
doc_res = retriever.invoke(query)
82+
83+
# 결과를 사전 형태로 변환
84+
documents_dict = {}
85+
for doc in doc_res:
86+
lines = doc.page_content.split("\n")
87+
88+
# 테이블명 및 설명 추출
89+
table_name, table_desc = lines[0].split(": ", 1)
90+
91+
# 컬럼 정보 추출
92+
columns = {}
93+
if len(lines) > 2 and lines[1].strip() == "Columns:":
94+
for line in lines[2:]:
95+
if ": " in line:
96+
col_name, col_desc = line.split(": ", 1)
97+
columns[col_name.strip()] = col_desc.strip()
98+
99+
# 딕셔너리 저장
100+
documents_dict[table_name] = {
101+
"table_description": table_desc.strip(),
102+
**columns, # 컬럼 정보 추가
103+
}
104+
105+
return documents_dict

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ pre_commit==4.1.0
1111
setuptools
1212
wheel
1313
twine
14+
transformers==4.51.2

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"langchain-google-genai>=2.1.3,<3.0.0",
2929
"langchain-ollama>=0.3.2,<0.4.0",
3030
"langchain-huggingface>=0.1.2,<0.2.0",
31+
"transformers==4.51.2",
3132
],
3233
entry_points={
3334
"console_scripts": [

0 commit comments

Comments
 (0)