Skip to content

Commit 5244fda

Browse files
committed
Add reranking functionality and new dependencies
- add reranking feature in the Streamlit app to enhance search result accuracy. - Added new dependencies: transformers==4.51.2 and langchain-huggingface==0.1.2 to setup.py. - Created a new retrieval module to handle vector database interactions and reranking logic.
1 parent 335fd9d commit 5244fda

File tree

4 files changed

+109
-70
lines changed

4 files changed

+109
-70
lines changed

interface/streamlit_app.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def summarize_total_tokens(data):
2828
return total_tokens
2929

3030

31+
use_reranker = st.checkbox(
32+
"리랭킹(Reranking) 기능 사용",
33+
value=False,
34+
help="검색 결과의 정확도를 높이기 위한 리랭킹 기능을 사용합니다.",
35+
)
36+
3137
# 버튼 클릭 시 실행
3238
if st.button("쿼리 실행"):
3339
# 그래프 컴파일 및 쿼리 실행
@@ -38,6 +44,7 @@ def summarize_total_tokens(data):
3844
"messages": [HumanMessage(content=user_query)],
3945
"user_database_env": user_database_env,
4046
"best_practice_query": "",
47+
"use_rerank": use_reranker,
4148
}
4249
)
4350
total_tokens = summarize_total_tokens(res["messages"])

llm_utils/graph.py

Lines changed: 6 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515

1616
from llm_utils.tools import get_info_from_db
17+
from llm_utils.retrieval import search_tables
1718

1819
# 노드 식별자 정의
1920
QUERY_REFINER = "query_refiner"
@@ -31,6 +32,7 @@ class QueryMakerState(TypedDict):
3132
best_practice_query: str
3233
refined_input: str
3334
generated_query: str
35+
use_rerank: bool
3436

3537

3638
# 노드 함수: QUERY_REFINER 노드
@@ -49,76 +51,10 @@ def query_refiner_node(state: QueryMakerState):
4951

5052

5153
def get_table_info_node(state: QueryMakerState):
52-
from langchain_community.vectorstores import FAISS
53-
from langchain_openai import OpenAIEmbeddings
54-
55-
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
56-
try:
57-
db = FAISS.load_local(
58-
os.getcwd() + "/table_info_db",
59-
embeddings,
60-
allow_dangerous_deserialization=True,
61-
)
62-
except:
63-
documents = get_info_from_db()
64-
db = FAISS.from_documents(documents, embeddings)
65-
db.save_local(os.getcwd() + "/table_info_db")
66-
print("table_info_db not found")
67-
68-
retriever = db.as_retriever(search_kwargs={"k": 10})
69-
70-
from langchain.retrievers import ContextualCompressionRetriever
71-
from langchain.retrievers.document_compressors import CrossEncoderReranker
72-
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
73-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
74-
75-
# Reranking 적용 여부 설정
76-
use_rerank = True # 필요에 따라 True 또는 False로 설정
77-
78-
if use_rerank:
79-
local_model_path = os.path.join(os.getcwd(), "ko_reranker_local")
80-
81-
# 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장
82-
if os.path.exists(local_model_path) and os.path.isdir(local_model_path):
83-
print("🔄 ko-reranker 모델 로컬에서 로드 중...")
84-
else:
85-
print("⬇️ ko-reranker 모델 다운로드 및 저장 중...")
86-
model = AutoModelForSequenceClassification.from_pretrained(
87-
"Dongjin-kr/ko-reranker"
88-
)
89-
tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker")
90-
model.save_pretrained(local_model_path)
91-
tokenizer.save_pretrained(local_model_path)
92-
model = HuggingFaceCrossEncoder(model_name=local_model_path)
93-
compressor = CrossEncoderReranker(model=model, top_n=3)
94-
retriever = db.as_retriever(search_kwargs={"k": 10})
95-
compression_retriever = ContextualCompressionRetriever(
96-
base_compressor=compressor, base_retriever=retriever
97-
)
98-
99-
doc_res = compression_retriever.invoke(state["messages"][0].content)
100-
else: # Reranking 미적용
101-
doc_res = db.similarity_search(state["messages"][0].content, k=10)
102-
documents_dict = {}
103-
for doc in doc_res:
104-
lines = doc.page_content.split("\n")
105-
106-
# 테이블명 및 설명 추출
107-
table_name, table_desc = lines[0].split(": ", 1)
108-
109-
# 컬럼 정보 추출
110-
columns = {}
111-
if len(lines) > 2 and lines[1].strip() == "Columns:":
112-
for line in lines[2:]:
113-
if ": " in line:
114-
col_name, col_desc = line.split(": ", 1)
115-
columns[col_name.strip()] = col_desc.strip()
116-
117-
# 딕셔너리 저장
118-
documents_dict[table_name] = {
119-
"table_description": table_desc.strip(),
120-
**columns, # 컬럼 정보 추가
121-
}
54+
# state의 use_rerank 값을 이용하여 검색 수행
55+
documents_dict = search_tables(
56+
state["messages"][0].content, use_rerank=state["use_rerank"]
57+
)
12258
state["searched_tables"] = documents_dict
12359

