1414)
1515
1616from llm_utils .tools import get_info_from_db
17+ from llm_utils .retrieval import search_tables
1718
1819# 노드 식별자 정의
1920QUERY_REFINER = "query_refiner"
@@ -31,6 +32,7 @@ class QueryMakerState(TypedDict):
3132 best_practice_query : str
3233 refined_input : str
3334 generated_query : str
35+ use_rerank : bool
3436
3537
3638# 노드 함수: QUERY_REFINER 노드
@@ -49,76 +51,10 @@ def query_refiner_node(state: QueryMakerState):
4951
5052
5153def get_table_info_node (state : QueryMakerState ):
52- from langchain_community .vectorstores import FAISS
53- from langchain_openai import OpenAIEmbeddings
54-
55- embeddings = OpenAIEmbeddings (model = "text-embedding-3-small" )
56- try :
57- db = FAISS .load_local (
58- os .getcwd () + "/table_info_db" ,
59- embeddings ,
60- allow_dangerous_deserialization = True ,
61- )
62- except :
63- documents = get_info_from_db ()
64- db = FAISS .from_documents (documents , embeddings )
65- db .save_local (os .getcwd () + "/table_info_db" )
66- print ("table_info_db not found" )
67-
68- retriever = db .as_retriever (search_kwargs = {"k" : 10 })
69-
70- from langchain .retrievers import ContextualCompressionRetriever
71- from langchain .retrievers .document_compressors import CrossEncoderReranker
72- from langchain_community .cross_encoders import HuggingFaceCrossEncoder
73- from transformers import AutoModelForSequenceClassification , AutoTokenizer
74-
75- # Reranking 적용 여부 설정
76- use_rerank = True # 필요에 따라 True 또는 False로 설정
77-
78- if use_rerank :
79- local_model_path = os .path .join (os .getcwd (), "ko_reranker_local" )
80-
81- # 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장
82- if os .path .exists (local_model_path ) and os .path .isdir (local_model_path ):
83- print ("🔄 ko-reranker 모델 로컬에서 로드 중..." )
84- else :
85- print ("⬇️ ko-reranker 모델 다운로드 및 저장 중..." )
86- model = AutoModelForSequenceClassification .from_pretrained (
87- "Dongjin-kr/ko-reranker"
88- )
89- tokenizer = AutoTokenizer .from_pretrained ("Dongjin-kr/ko-reranker" )
90- model .save_pretrained (local_model_path )
91- tokenizer .save_pretrained (local_model_path )
92- model = HuggingFaceCrossEncoder (model_name = local_model_path )
93- compressor = CrossEncoderReranker (model = model , top_n = 3 )
94- retriever = db .as_retriever (search_kwargs = {"k" : 10 })
95- compression_retriever = ContextualCompressionRetriever (
96- base_compressor = compressor , base_retriever = retriever
97- )
98-
99- doc_res = compression_retriever .invoke (state ["messages" ][0 ].content )
100- else : # Reranking 미적용
101- doc_res = db .similarity_search (state ["messages" ][0 ].content , k = 10 )
102- documents_dict = {}
103- for doc in doc_res :
104- lines = doc .page_content .split ("\n " )
105-
106- # 테이블명 및 설명 추출
107- table_name , table_desc = lines [0 ].split (": " , 1 )
108-
109- # 컬럼 정보 추출
110- columns = {}
111- if len (lines ) > 2 and lines [1 ].strip () == "Columns:" :
112- for line in lines [2 :]:
113- if ": " in line :
114- col_name , col_desc = line .split (": " , 1 )
115- columns [col_name .strip ()] = col_desc .strip ()
116-
117- # 딕셔너리 저장
118- documents_dict [table_name ] = {
119- "table_description" : table_desc .strip (),
120- ** columns , # 컬럼 정보 추가
121- }
54+ # state의 use_rerank 값을 이용하여 검색 수행
55+ documents_dict = search_tables (
56+ state ["messages" ][0 ].content , use_rerank = state ["use_rerank" ]
57+ )
12258 state ["searched_tables" ] = documents_dict
12359
12460 return state
0 commit comments