1414)
1515
1616from llm_utils .tools import get_info_from_db
17+ from llm_utils .retrieval import search_tables
1718
1819# 노드 식별자 정의
1920QUERY_REFINER = "query_refiner"
@@ -31,6 +32,8 @@ class QueryMakerState(TypedDict):
3132 best_practice_query : str
3233 refined_input : str
3334 generated_query : str
35+ retriever_name : str
36+ top_n : int
3437
3538
3639# 노드 함수: QUERY_REFINER 노드
@@ -40,6 +43,7 @@ def query_refiner_node(state: QueryMakerState):
4043 "user_input" : [state ["messages" ][0 ].content ],
4144 "user_database_env" : [state ["user_database_env" ]],
4245 "best_practice_query" : [state ["best_practice_query" ]],
46+ "searched_tables" : [json .dumps (state ["searched_tables" ])],
4347 }
4448 )
4549 state ["messages" ].append (res )
@@ -48,42 +52,12 @@ def query_refiner_node(state: QueryMakerState):
4852
4953
5054def get_table_info_node (state : QueryMakerState ):
51- from langchain_community .vectorstores import FAISS
52- from langchain_openai import OpenAIEmbeddings
53-
54- embeddings = OpenAIEmbeddings (model = "text-embedding-3-small" )
55- try :
56- db = FAISS .load_local (
57- os .getcwd () + "/table_info_db" ,
58- embeddings ,
59- allow_dangerous_deserialization = True ,
60- )
61- except :
62- documents = get_info_from_db ()
63- db = FAISS .from_documents (documents , embeddings )
64- db .save_local (os .getcwd () + "/table_info_db" )
65- doc_res = db .similarity_search (state ["messages" ][- 1 ].content )
66- documents_dict = {}
67-
68- for doc in doc_res :
69- lines = doc .page_content .split ("\n " )
70-
71- # 테이블명 및 설명 추출
72- table_name , table_desc = lines [0 ].split (": " , 1 )
73-
74- # 컬럼 정보 추출
75- columns = {}
76- if len (lines ) > 2 and lines [1 ].strip () == "Columns:" :
77- for line in lines [2 :]:
78- if ": " in line :
79- col_name , col_desc = line .split (": " , 1 )
80- columns [col_name .strip ()] = col_desc .strip ()
81-
82- # 딕셔너리 저장
83- documents_dict [table_name ] = {
84- "table_description" : table_desc .strip (),
85- ** columns , # 컬럼 정보 추가
86- }
55+ # retriever_name과 top_n을 이용하여 검색 수행
56+ documents_dict = search_tables (
57+ query = state ["messages" ][0 ].content ,
58+ retriever_name = state ["retriever_name" ],
59+ top_n = state ["top_n" ],
60+ )
8761 state ["searched_tables" ] = documents_dict
8862
8963 return state
@@ -129,19 +103,19 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
129103
130104# StateGraph 생성 및 구성
131105builder = StateGraph (QueryMakerState )
132- builder .set_entry_point (QUERY_REFINER )
106+ builder .set_entry_point (GET_TABLE_INFO )
133107
134108# 노드 추가
135- builder .add_node (QUERY_REFINER , query_refiner_node )
136109builder .add_node (GET_TABLE_INFO , get_table_info_node )
110+ builder .add_node (QUERY_REFINER , query_refiner_node )
137111builder .add_node (QUERY_MAKER , query_maker_node ) # query_maker_node_with_db_guide
138112# builder.add_node(
139113# QUERY_MAKER, query_maker_node_with_db_guide
140114# ) # query_maker_node_with_db_guide
141115
142116# 기본 엣지 설정
143- builder .add_edge (QUERY_REFINER , GET_TABLE_INFO )
144- builder .add_edge (GET_TABLE_INFO , QUERY_MAKER )
117+ builder .add_edge (GET_TABLE_INFO , QUERY_REFINER )
118+ builder .add_edge (QUERY_REFINER , QUERY_MAKER )
145119
146120# QUERY_MAKER 노드 후 종료
147121builder .add_edge (QUERY_MAKER , END )
0 commit comments