diff --git a/.gitignore b/.gitignore index caebd49..de2c5ad 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,4 @@ test_lhm/ .cursorignore .vscode table_info_db -ko_reranker_local \ No newline at end of file +ko_reranker_local diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 6006228..a41068f 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -11,10 +11,10 @@ from llm_utils.connect_db import ConnectDB from llm_utils.graph import builder +from llm_utils.enriched_graph import builder as enriched_builder from llm_utils.display_chart import DisplayChart from llm_utils.llm_response_parser import LLMResponseParser - DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" SIDEBAR_OPTIONS = { "show_total_token_usage": "Show Total Token Usage", @@ -77,7 +77,10 @@ def execute_query( graph = st.session_state.get("graph") if graph is None: - graph = builder.compile() + graph_builder = ( + enriched_builder if st.session_state.get("use_enriched") else builder + ) + graph = graph_builder.compile() st.session_state["graph"] = graph res = graph.invoke( @@ -198,14 +201,29 @@ def should_show(_key: str) -> bool: st.title("Lang2SQL") +# 워크플로우 선택(UI) +use_enriched = st.sidebar.checkbox( + "프로파일 추출 & 컨텍스트 보강 워크플로우 사용", value=False +) + # 세션 상태 초기화 -if "graph" not in st.session_state: - st.session_state["graph"] = builder.compile() +if ( + "graph" not in st.session_state + or st.session_state.get("use_enriched") != use_enriched +): + graph_builder = enriched_builder if use_enriched else builder + st.session_state["graph"] = graph_builder.compile() + + # 프로파일 추출 & 컨텍스트 보강 그래프 + st.session_state["use_enriched"] = use_enriched st.info("Lang2SQL이 성공적으로 시작되었습니다.") # 새로고침 버튼 추가 if st.sidebar.button("Lang2SQL 새로고침"): - st.session_state["graph"] = builder.compile() + graph_builder = ( + enriched_builder if st.session_state.get("use_enriched") else builder + ) + st.session_state["graph"] = graph_builder.compile() st.sidebar.success("Lang2SQL이 성공적으로 새로고침되었습니다.") user_query = st.text_area( diff --git a/llm_utils/chains.py b/llm_utils/chains.py index a0a5f27..587538c 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -4,12 +4,14 @@ MessagesPlaceholder, SystemMessagePromptTemplate, ) +from pydantic import BaseModel, Field from .llm_factory import get_llm from dotenv import load_dotenv from prompt.template_loader import get_prompt_template + env_path = os.path.join(os.getcwd(), ".env") if os.path.exists(env_path): @@ -20,6 +22,16 @@ llm = get_llm() +class QuestionProfile(BaseModel): + is_timeseries: bool = Field(description="시계열 분석 필요 여부") + is_aggregation: bool = Field(description="집계 함수 필요 여부") + has_filter: bool = Field(description="조건 필터 필요 여부") + is_grouped: bool = Field(description="그룹화 필요 여부") + has_ranking: bool = Field(description="정렬/순위 필요 여부") + has_temporal_comparison: bool = Field(description="기간 비교 포함 여부") + intent_type: str = Field(description="질문의 주요 의도 유형") + + def create_query_refiner_chain(llm): prompt = get_prompt_template("query_refiner_prompt") tool_choice_prompt = ChatPromptTemplate.from_messages( @@ -72,8 +84,66 @@ def create_query_maker_chain(llm): return query_maker_prompt | llm +def create_query_refiner_with_profile_chain(llm): + prompt = get_prompt_template("query_refiner_prompt") + + tool_choice_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(prompt), + MessagesPlaceholder(variable_name="user_input"), + SystemMessagePromptTemplate.from_template( + "다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:" + ), + MessagesPlaceholder(variable_name="searched_tables"), + # 프로파일 정보 입력 + SystemMessagePromptTemplate.from_template( + "다음은 사용자의 질문을 분석한 프로파일 정보입니다." + ), + MessagesPlaceholder("profile_prompt"), + SystemMessagePromptTemplate.from_template( + """ + 위 사용자의 입력과 위 조건을 바탕으로 + 분석 관점에서 **충분히 답변 가능한 형태**로 + "구체화된 질문"을 작성하세요. + """, + ), + ] + ) + + return tool_choice_prompt | llm + + +def create_query_enrichment_chain(llm): + prompt = get_prompt_template("query_enrichment_prompt") + + enrichment_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(prompt), + ] + ) + + chain = enrichment_prompt | llm + return chain + + +def create_profile_extraction_chain(llm): + prompt = get_prompt_template("profile_extraction_prompt") + + profile_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(prompt), + ] + ) + + chain = profile_prompt | llm.with_structured_output(QuestionProfile) + return chain + + query_refiner_chain = create_query_refiner_chain(llm) query_maker_chain = create_query_maker_chain(llm) +profile_extraction_chain = create_profile_extraction_chain(llm) +query_refiner_with_profile_chain = create_query_refiner_with_profile_chain(llm) +query_enrichment_chain = create_query_enrichment_chain(llm) if __name__ == "__main__": query_refiner_chain.invoke() diff --git a/llm_utils/enriched_graph.py b/llm_utils/enriched_graph.py new file mode 100644 index 0000000..1018ec6 --- /dev/null +++ b/llm_utils/enriched_graph.py @@ -0,0 +1,41 @@ +import json + +from langgraph.graph import StateGraph, END +from llm_utils.graph import ( + QueryMakerState, + GET_TABLE_INFO, + PROFILE_EXTRACTION, + QUERY_REFINER, + CONTEXT_ENRICHMENT, + QUERY_MAKER, + get_table_info_node, + profile_extraction_node, + query_refiner_with_profile_node, + context_enrichment_node, + query_maker_node, +) + +""" +기본 워크플로우에 '프로파일 추출(PROFILE_EXTRACTION)'과 '컨텍스트 보강(CONTEXT_ENRICHMENT)'를 +추가한 확장된 그래프입니다. +""" + +# StateGraph 생성 및 구성 +builder = StateGraph(QueryMakerState) +builder.set_entry_point(GET_TABLE_INFO) + +# 노드 추가 +builder.add_node(GET_TABLE_INFO, get_table_info_node) +builder.add_node(QUERY_REFINER, query_refiner_with_profile_node) +builder.add_node(PROFILE_EXTRACTION, profile_extraction_node) +builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node) +builder.add_node(QUERY_MAKER, query_maker_node) + +# 기본 엣지 설정 +builder.add_edge(GET_TABLE_INFO, PROFILE_EXTRACTION) +builder.add_edge(PROFILE_EXTRACTION, QUERY_REFINER) +builder.add_edge(QUERY_REFINER, CONTEXT_ENRICHMENT) +builder.add_edge(CONTEXT_ENRICHMENT, QUERY_MAKER) + +# QUERY_MAKER 노드 후 종료 +builder.add_edge(QUERY_MAKER, END) diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 69a10b9..598671a 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -11,10 +11,14 @@ from llm_utils.chains import ( query_refiner_chain, query_maker_chain, + query_refiner_with_profile_chain, + profile_extraction_chain, + query_enrichment_chain, ) from llm_utils.tools import get_info_from_db from llm_utils.retrieval import search_tables +from llm_utils.utils import profile_to_text # 노드 식별자 정의 QUERY_REFINER = "query_refiner" @@ -22,6 +26,8 @@ TOOL = "tool" TABLE_FILTER = "table_filter" QUERY_MAKER = "query_maker" +PROFILE_EXTRACTION = "profile_extraction" +CONTEXT_ENRICHMENT = "context_enrichment" # 상태 타입 정의 (추가 상태 정보와 메시지들을 포함) @@ -31,12 +37,38 @@ class QueryMakerState(TypedDict): searched_tables: dict[str, dict[str, str]] best_practice_query: str refined_input: str + question_profile: dict generated_query: str retriever_name: str top_n: int device: str +# 노드 함수: PROFILE_EXTRACTION 노드 +def profile_extraction_node(state: QueryMakerState): + """ + 자연어 쿼리로부터 질문 유형(PROFILE)을 추출하는 노드입니다. + + 이 노드는 주어진 자연어 쿼리에서 질문의 특성을 분석하여, 해당 질문이 시계열 분석, 집계 함수 사용, 조건 필터 필요 여부, + 그룹화, 정렬/순위, 기간 비교 등 다양한 특성을 갖는지 여부를 추출합니다. + + 추출된 정보는 `QuestionProfile` 모델에 맞춰 저장됩니다. `QuestionProfile` 모델의 필드는 다음과 같습니다: + - `is_timeseries`: 시계열 분석 필요 여부 + - `is_aggregation`: 집계 함수 필요 여부 + - `has_filter`: 조건 필터 필요 여부 + - `is_grouped`: 그룹화 필요 여부 + - `has_ranking`: 정렬/순위 필요 여부 + - `has_temporal_comparison`: 기간 비교 포함 여부 + - `intent_type`: 질문의 주요 의도 유형 + + """ + result = profile_extraction_chain.invoke({"question": state["messages"][0].content}) + + state["question_profile"] = result + print("profile_extraction_node : ", result) + return state + + # 노드 함수: QUERY_REFINER 노드 def query_refiner_node(state: QueryMakerState): res = query_refiner_chain.invoke( @@ -52,6 +84,80 @@ def query_refiner_node(state: QueryMakerState): return state +# 노드 함수: QUERY_REFINER 노드 +def query_refiner_with_profile_node(state: QueryMakerState): + """ + 자연어 쿼리로부터 질문 유형(PROFILE)을 사용해 자연어 질의를 확장하는 노드입니다. + + """ + + profile_bullets = profile_to_text(state["question_profile"]) + res = query_refiner_with_profile_chain.invoke( + input={ + "user_input": [state["messages"][0].content], + "user_database_env": [state["user_database_env"]], + "best_practice_query": [state["best_practice_query"]], + "searched_tables": [json.dumps(state["searched_tables"])], + "profile_prompt": [profile_bullets], + } + ) + state["messages"].append(res) + state["refined_input"] = res + + print("refined_input before context enrichment : ", res.content) + return state + + +# 노드 함수: CONTEXT_ENRICHMENT 노드 +def context_enrichment_node(state: QueryMakerState): + """ + 주어진 질문과 관련된 메타데이터를 기반으로 질문을 풍부하게 만드는 노드입니다. + + 이 함수는 `refined_question`, `profiles`, `related_tables` 정보를 이용하여 자연어 질문을 보강합니다. + 보강 과정에서는 질문의 의도를 유지하면서, 추가적인 세부 정보를 제공하거나 잘못된 용어를 수정합니다. + + 주요 작업: + - 주어진 질문의 메타데이터 (`question_profile` 및 `searched_tables`)를 활용하여, 질문을 수정하거나 추가 정보를 삽입합니다. + - 질문이 시계열 분석 또는 집계 함수 관련인 경우, 이를 명시적으로 강조합니다 (예: "지난 30일 동안"). + - 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: ‘미국’ → ‘USA’). + - 보강된 질문을 출력합니다. + + Args: + state (QueryMakerState): 쿼리와 관련된 상태 정보를 담고 있는 객체. + 상태 객체는 `refined_input`, `question_profile`, `searched_tables` 등의 정보를 포함합니다. + + Returns: + QueryMakerState: 보강된 질문이 포함된 상태 객체. + + Example: + Given the refined question "What are the total sales in the last month?", + the function would enrich it with additional information such as: + - Ensuring the time period is specified correctly. + - Correcting any column names if necessary. + - Returning the enriched version of the question. + """ + + searched_tables = state["searched_tables"] + searched_tables_json = json.dumps(searched_tables, ensure_ascii=False, indent=2) + + question_profile = state["question_profile"].model_dump() + question_profile_json = json.dumps(question_profile, ensure_ascii=False, indent=2) + + enriched_text = query_enrichment_chain.invoke( + input={ + "refined_question": state["refined_input"], + "profiles": question_profile_json, + "related_tables": searched_tables_json, + } + ) + + state["refined_input"] = enriched_text + state["messages"].append(enriched_text) + print("After context enrichment : ", enriched_text.content) + + return state + + def get_table_info_node(state: QueryMakerState): # retriever_name과 top_n을 이용하여 검색 수행 documents_dict = search_tables( diff --git a/llm_utils/utils.py b/llm_utils/utils.py new file mode 100644 index 0000000..2057b5c --- /dev/null +++ b/llm_utils/utils.py @@ -0,0 +1,17 @@ +def profile_to_text(profile_obj) -> str: + mapping = { + "is_timeseries": "• 시계열 분석 필요", + "is_aggregation": "• 집계 함수 필요", + "has_filter": "• WHERE 조건 필요", + "is_grouped": "• GROUP BY 필요", + "has_ranking": "• 정렬/순위 필요", + "has_temporal_comparison": "• 기간 비교 필요", + } + bullets = [ + text for field, text in mapping.items() if getattr(profile_obj, field, False) + ] + intent = getattr(profile_obj, "intent_type", None) + if intent: + bullets.append(f"• 의도 유형 → {intent}") + + return "\n".join(bullets) diff --git a/prompt/profile_extraction_prompt.md b/prompt/profile_extraction_prompt.md new file mode 100644 index 0000000..606e037 --- /dev/null +++ b/prompt/profile_extraction_prompt.md @@ -0,0 +1,19 @@ +# Role + +You are an assistant that analyzes a user question and extracts the following profiles as JSON: +- is_timeseries (boolean) +- is_aggregation (boolean) +- has_filter (boolean) +- is_grouped (boolean) +- has_ranking (boolean) +- has_temporal_comparison (boolean) +- intent_type (one of: trend, lookup, comparison, distribution) + +# Input + +Question: +{question} + +# Output Example + +The output must be a valid JSON matching the QuestionProfile schema. diff --git a/prompt/query_enrichment_prompt.md b/prompt/query_enrichment_prompt.md new file mode 100644 index 0000000..98fbb6f --- /dev/null +++ b/prompt/query_enrichment_prompt.md @@ -0,0 +1,22 @@ +# Role + +You are a smart assistant that takes a user question and enriches it using: +1. Question profiles: {profiles} +2. Table metadata (names, columns, descriptions): + {related_tables} + +# Tasks + +- Correct any wrong terms by matching them to actual column names. +- If the question is time-series or aggregation, add explicit hints (e.g., "over the last 30 days"). +- If needed, map natural language terms to actual column values (e.g., ‘미국’ → ‘USA’ for country_code). +- Output the enriched question only. + +# Input + +Refined question: +{refined_question} + +# Notes + +Using the refined version for enrichment, but keep the original intent in mind. diff --git a/pyproject.toml b/pyproject.toml_ similarity index 100% rename from pyproject.toml rename to pyproject.toml_