diff --git a/interface/lang2sql.py b/interface/lang2sql.py index b7a5905..8704d95 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -46,6 +46,9 @@ def execute_query( *, query: str, database_env: str, + retriever_name: str = "기본", + top_n: int = 5, + device: str = "cpu", ) -> dict: """ Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다. @@ -53,6 +56,8 @@ def execute_query( Args: query (str): 자연어로 작성된 사용자 쿼리. database_env (str): 사용할 데이터베이스 환경 설정 이름. + retriever_name (str): 사용할 검색기 이름. + top_n (int): 검색할 테이블 정보의 개수. Returns: dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리. @@ -64,6 +69,9 @@ def execute_query( "messages": [HumanMessage(content=query)], "user_database_env": database_env, "best_practice_query": "", + "retriever_name": retriever_name, + "top_n": top_n, + "device": device, } ) @@ -123,6 +131,33 @@ def display_result( index=0, ) +device = st.selectbox( + "모델 실행 장치를 선택하세요:", + options=["cpu", "cuda"], + index=0, +) + +retriever_options = { + "기본": "벡터 검색 (기본)", + "Reranker": "Reranker 검색 (정확도 향상)", +} + +user_retriever = st.selectbox( + "검색기 유형을 선택하세요:", + options=list(retriever_options.keys()), + format_func=lambda x: retriever_options[x], + index=0, +) + +user_top_n = st.slider( + "검색할 테이블 정보 개수:", + min_value=1, + max_value=20, + value=5, + step=1, + help="검색할 테이블 정보의 개수를 설정합니다. 값이 클수록 더 많은 테이블 정보를 검색하지만 처리 시간이 길어질 수 있습니다.", +) + st.sidebar.title("Output Settings") for key, label in SIDEBAR_OPTIONS.items(): st.sidebar.checkbox(label, value=True, key=key) @@ -131,5 +166,8 @@ def display_result( result = execute_query( query=user_query, database_env=user_database_env, + retriever_name=user_retriever, + top_n=user_top_n, + device=device, ) display_result(res=result, database=db) diff --git a/llm_utils/chains.py b/llm_utils/chains.py index 81d957e..a0a5f27 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -26,6 +26,10 @@ def create_query_refiner_chain(llm): [ SystemMessagePromptTemplate.from_template(prompt), MessagesPlaceholder(variable_name="user_input"), + SystemMessagePromptTemplate.from_template( + "다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:" + ), + MessagesPlaceholder(variable_name="searched_tables"), SystemMessagePromptTemplate.from_template( """ 위 사용자의 입력을 바탕으로 diff --git a/llm_utils/graph.py b/llm_utils/graph.py index a6f5137..69a10b9 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -14,6 +14,7 @@ ) from llm_utils.tools import get_info_from_db +from llm_utils.retrieval import search_tables # 노드 식별자 정의 QUERY_REFINER = "query_refiner" @@ -31,6 +32,9 @@ class QueryMakerState(TypedDict): best_practice_query: str refined_input: str generated_query: str + retriever_name: str + top_n: int + device: str # 노드 함수: QUERY_REFINER 노드 @@ -40,6 +44,7 @@ def query_refiner_node(state: QueryMakerState): "user_input": [state["messages"][0].content], "user_database_env": [state["user_database_env"]], "best_practice_query": [state["best_practice_query"]], + "searched_tables": [json.dumps(state["searched_tables"])], } ) state["messages"].append(res) @@ -48,42 +53,13 @@ def query_refiner_node(state: QueryMakerState): def get_table_info_node(state: QueryMakerState): - from langchain_community.vectorstores import FAISS - from langchain_openai import OpenAIEmbeddings - - embeddings = OpenAIEmbeddings(model="text-embedding-3-small") - try: - db = FAISS.load_local( - os.getcwd() + "/table_info_db", - embeddings, - allow_dangerous_deserialization=True, - ) - except: - documents = get_info_from_db() - db = FAISS.from_documents(documents, embeddings) - db.save_local(os.getcwd() + "/table_info_db") - doc_res = db.similarity_search(state["messages"][-1].content) - documents_dict = {} - - for doc in doc_res: - lines = doc.page_content.split("\n") - - # 테이블명 및 설명 추출 - table_name, table_desc = lines[0].split(": ", 1) - - # 컬럼 정보 추출 - columns = {} - if len(lines) > 2 and lines[1].strip() == "Columns:": - for line in lines[2:]: - if ": " in line: - col_name, col_desc = line.split(": ", 1) - columns[col_name.strip()] = col_desc.strip() - - # 딕셔너리 저장 - documents_dict[table_name] = { - "table_description": table_desc.strip(), - **columns, # 컬럼 정보 추가 - } + # retriever_name과 top_n을 이용하여 검색 수행 + documents_dict = search_tables( + query=state["messages"][0].content, + retriever_name=state["retriever_name"], + top_n=state["top_n"], + device=state["device"], + ) state["searched_tables"] = documents_dict return state @@ -129,19 +105,19 @@ def query_maker_node_with_db_guide(state: QueryMakerState): # StateGraph 생성 및 구성 builder = StateGraph(QueryMakerState) -builder.set_entry_point(QUERY_REFINER) +builder.set_entry_point(GET_TABLE_INFO) # 노드 추가 -builder.add_node(QUERY_REFINER, query_refiner_node) builder.add_node(GET_TABLE_INFO, get_table_info_node) +builder.add_node(QUERY_REFINER, query_refiner_node) builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide # builder.add_node( # QUERY_MAKER, query_maker_node_with_db_guide # ) # query_maker_node_with_db_guide # 기본 엣지 설정 -builder.add_edge(QUERY_REFINER, GET_TABLE_INFO) -builder.add_edge(GET_TABLE_INFO, QUERY_MAKER) +builder.add_edge(GET_TABLE_INFO, QUERY_REFINER) +builder.add_edge(QUERY_REFINER, QUERY_MAKER) # QUERY_MAKER 노드 후 종료 builder.add_edge(QUERY_MAKER, END) diff --git a/llm_utils/retrieval.py b/llm_utils/retrieval.py new file mode 100644 index 0000000..728141f --- /dev/null +++ b/llm_utils/retrieval.py @@ -0,0 +1,113 @@ +import os +from langchain_community.vectorstores import FAISS +from langchain_openai import OpenAIEmbeddings +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import CrossEncoderReranker +from langchain_community.cross_encoders import HuggingFaceCrossEncoder +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from .tools import get_info_from_db + + +def get_vector_db(): + """벡터 데이터베이스를 로드하거나 생성합니다.""" + embeddings = OpenAIEmbeddings(model="text-embedding-3-small") + try: + db = FAISS.load_local( + os.getcwd() + "/table_info_db", + embeddings, + allow_dangerous_deserialization=True, + ) + except: + documents = get_info_from_db() + db = FAISS.from_documents(documents, embeddings) + db.save_local(os.getcwd() + "/table_info_db") + print("table_info_db not found") + return db + + +def load_reranker_model(device: str = "cpu"): + """한국어 reranker 모델을 로드하거나 다운로드합니다.""" + local_model_path = os.path.join(os.getcwd(), "ko_reranker_local") + + # 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장 + if os.path.exists(local_model_path) and os.path.isdir(local_model_path): + print("🔄 ko-reranker 모델 로컬에서 로드 중...") + else: + print("⬇️ ko-reranker 모델 다운로드 및 저장 중...") + model = AutoModelForSequenceClassification.from_pretrained( + "Dongjin-kr/ko-reranker" + ) + tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker") + model.save_pretrained(local_model_path) + tokenizer.save_pretrained(local_model_path) + + return HuggingFaceCrossEncoder( + model_name=local_model_path, + model_kwargs={"device": device}, + ) + + +def get_retriever(retriever_name: str = "기본", top_n: int = 5, device: str = "cpu"): + """검색기 타입에 따라 적절한 검색기를 생성합니다. + + Args: + retriever_name: 사용할 검색기 이름 ("기본", "재순위", 등) + top_n: 반환할 상위 결과 개수 + """ + print(device) + retrievers = { + "기본": lambda: get_vector_db().as_retriever(search_kwargs={"k": top_n}), + "Reranker": lambda: ContextualCompressionRetriever( + base_compressor=CrossEncoderReranker( + model=load_reranker_model(device), top_n=top_n + ), + base_retriever=get_vector_db().as_retriever(search_kwargs={"k": top_n}), + ), + } + + if retriever_name not in retrievers: + print( + f"경고: '{retriever_name}' 검색기를 찾을 수 없습니다. 기본 검색기를 사용합니다." + ) + retriever_name = "기본" + + return retrievers[retriever_name]() + + +def search_tables( + query: str, retriever_name: str = "기본", top_n: int = 5, device: str = "cpu" +): + """쿼리에 맞는 테이블 정보를 검색합니다.""" + if retriever_name == "기본": + db = get_vector_db() + doc_res = db.similarity_search(query, k=top_n) + else: + retriever = get_retriever( + retriever_name=retriever_name, top_n=top_n, device=device + ) + doc_res = retriever.invoke(query) + + # 결과를 사전 형태로 변환 + documents_dict = {} + for doc in doc_res: + lines = doc.page_content.split("\n") + + # 테이블명 및 설명 추출 + table_name, table_desc = lines[0].split(": ", 1) + + # 컬럼 정보 추출 + columns = {} + if len(lines) > 2 and lines[1].strip() == "Columns:": + for line in lines[2:]: + if ": " in line: + col_name, col_desc = line.split(": ", 1) + columns[col_name.strip()] = col_desc.strip() + + # 딕셔너리 저장 + documents_dict[table_name] = { + "table_description": table_desc.strip(), + **columns, # 컬럼 정보 추가 + } + + return documents_dict diff --git a/requirements.txt b/requirements.txt index b7f912a..0f1ecb0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ pre_commit==4.1.0 setuptools wheel twine +transformers==4.51.2 langchain-aws>=0.2.21,<0.3.0 langchain-google-genai>=2.1.3,<3.0.0 langchain-ollama>=0.3.2,<0.4.0 diff --git a/setup.py b/setup.py index 71a31ac..78f0612 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "langchain-google-genai>=2.1.3,<3.0.0", "langchain-ollama>=0.3.2,<0.4.0", "langchain-huggingface>=0.1.2,<0.2.0", + "transformers==4.51.2", ], entry_points={ "console_scripts": [