1010
1111from llm_utils .chains import (
1212 query_refiner_chain ,
13- query_redefined_again_chain ,
1413 query_maker_chain ,
1514)
1615
1716from llm_utils .tools import get_info_from_db
1817
1918# 노드 식별자 정의
2019QUERY_REFINER = "query_refiner"
21- QUERY_REFINED_AGAIN = "query_redefined_again"
2220GET_TABLE_INFO = "get_table_info"
2321TOOL = "tool"
2422TABLE_FILTER = "table_filter"
@@ -32,7 +30,6 @@ class QueryMakerState(TypedDict):
3230 searched_tables : dict [str , dict [str , str ]]
3331 best_practice_query : str
3432 refined_input : str
35- refined_input_again : str
3633 generated_query : str
3734
3835
@@ -43,6 +40,7 @@ def query_refiner_node(state: QueryMakerState):
4340 "user_input" : [state ["messages" ][0 ].content ],
4441 "user_database_env" : [state ["user_database_env" ]],
4542 "best_practice_query" : [state ["best_practice_query" ]],
43+ "searched_tables" : [json .dumps (state ["searched_tables" ])],
4644 }
4745 )
4846 state ["messages" ].append (res )
@@ -66,9 +64,42 @@ def get_table_info_node(state: QueryMakerState):
6664 db = FAISS .from_documents (documents , embeddings )
6765 db .save_local (os .getcwd () + "/table_info_db" )
6866 print ("table_info_db not found" )
69- doc_res = db .similarity_search (state ["messages" ][- 1 ].content )
70- documents_dict = {}
7167
68+ retriever = db .as_retriever (search_kwargs = {"k" : 10 })
69+
70+ from langchain .retrievers import ContextualCompressionRetriever
71+ from langchain .retrievers .document_compressors import CrossEncoderReranker
72+ from langchain_community .cross_encoders import HuggingFaceCrossEncoder
73+ from transformers import AutoModelForSequenceClassification , AutoTokenizer
74+
75+ # Reranking 적용 여부 설정
76+ use_rerank = True # 필요에 따라 True 또는 False로 설정
77+
78+ if use_rerank :
79+ local_model_path = os .path .join (os .getcwd (), "ko_reranker_local" )
80+
81+ # 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장
82+ if os .path .exists (local_model_path ) and os .path .isdir (local_model_path ):
83+ print ("🔄 ko-reranker 모델 로컬에서 로드 중..." )
84+ else :
85+ print ("⬇️ ko-reranker 모델 다운로드 및 저장 중..." )
86+ model = AutoModelForSequenceClassification .from_pretrained (
87+ "Dongjin-kr/ko-reranker"
88+ )
89+ tokenizer = AutoTokenizer .from_pretrained ("Dongjin-kr/ko-reranker" )
90+ model .save_pretrained (local_model_path )
91+ tokenizer .save_pretrained (local_model_path )
92+ model = HuggingFaceCrossEncoder (model_name = local_model_path )
93+ compressor = CrossEncoderReranker (model = model , top_n = 3 )
94+ retriever = db .as_retriever (search_kwargs = {"k" : 10 })
95+ compression_retriever = ContextualCompressionRetriever (
96+ base_compressor = compressor , base_retriever = retriever
97+ )
98+
99+ doc_res = compression_retriever .invoke (state ["messages" ][0 ].content )
100+ else : # Reranking 미적용
101+ doc_res = db .similarity_search (state ["messages" ][0 ].content , k = 10 )
102+ documents_dict = {}
72103 for doc in doc_res :
73104 lines = doc .page_content .split ("\n " )
74105
@@ -93,19 +124,6 @@ def get_table_info_node(state: QueryMakerState):
93124 return state
94125
95126
96- def query_redefined_again_node (state : QueryMakerState ):
97- res = query_redefined_again_chain .invoke (
98- input = {
99- "user_input" : [state ["messages" ][0 ].content ],
100- "refined_input" : [state ["refined_input" ]],
101- "user_database_env" : [state ["user_database_env" ]],
102- "searched_tables" : [json .dumps (state ["searched_tables" ])],
103- }
104- )
105- state ["refined_input_again" ] = res
106- return state
107-
108-
109127# 노드 함수: QUERY_MAKER 노드
110128def query_maker_node (state : QueryMakerState ):
111129 res = query_maker_chain .invoke (
@@ -137,9 +155,7 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
137155 res = chain .invoke (
138156 input = {
139157 "input" : "\n \n ---\n \n " .join (
140- [state ["messages" ][0 ].content ]
141- # + [state["refined_input"].content]
142- + [state ["refined_input_again" ].content ]
158+ [state ["messages" ][0 ].content ] + [state ["refined_input" ].content ]
143159 ),
144160 "table_info" : [json .dumps (state ["searched_tables" ])],
145161 "top_k" : 10 ,
@@ -152,21 +168,16 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
152168
153169# StateGraph 생성 및 구성
154170builder = StateGraph (QueryMakerState )
155- builder .set_entry_point (QUERY_REFINER )
171+ builder .set_entry_point (GET_TABLE_INFO )
156172
157173# 노드 추가
158- builder .add_node (QUERY_REFINER , query_refiner_node )
159174builder .add_node (GET_TABLE_INFO , get_table_info_node )
160- # builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
161- builder .add_node (
162- QUERY_MAKER , query_maker_node_with_db_guide
163- ) # query_maker_node_with_db_guide
164- builder .add_node (QUERY_REFINED_AGAIN , query_redefined_again_node )
175+ builder .add_node (QUERY_REFINER , query_refiner_node )
176+ builder .add_node (QUERY_MAKER , query_maker_node_with_db_guide )
165177
166178# 기본 엣지 설정
167- builder .add_edge (QUERY_REFINER , GET_TABLE_INFO )
168- builder .add_edge (GET_TABLE_INFO , QUERY_REFINED_AGAIN )
169- builder .add_edge (QUERY_REFINED_AGAIN , QUERY_MAKER )
179+ builder .add_edge (GET_TABLE_INFO , QUERY_REFINER )
180+ builder .add_edge (QUERY_REFINER , QUERY_MAKER )
170181
171182# QUERY_MAKER 노드 후 종료
172183builder .add_edge (QUERY_MAKER , END )
0 commit comments