|
7 | 7 |
|
8 | 8 | import streamlit as st |
9 | 9 | from langchain.chains.sql_database.prompt import SQL_PROMPTS |
10 | | -from langchain_core.messages import HumanMessage |
| 10 | +from langchain_core.messages import AIMessage, HumanMessage |
11 | 11 |
|
12 | 12 | from llm_utils.connect_db import ConnectDB |
13 | 13 | from llm_utils.graph import builder |
14 | 14 | from llm_utils.enriched_graph import builder as enriched_builder |
| 15 | +from llm_utils.display_chart import DisplayChart |
| 16 | +from llm_utils.llm_response_parser import LLMResponseParser |
15 | 17 |
|
16 | 18 | DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" |
17 | 19 | SIDEBAR_OPTIONS = { |
@@ -52,18 +54,27 @@ def execute_query( |
52 | 54 | device: str = "cpu", |
53 | 55 | ) -> dict: |
54 | 56 | """ |
55 | | - Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다. |
| 57 | + 자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다. |
| 58 | +
|
| 59 | + 이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤, |
| 60 | + 사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다. |
| 61 | + 내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다. |
56 | 62 |
|
57 | 63 | Args: |
58 | | - query (str): 자연어로 작성된 사용자 쿼리. |
59 | | - database_env (str): 사용할 데이터베이스 환경 설정 이름. |
60 | | - retriever_name (str): 사용할 검색기 이름. |
61 | | - top_n (int): 검색할 테이블 정보의 개수. |
| 64 | + query (str): 사용자가 입력한 자연어 기반 질문. |
| 65 | + database_env (str): 사용할 데이터베이스 환경 이름 또는 키 (예: "dev", "prod"). |
| 66 | + retriever_name (str, optional): 테이블 검색기 이름. 기본값은 "기본". |
| 67 | + top_n (int, optional): 검색된 상위 테이블 수 제한. 기본값은 5. |
| 68 | + device (str, optional): LLM 실행에 사용할 디바이스 ("cpu" 또는 "cuda"). 기본값은 "cpu". |
62 | 69 |
|
63 | 70 | Returns: |
64 | | - dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리. |
| 71 | + dict: 다음 정보를 포함한 Lang2SQL 실행 결과 딕셔너리: |
| 72 | + - "generated_query": 생성된 SQL 쿼리 (`AIMessage`) |
| 73 | + - "messages": 전체 LLM 응답 메시지 목록 |
| 74 | + - "refined_input": AI가 재구성한 입력 질문 |
| 75 | + - "searched_tables": 참조된 테이블 목록 등 추가 정보 |
65 | 76 | """ |
66 | | - # 세션 상태에서 그래프 가져오기 |
| 77 | + |
67 | 78 | graph = st.session_state.get("graph") |
68 | 79 | if graph is None: |
69 | 80 | graph_builder = ( |
@@ -106,22 +117,84 @@ def display_result( |
106 | 117 | - 참조된 테이블 목록 |
107 | 118 | - 쿼리 실행 결과 테이블 |
108 | 119 | """ |
109 | | - total_tokens = summarize_total_tokens(res["messages"]) |
110 | | - |
111 | | - if st.session_state.get("show_total_token_usage", True): |
112 | | - st.write("총 토큰 사용량:", total_tokens) |
113 | | - if st.session_state.get("show_sql", True): |
114 | | - st.write("결과:", "\n\n```sql\n" + res["generated_query"].content + "\n```") |
115 | | - if st.session_state.get("show_result_description", True): |
116 | | - st.write("결과 설명:\n\n", res["messages"][-1].content) |
117 | | - if st.session_state.get("show_question_reinterpreted_by_ai", True): |
118 | | - st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content) |
119 | | - if st.session_state.get("show_referenced_tables", True): |
120 | | - st.write("참고한 테이블 목록:", res["searched_tables"]) |
121 | | - if st.session_state.get("show_table", True): |
122 | | - sql = res["generated_query"] |
| 120 | + |
| 121 | + def should_show(_key: str) -> bool: |
| 122 | + st.markdown("---") |
| 123 | + return st.session_state.get(_key, True) |
| 124 | + |
| 125 | + if should_show("show_total_token_usage"): |
| 126 | + total_tokens = summarize_total_tokens(res["messages"]) |
| 127 | + st.write("**총 토큰 사용량:**", total_tokens) |
| 128 | + |
| 129 | + if should_show("show_sql"): |
| 130 | + generated_query = res.get("generated_query") |
| 131 | + query_text = ( |
| 132 | + generated_query.content |
| 133 | + if isinstance(generated_query, AIMessage) |
| 134 | + else str(generated_query) |
| 135 | + ) |
| 136 | + |
| 137 | + try: |
| 138 | + sql = LLMResponseParser.extract_sql(query_text) |
| 139 | + st.markdown("**생성된 SQL 쿼리:**") |
| 140 | + st.code(sql, language="sql") |
| 141 | + except ValueError: |
| 142 | + st.warning("SQL 블록을 추출할 수 없습니다.") |
| 143 | + st.text(query_text) |
| 144 | + |
| 145 | + interpretation = LLMResponseParser.extract_interpretation(query_text) |
| 146 | + if interpretation: |
| 147 | + st.markdown("**결과 해석:**") |
| 148 | + st.code(interpretation) |
| 149 | + |
| 150 | + if should_show("show_result_description"): |
| 151 | + st.markdown("**결과 설명:**") |
| 152 | + result_message = res["messages"][-1].content |
| 153 | + |
| 154 | + try: |
| 155 | + sql = LLMResponseParser.extract_sql(result_message) |
| 156 | + st.code(sql, language="sql") |
| 157 | + except ValueError: |
| 158 | + st.warning("SQL 블록을 추출할 수 없습니다.") |
| 159 | + st.text(result_message) |
| 160 | + |
| 161 | + interpretation = LLMResponseParser.extract_interpretation(result_message) |
| 162 | + if interpretation: |
| 163 | + st.code(interpretation, language="plaintext") |
| 164 | + |
| 165 | + if should_show("show_question_reinterpreted_by_ai"): |
| 166 | + st.markdown("**AI가 재해석한 사용자 질문:**") |
| 167 | + st.code(res["refined_input"].content) |
| 168 | + |
| 169 | + if should_show("show_referenced_tables"): |
| 170 | + st.markdown("**참고한 테이블 목록:**") |
| 171 | + st.write(res.get("searched_tables", [])) |
| 172 | + |
| 173 | + if should_show("show_table"): |
| 174 | + try: |
| 175 | + sql_raw = ( |
| 176 | + res["generated_query"].content |
| 177 | + if isinstance(res["generated_query"], AIMessage) |
| 178 | + else str(res["generated_query"]) |
| 179 | + ) |
| 180 | + sql = LLMResponseParser.extract_sql(sql_raw) |
| 181 | + df = database.run_sql(sql) |
| 182 | + st.dataframe(df.head(10) if len(df) > 10 else df) |
| 183 | + except Exception as e: |
| 184 | + st.error(f"쿼리 실행 중 오류 발생: {e}") |
| 185 | + if should_show("show_chart"): |
123 | 186 | df = database.run_sql(sql) |
124 | | - st.dataframe(df.head(10) if len(df) > 10 else df) |
| 187 | + st.markdown("**쿼리 결과 시각화:**") |
| 188 | + display_code = DisplayChart( |
| 189 | + question=res["refined_input"].content, |
| 190 | + sql=sql, |
| 191 | + df_metadata=f"Running df.dtypes gives:\n{df.dtypes}", |
| 192 | + ) |
| 193 | + # plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다 |
| 194 | + fig = display_code.get_plotly_figure( |
| 195 | + plotly_code=display_code.generate_plotly_code(), df=df |
| 196 | + ) |
| 197 | + st.plotly_chart(fig) |
125 | 198 |
|
126 | 199 |
|
127 | 200 | db = ConnectDB() |
|
0 commit comments