12460
return state

llm_utils/retrieval.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os
2+
from langchain_community.vectorstores import FAISS
3+
from langchain_openai import OpenAIEmbeddings
4+
from langchain.retrievers import ContextualCompressionRetriever
5+
from langchain.retrievers.document_compressors import CrossEncoderReranker
6+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
7+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
8+
9+
from .tools import get_info_from_db
10+
11+
12+
def get_vector_db():
13+
"""벡터 데이터베이스를 로드하거나 생성합니다."""
14+
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
15+
try:
16+
db = FAISS.load_local(
17+
os.getcwd() + "/table_info_db",
18+
embeddings,
19+
allow_dangerous_deserialization=True,
20+
)
21+
except:
22+
documents = get_info_from_db()
23+
db = FAISS.from_documents(documents, embeddings)
24+
db.save_local(os.getcwd() + "/table_info_db")
25+
print("table_info_db not found")
26+
return db
27+
28+
29+
def load_reranker_model():
30+
"""한국어 reranker 모델을 로드하거나 다운로드합니다."""
31+
local_model_path = os.path.join(os.getcwd(), "ko_reranker_local")
32+
33+
# 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장
34+
if os.path.exists(local_model_path) and os.path.isdir(local_model_path):
35+
print("🔄 ko-reranker 모델 로컬에서 로드 중...")
36+
else:
37+
print("⬇️ ko-reranker 모델 다운로드 및 저장 중...")
38+
model = AutoModelForSequenceClassification.from_pretrained(
39+
"Dongjin-kr/ko-reranker"
40+
)
41+
tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker")
42+
model.save_pretrained(local_model_path)
43+
tokenizer.save_pretrained(local_model_path)
44+
45+
return HuggingFaceCrossEncoder(model_name=local_model_path)
46+
47+
48+
def get_retriever(use_rerank=False):
49+
"""검색기를 생성합니다. use_rerank가 True이면 reranking을 적용합니다."""
50+
db = get_vector_db()
51+
retriever = db.as_retriever(search_kwargs={"k": 10})
52+
53+
if use_rerank:
54+
model = load_reranker_model()
55+
compressor = CrossEncoderReranker(model=model, top_n=3)
56+
return ContextualCompressionRetriever(
57+
base_compressor=compressor, base_retriever=retriever
58+
)
59+
else:
60+
return retriever
61+
62+
63+
def search_tables(query, use_rerank=False):
64+
"""쿼리에 맞는 테이블 정보를 검색합니다."""
65+
if use_rerank:
66+
retriever = get_retriever(use_rerank=True)
67+
doc_res = retriever.invoke(query)
68+
else:
69+
db = get_vector_db()
70+
doc_res = db.similarity_search(query, k=10)
71+
72+
# 결과를 사전 형태로 변환
73+
documents_dict = {}
74+
for doc in doc_res:
75+
lines = doc.page_content.split("\n")
76+
77+
# 테이블명 및 설명 추출
78+
table_name, table_desc = lines[0].split(": ", 1)
79+
80+
# 컬럼 정보 추출
81+
columns = {}
82+
if len(lines) > 2 and lines[1].strip() == "Columns:":
83+
for line in lines[2:]:
84+
if ": " in line:
85+
col_name, col_desc = line.split(": ", 1)
86+
columns[col_name.strip()] = col_desc.strip()
87+
88+
# 딕셔너리 저장
89+
documents_dict[table_name] = {
90+
"table_description": table_desc.strip(),
91+
**columns, # 컬럼 정보 추가
92+
}
93+
94+
return documents_dict

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
"streamlit==1.41.1",
2525
"python-dotenv==1.0.1",
2626
"faiss-cpu==1.10.0",
27+
"transformers==4.51.2",
28+
"langchain-huggingface==0.1.2",
2729
],
2830
entry_points={
2931
"console_scripts": [

0 commit comments

Comments
 (0)