diff --git a/interface/streamlit_app.py b/interface/streamlit_app.py index 395b2b7..7846eed 100644 --- a/interface/streamlit_app.py +++ b/interface/streamlit_app.py @@ -28,6 +28,12 @@ def summarize_total_tokens(data): return total_tokens +use_reranker = st.checkbox( + "리랭킹(Reranking) 기능 사용", + value=False, + help="검색 결과의 정확도를 높이기 위한 리랭킹 기능을 사용합니다.", +) + # 버튼 클릭 시 실행 if st.button("쿼리 실행"): # 그래프 컴파일 및 쿼리 실행 @@ -38,6 +44,7 @@ def summarize_total_tokens(data): "messages": [HumanMessage(content=user_query)], "user_database_env": user_database_env, "best_practice_query": "", + "use_rerank": use_reranker, } ) total_tokens = summarize_total_tokens(res["messages"]) diff --git a/llm_utils/chains.py b/llm_utils/chains.py index d9e5e6c..9806721 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -40,23 +40,28 @@ def create_query_refiner_chain(llm): 예시: 사용자가 "유저 이탈 원인이 궁금해요"라고 했다면, 재질문 형식이 아니라 - "최근 1개월 간의 접속·결제 로그를 기준으로, + "접속·결제 로그를 기준으로, 주로 어떤 사용자가 어떤 과정을 거쳐 이탈하는지를 분석해야 한다"처럼 분석 방향이 명확해진 질문 한 문장(또는 한 문단)으로 정리해 주세요. 최종 출력 형식 예시: ------------------------------ 구체화된 질문: - "최근 1개월 동안 고액 결제 경험이 있는 유저가 + "고액 결제 경험이 있는 유저가 행동 로그에서 이탈 전 어떤 패턴을 보였는지 분석" 가정한 조건: - - 최근 1개월치 행동 로그와 결제 로그 중심 + - 행동 로그와 결제 로그 중심 - 고액 결제자(월 결제액 10만 원 이상) 그룹 대상으로 한정 ------------------------------ """, ), MessagesPlaceholder(variable_name="user_input"), + ( + "system", + "다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:", + ), + MessagesPlaceholder(variable_name="searched_tables"), ( "system", """ diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 0aef51d..112b25f 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -14,6 +14,7 @@ ) from llm_utils.tools import get_info_from_db +from llm_utils.retrieval import search_tables # 노드 식별자 정의 QUERY_REFINER = "query_refiner" @@ -31,6 +32,7 @@ class QueryMakerState(TypedDict): best_practice_query: str refined_input: str generated_query: str + use_rerank: bool # 노드 함수: QUERY_REFINER 노드 @@ -40,6 +42,7 @@ def query_refiner_node(state: QueryMakerState): "user_input": [state["messages"][0].content], "user_database_env": [state["user_database_env"]], "best_practice_query": [state["best_practice_query"]], + "searched_tables": [json.dumps(state["searched_tables"])], } ) state["messages"].append(res) @@ -48,43 +51,10 @@ def query_refiner_node(state: QueryMakerState): def get_table_info_node(state: QueryMakerState): - from langchain_community.vectorstores import FAISS - from langchain_openai import OpenAIEmbeddings - - embeddings = OpenAIEmbeddings(model="text-embedding-3-small") - try: - db = FAISS.load_local( - os.getcwd() + "/table_info_db", - embeddings, - allow_dangerous_deserialization=True, - ) - except: - documents = get_info_from_db() - db = FAISS.from_documents(documents, embeddings) - db.save_local(os.getcwd() + "/table_info_db") - print("table_info_db not found") - doc_res = db.similarity_search(state["messages"][-1].content) - documents_dict = {} - - for doc in doc_res: - lines = doc.page_content.split("\n") - - # 테이블명 및 설명 추출 - table_name, table_desc = lines[0].split(": ", 1) - - # 컬럼 정보 추출 - columns = {} - if len(lines) > 2 and lines[1].strip() == "Columns:": - for line in lines[2:]: - if ": " in line: - col_name, col_desc = line.split(": ", 1) - columns[col_name.strip()] = col_desc.strip() - - # 딕셔너리 저장 - documents_dict[table_name] = { - "table_description": table_desc.strip(), - **columns, # 컬럼 정보 추가 - } + # state의 use_rerank 값을 이용하여 검색 수행 + documents_dict = search_tables( + state["messages"][0].content, use_rerank=state["use_rerank"] + ) state["searched_tables"] = documents_dict return state @@ -134,19 +104,16 @@ def query_maker_node_with_db_guide(state: QueryMakerState): # StateGraph 생성 및 구성 builder = StateGraph(QueryMakerState) -builder.set_entry_point(QUERY_REFINER) +builder.set_entry_point(GET_TABLE_INFO) # 노드 추가 -builder.add_node(QUERY_REFINER, query_refiner_node) builder.add_node(GET_TABLE_INFO, get_table_info_node) -# builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide -builder.add_node( - QUERY_MAKER, query_maker_node_with_db_guide -) # query_maker_node_with_db_guide +builder.add_node(QUERY_REFINER, query_refiner_node) +builder.add_node(QUERY_MAKER, query_maker_node_with_db_guide) # 기본 엣지 설정 -builder.add_edge(QUERY_REFINER, GET_TABLE_INFO) -builder.add_edge(GET_TABLE_INFO, QUERY_MAKER) +builder.add_edge(GET_TABLE_INFO, QUERY_REFINER) +builder.add_edge(QUERY_REFINER, QUERY_MAKER) # QUERY_MAKER 노드 후 종료 builder.add_edge(QUERY_MAKER, END) diff --git a/llm_utils/retrieval.py b/llm_utils/retrieval.py new file mode 100644 index 0000000..dbd1d2b --- /dev/null +++ b/llm_utils/retrieval.py @@ -0,0 +1,94 @@ +import os +from langchain_community.vectorstores import FAISS +from langchain_openai import OpenAIEmbeddings +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import CrossEncoderReranker +from langchain_community.cross_encoders import HuggingFaceCrossEncoder +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from .tools import get_info_from_db + + +def get_vector_db(): + """벡터 데이터베이스를 로드하거나 생성합니다.""" + embeddings = OpenAIEmbeddings(model="text-embedding-3-small") + try: + db = FAISS.load_local( + os.getcwd() + "/table_info_db", + embeddings, + allow_dangerous_deserialization=True, + ) + except: + documents = get_info_from_db() + db = FAISS.from_documents(documents, embeddings) + db.save_local(os.getcwd() + "/table_info_db") + print("table_info_db not found") + return db + + +def load_reranker_model(): + """한국어 reranker 모델을 로드하거나 다운로드합니다.""" + local_model_path = os.path.join(os.getcwd(), "ko_reranker_local") + + # 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장 + if os.path.exists(local_model_path) and os.path.isdir(local_model_path): + print("🔄 ko-reranker 모델 로컬에서 로드 중...") + else: + print("⬇️ ko-reranker 모델 다운로드 및 저장 중...") + model = AutoModelForSequenceClassification.from_pretrained( + "Dongjin-kr/ko-reranker" + ) + tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker") + model.save_pretrained(local_model_path) + tokenizer.save_pretrained(local_model_path) + + return HuggingFaceCrossEncoder(model_name=local_model_path) + + +def get_retriever(use_rerank=False): + """검색기를 생성합니다. use_rerank가 True이면 reranking을 적용합니다.""" + db = get_vector_db() + retriever = db.as_retriever(search_kwargs={"k": 10}) + + if use_rerank: + model = load_reranker_model() + compressor = CrossEncoderReranker(model=model, top_n=3) + return ContextualCompressionRetriever( + base_compressor=compressor, base_retriever=retriever + ) + else: + return retriever + + +def search_tables(query, use_rerank=False): + """쿼리에 맞는 테이블 정보를 검색합니다.""" + if use_rerank: + retriever = get_retriever(use_rerank=True) + doc_res = retriever.invoke(query) + else: + db = get_vector_db() + doc_res = db.similarity_search(query, k=10) + + # 결과를 사전 형태로 변환 + documents_dict = {} + for doc in doc_res: + lines = doc.page_content.split("\n") + + # 테이블명 및 설명 추출 + table_name, table_desc = lines[0].split(": ", 1) + + # 컬럼 정보 추출 + columns = {} + if len(lines) > 2 and lines[1].strip() == "Columns:": + for line in lines[2:]: + if ": " in line: + col_name, col_desc = line.split(": ", 1) + columns[col_name.strip()] = col_desc.strip() + + # 딕셔너리 저장 + documents_dict[table_name] = { + "table_description": table_desc.strip(), + **columns, # 컬럼 정보 추가 + } + + return documents_dict diff --git a/requirements.txt b/requirements.txt index 2c506a8..2998e38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,5 @@ pre_commit==4.1.0 setuptools wheel twine +langchain-huggingface==0.1.2 +transformers==4.51.2 \ No newline at end of file diff --git a/setup.py b/setup.py index d5e4805..d02bc0f 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,8 @@ "streamlit==1.41.1", "python-dotenv==1.0.1", "faiss-cpu==1.10.0", + "transformers==4.51.2", + "langchain-huggingface==0.1.2", ], entry_points={ "console_scripts": [ diff --git a/table_info_db/index.faiss b/table_info_db/index.faiss deleted file mode 100644 index 2ece6ff..0000000 Binary files a/table_info_db/index.faiss and /dev/null differ diff --git a/table_info_db/index.pkl b/table_info_db/index.pkl deleted file mode 100644 index 66662a6..0000000 Binary files a/table_info_db/index.pkl and /dev/null differ