diff --git a/cli/__init__.py b/cli/__init__.py index 191141b..5670638 100644 --- a/cli/__init__.py +++ b/cli/__init__.py @@ -159,3 +159,91 @@ def run_streamlit_cli_command(port: int) -> None: logger.info("Executing 'run-streamlit' command on port %d...", port) run_streamlit_command(port) + + +@cli.command(name="query") +@click.argument("question", type=str) +@click.option( + "--database-env", + default="clickhouse", + help="사용할 데이터베이스 환경 (기본값: clickhouse)", +) +@click.option( + "--retriever-name", + default="기본", + help="테이블 검색기 이름 (기본값: 기본)", +) +@click.option( + "--top-n", + type=int, + default=5, + help="검색된 상위 테이블 수 제한 (기본값: 5)", +) +@click.option( + "--device", + default="cpu", + help="LLM 실행에 사용할 디바이스 (기본값: cpu)", +) +@click.option( + "--use-enriched-graph", + is_flag=True, + help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부", +) +def query_command( + question: str, + database_env: str, + retriever_name: str, + top_n: int, + device: str, + use_enriched_graph: bool, +) -> None: + """ + 자연어 질문을 SQL 쿼리로 변환하여 출력하는 명령어입니다. + + 이 명령은 사용자가 입력한 자연어 질문을 받아서 SQL 쿼리로 변환하고, + 생성된 SQL 쿼리만을 표준 출력으로 출력합니다. + + 매개변수: + question (str): SQL로 변환할 자연어 질문 + database_env (str): 사용할 데이터베이스 환경 + retriever_name (str): 테이블 검색기 이름 + top_n (int): 검색된 상위 테이블 수 제한 + device (str): LLM 실행에 사용할 디바이스 + use_enriched_graph (bool): 확장된 그래프 사용 여부 + + 예시: + lang2sql query "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" + lang2sql query "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" --use-enriched-graph + """ + + try: + from llm_utils.query_executor import execute_query, extract_sql_from_result + + # 공용 함수를 사용하여 쿼리 실행 + res = execute_query( + query=question, + database_env=database_env, + retriever_name=retriever_name, + top_n=top_n, + device=device, + use_enriched_graph=use_enriched_graph, + ) + + # SQL 추출 및 출력 + sql = extract_sql_from_result(res) + if sql: + print(sql) + else: + # SQL 추출 실패 시 원본 쿼리 텍스트 출력 + generated_query = res.get("generated_query") + if generated_query: + query_text = ( + generated_query.content + if hasattr(generated_query, "content") + else str(generated_query) + ) + print(query_text) + + except Exception as e: + logger.error("쿼리 처리 중 오류 발생: %s", e) + raise diff --git a/evaluation/gen_answer.py b/evaluation/gen_answer.py index 65feb91..653edab 100644 --- a/evaluation/gen_answer.py +++ b/evaluation/gen_answer.py @@ -6,7 +6,7 @@ from tqdm import tqdm import uuid -from llm_utils.graph import builder +from llm_utils.graph_utils.basic_graph import builder def get_eval_result( diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 9d94268..07d381b 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -7,14 +7,15 @@ import streamlit as st from langchain.chains.sql_database.prompt import SQL_PROMPTS -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage from llm_utils.connect_db import ConnectDB from llm_utils.display_chart import DisplayChart -from llm_utils.enriched_graph import builder as enriched_builder -from llm_utils.graph import builder +from llm_utils.query_executor import execute_query as execute_query_common from llm_utils.llm_response_parser import LLMResponseParser from llm_utils.token_utils import TokenUtils +from llm_utils.graph_utils.enriched_graph import builder as enriched_builder +from llm_utils.graph_utils.basic_graph import builder TITLE = "Lang2SQL" DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" @@ -40,9 +41,8 @@ def execute_query( """ 자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다. - 이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤, - 사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다. - 내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다. + 이 함수는 공용 execute_query 함수를 호출하여 Lang2SQL 파이프라인을 실행합니다. + Streamlit 세션 상태를 활용하여 그래프를 재사용합니다. Args: query (str): 사용자가 입력한 자연어 기반 질문. @@ -59,27 +59,16 @@ def execute_query( - "searched_tables": 참조된 테이블 목록 등 추가 정보 """ - graph = st.session_state.get("graph") - if graph is None: - graph_builder = ( - enriched_builder if st.session_state.get("use_enriched") else builder - ) - graph = graph_builder.compile() - st.session_state["graph"] = graph - - res = graph.invoke( - input={ - "messages": [HumanMessage(content=query)], - "user_database_env": database_env, - "best_practice_query": "", - "retriever_name": retriever_name, - "top_n": top_n, - "device": device, - } + return execute_query_common( + query=query, + database_env=database_env, + retriever_name=retriever_name, + top_n=top_n, + device=device, + use_enriched_graph=st.session_state.get("use_enriched", False), + session_state=st.session_state, ) - return res - def display_result( *, @@ -120,40 +109,50 @@ def should_show(_key: str) -> bool: if should_show("show_sql"): st.markdown("---") generated_query = res.get("generated_query") - query_text = ( - generated_query.content - if isinstance(generated_query, AIMessage) - else str(generated_query) - ) + if generated_query: + query_text = ( + generated_query.content + if isinstance(generated_query, AIMessage) + else str(generated_query) + ) - try: - sql = LLMResponseParser.extract_sql(query_text) - st.markdown("**생성된 SQL 쿼리:**") - st.code(sql, language="sql") - except ValueError: - st.warning("SQL 블록을 추출할 수 없습니다.") - st.text(query_text) - - interpretation = LLMResponseParser.extract_interpretation(query_text) - if interpretation: - st.markdown("**결과 해석:**") - st.code(interpretation) + # query_text가 문자열인지 확인 + if isinstance(query_text, str): + try: + sql = LLMResponseParser.extract_sql(query_text) + st.markdown("**생성된 SQL 쿼리:**") + st.code(sql, language="sql") + except ValueError: + st.warning("SQL 블록을 추출할 수 없습니다.") + st.text(query_text) + + interpretation = LLMResponseParser.extract_interpretation(query_text) + if interpretation: + st.markdown("**결과 해석:**") + st.code(interpretation) + else: + st.warning("쿼리 텍스트가 문자열이 아닙니다.") + st.text(str(query_text)) if should_show("show_result_description"): st.markdown("---") st.markdown("**결과 설명:**") result_message = res["messages"][-1].content - try: - sql = LLMResponseParser.extract_sql(result_message) - st.code(sql, language="sql") - except ValueError: - st.warning("SQL 블록을 추출할 수 없습니다.") - st.text(result_message) - - interpretation = LLMResponseParser.extract_interpretation(result_message) - if interpretation: - st.code(interpretation, language="plaintext") + if isinstance(result_message, str): + try: + sql = LLMResponseParser.extract_sql(result_message) + st.code(sql, language="sql") + except ValueError: + st.warning("SQL 블록을 추출할 수 없습니다.") + st.text(result_message) + + interpretation = LLMResponseParser.extract_interpretation(result_message) + if interpretation: + st.code(interpretation, language="plaintext") + else: + st.warning("결과 메시지가 문자열이 아닙니다.") + st.text(str(result_message)) if should_show("show_question_reinterpreted_by_ai"): st.markdown("---") @@ -173,26 +172,41 @@ def should_show(_key: str) -> bool: if isinstance(res["generated_query"], AIMessage) else str(res["generated_query"]) ) - sql = LLMResponseParser.extract_sql(sql_raw) - df = database.run_sql(sql) - st.dataframe(df.head(10) if len(df) > 10 else df) + if isinstance(sql_raw, str): + sql = LLMResponseParser.extract_sql(sql_raw) + df = database.run_sql(sql) + st.dataframe(df.head(10) if len(df) > 10 else df) + else: + st.error("SQL 원본이 문자열이 아닙니다.") except Exception as e: st.error(f"쿼리 실행 중 오류 발생: {e}") if should_show("show_chart"): st.markdown("---") - df = database.run_sql(sql) - st.markdown("**쿼리 결과 시각화:**") - display_code = DisplayChart( - question=res["refined_input"].content, - sql=sql, - df_metadata=f"Running df.dtypes gives:\n{df.dtypes}", - ) - # plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다 - fig = display_code.get_plotly_figure( - plotly_code=display_code.generate_plotly_code(), df=df - ) - st.plotly_chart(fig) + try: + sql_raw = ( + res["generated_query"].content + if isinstance(res["generated_query"], AIMessage) + else str(res["generated_query"]) + ) + if isinstance(sql_raw, str): + sql = LLMResponseParser.extract_sql(sql_raw) + df = database.run_sql(sql) + st.markdown("**쿼리 결과 시각화:**") + display_code = DisplayChart( + question=res["refined_input"].content, + sql=sql, + df_metadata=f"Running df.dtypes gives:\n{df.dtypes}", + ) + # plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다 + fig = display_code.get_plotly_figure( + plotly_code=display_code.generate_plotly_code(), df=df + ) + st.plotly_chart(fig) + else: + st.error("SQL 원본이 문자열이 아닙니다.") + except Exception as e: + st.error(f"차트 생성 중 오류 발생: {e}") db = ConnectDB() diff --git a/llm_utils/graph_utils/__init__.py b/llm_utils/graph_utils/__init__.py new file mode 100644 index 0000000..41b539b --- /dev/null +++ b/llm_utils/graph_utils/__init__.py @@ -0,0 +1,45 @@ +""" +그래프 관련 유틸리티 모듈입니다. + +이 패키지는 Lang2SQL의 워크플로우 그래프 구성과 관련된 모듈들을 포함합니다. +""" + +from .base import ( + QueryMakerState, + GET_TABLE_INFO, + QUERY_REFINER, + QUERY_MAKER, + PROFILE_EXTRACTION, + CONTEXT_ENRICHMENT, + get_table_info_node, + query_refiner_node, + query_maker_node, + profile_extraction_node, + query_refiner_with_profile_node, + context_enrichment_node, + query_maker_node_with_db_guide, +) + +from .basic_graph import builder as basic_builder +from .enriched_graph import builder as enriched_builder + +__all__ = [ + # 상태 및 노드 식별자 + "QueryMakerState", + "GET_TABLE_INFO", + "QUERY_REFINER", + "QUERY_MAKER", + "PROFILE_EXTRACTION", + "CONTEXT_ENRICHMENT", + # 노드 함수들 + "get_table_info_node", + "query_refiner_node", + "query_maker_node", + "profile_extraction_node", + "query_refiner_with_profile_node", + "context_enrichment_node", + "query_maker_node_with_db_guide", + # 그래프 빌더들 + "basic_builder", + "enriched_builder", +] diff --git a/llm_utils/graph.py b/llm_utils/graph_utils/base.py similarity index 91% rename from llm_utils/graph.py rename to llm_utils/graph_utils/base.py index 598671a..d725651 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph_utils/base.py @@ -6,7 +6,7 @@ from langgraph.graph.message import add_messages from langchain.chains.sql_database.prompt import SQL_PROMPTS from pydantic import BaseModel, Field -from .llm_factory import get_llm +from llm_utils.llm_factory import get_llm from llm_utils.chains import ( query_refiner_chain, @@ -119,7 +119,7 @@ def context_enrichment_node(state: QueryMakerState): 주요 작업: - 주어진 질문의 메타데이터 (`question_profile` 및 `searched_tables`)를 활용하여, 질문을 수정하거나 추가 정보를 삽입합니다. - 질문이 시계열 분석 또는 집계 함수 관련인 경우, 이를 명시적으로 강조합니다 (예: "지난 30일 동안"). - - 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: ‘미국’ → ‘USA’). + - 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: '미국' → 'USA'). - 보강된 질문을 출력합니다. Args: @@ -207,23 +207,3 @@ def query_maker_node_with_db_guide(state: QueryMakerState): state["generated_query"] = res.sql state["messages"].append(res.explanation) return state - - -# StateGraph 생성 및 구성 -builder = StateGraph(QueryMakerState) -builder.set_entry_point(GET_TABLE_INFO) - -# 노드 추가 -builder.add_node(GET_TABLE_INFO, get_table_info_node) -builder.add_node(QUERY_REFINER, query_refiner_node) -builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide -# builder.add_node( -# QUERY_MAKER, query_maker_node_with_db_guide -# ) # query_maker_node_with_db_guide - -# 기본 엣지 설정 -builder.add_edge(GET_TABLE_INFO, QUERY_REFINER) -builder.add_edge(QUERY_REFINER, QUERY_MAKER) - -# QUERY_MAKER 노드 후 종료 -builder.add_edge(QUERY_MAKER, END) diff --git a/llm_utils/graph_utils/basic_graph.py b/llm_utils/graph_utils/basic_graph.py new file mode 100644 index 0000000..8f28264 --- /dev/null +++ b/llm_utils/graph_utils/basic_graph.py @@ -0,0 +1,33 @@ +import json + +from langgraph.graph import StateGraph, END +from llm_utils.graph_utils.base import ( + QueryMakerState, + GET_TABLE_INFO, + QUERY_REFINER, + QUERY_MAKER, + get_table_info_node, + query_refiner_node, + query_maker_node, +) + +""" +기본 워크플로우를 위한 StateGraph 구성입니다. +GET_TABLE_INFO -> QUERY_REFINER -> QUERY_MAKER 순서로 실행됩니다. +""" + +# StateGraph 생성 및 구성 +builder = StateGraph(QueryMakerState) +builder.set_entry_point(GET_TABLE_INFO) + +# 노드 추가 +builder.add_node(GET_TABLE_INFO, get_table_info_node) +builder.add_node(QUERY_REFINER, query_refiner_node) +builder.add_node(QUERY_MAKER, query_maker_node) + +# 기본 엣지 설정 +builder.add_edge(GET_TABLE_INFO, QUERY_REFINER) +builder.add_edge(QUERY_REFINER, QUERY_MAKER) + +# QUERY_MAKER 노드 후 종료 +builder.add_edge(QUERY_MAKER, END) diff --git a/llm_utils/enriched_graph.py b/llm_utils/graph_utils/enriched_graph.py similarity index 96% rename from llm_utils/enriched_graph.py rename to llm_utils/graph_utils/enriched_graph.py index 1018ec6..4645a32 100644 --- a/llm_utils/enriched_graph.py +++ b/llm_utils/graph_utils/enriched_graph.py @@ -1,7 +1,7 @@ import json from langgraph.graph import StateGraph, END -from llm_utils.graph import ( +from llm_utils.graph_utils.base import ( QueryMakerState, GET_TABLE_INFO, PROFILE_EXTRACTION, @@ -26,8 +26,8 @@ # 노드 추가 builder.add_node(GET_TABLE_INFO, get_table_info_node) -builder.add_node(QUERY_REFINER, query_refiner_with_profile_node) builder.add_node(PROFILE_EXTRACTION, profile_extraction_node) +builder.add_node(QUERY_REFINER, query_refiner_with_profile_node) builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node) builder.add_node(QUERY_MAKER, query_maker_node) diff --git a/llm_utils/query_executor.py b/llm_utils/query_executor.py new file mode 100644 index 0000000..7a87e3c --- /dev/null +++ b/llm_utils/query_executor.py @@ -0,0 +1,111 @@ +""" +Lang2SQL 쿼리 실행을 위한 공용 모듈입니다. + +이 모듈은 CLI와 Streamlit 인터페이스에서 공통으로 사용할 수 있는 +쿼리 실행 함수를 제공합니다. +""" + +import logging +from typing import Dict, Any, Optional, Union + +from langchain_core.messages import HumanMessage + +from llm_utils.graph_utils.enriched_graph import builder as enriched_builder +from llm_utils.graph_utils.basic_graph import builder as basic_builder +from llm_utils.llm_response_parser import LLMResponseParser + +logger = logging.getLogger(__name__) + + +def execute_query( + *, + query: str, + database_env: str, + retriever_name: str = "기본", + top_n: int = 5, + device: str = "cpu", + use_enriched_graph: bool = False, + session_state: Optional[Union[Dict[str, Any], Any]] = None, +) -> Dict[str, Any]: + """ + 자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 공용 함수입니다. + + 이 함수는 Lang2SQL 파이프라인(graph)을 사용하여 사용자의 자연어 질문을 + SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다. + CLI와 Streamlit 인터페이스에서 공통으로 사용할 수 있습니다. + + Args: + query (str): 사용자가 입력한 자연어 기반 질문. + database_env (str): 사용할 데이터베이스 환경 이름 또는 키 (예: "dev", "prod"). + retriever_name (str, optional): 테이블 검색기 이름. 기본값은 "기본". + top_n (int, optional): 검색된 상위 테이블 수 제한. 기본값은 5. + device (str, optional): LLM 실행에 사용할 디바이스 ("cpu" 또는 "cuda"). 기본값은 "cpu". + use_enriched_graph (bool, optional): 확장된 그래프 사용 여부. 기본값은 False. + session_state (Optional[Union[Dict[str, Any], Any]], optional): Streamlit 세션 상태 (Streamlit에서만 사용). + + Returns: + Dict[str, Any]: 다음 정보를 포함한 Lang2SQL 실행 결과 딕셔너리: + - "generated_query": 생성된 SQL 쿼리 (`AIMessage`) + - "messages": 전체 LLM 응답 메시지 목록 + - "refined_input": AI가 재구성한 입력 질문 + - "searched_tables": 참조된 테이블 목록 등 추가 정보 + """ + + logger.info("Processing query: %s", query) + logger.info("Using %s graph", "enriched" if use_enriched_graph else "basic") + + # 그래프 선택 및 컴파일 + if session_state is not None: + # Streamlit 환경: 세션 상태에서 그래프 재사용 + graph = session_state.get("graph") + if graph is None: + graph_builder = enriched_builder if use_enriched_graph else basic_builder + graph = graph_builder.compile() + session_state["graph"] = graph + else: + # CLI 환경: 매번 새로운 그래프 컴파일 + graph_builder = enriched_builder if use_enriched_graph else basic_builder + graph = graph_builder.compile() + + # 그래프 실행 + res = graph.invoke( + input={ + "messages": [HumanMessage(content=query)], + "user_database_env": database_env, + "best_practice_query": "", + "retriever_name": retriever_name, + "top_n": top_n, + "device": device, + } + ) + + return res + + +def extract_sql_from_result(res: Dict[str, Any]) -> Optional[str]: + """ + Lang2SQL 실행 결과에서 SQL 쿼리를 추출합니다. + + Args: + res (Dict[str, Any]): execute_query 함수의 반환 결과 + + Returns: + Optional[str]: 추출된 SQL 쿼리 문자열. 추출 실패 시 None + """ + generated_query = res.get("generated_query") + if not generated_query: + logger.error("생성된 쿼리가 없습니다.") + return None + + query_text = ( + generated_query.content + if hasattr(generated_query, "content") + else str(generated_query) + ) + + try: + sql = LLMResponseParser.extract_sql(query_text) + return sql + except ValueError: + logger.error("SQL을 추출할 수 없습니다.") + return None diff --git a/main.py b/main.py deleted file mode 100644 index b16c9bb..0000000 --- a/main.py +++ /dev/null @@ -1,11 +0,0 @@ -from langchain_core.messages import HumanMessage - -from llm_utils.graph import builder - -if __name__ == "__main__": - graph = builder.compile() - user_query = """ - 고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리 - """ - human_message = HumanMessage(content=user_query) - res = graph.invoke(input=human_message)