Skip to content

Commit d02edce

Browse files
authored
Merge branch 'master' into feature/77-add-isort-flake8
2 parents f92afd3 + 5a65240 commit d02edce

File tree

4 files changed

+298
-74
lines changed

4 files changed

+298
-74
lines changed

cli/__init__.py

Lines changed: 111 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Datahub GMS 서버 URL을 설정하고, 필요 시 Streamlit 인터페이스를 실행하는 CLI 프로그램입니다.
3+
"""
14
import subprocess
25

36
import click
@@ -9,29 +12,123 @@
912
@click.version_option(version="0.1.4")
1013
@click.pass_context
1114
@click.option(
12-
"--datahub_server", default="http://localhost:8080", help="Datahub GMS 서버 URL"
15+
"--datahub_server",
16+
default="http://localhost:8080",
17+
help=(
18+
"Datahub GMS 서버의 URL을 설정합니다. "
19+
"기본값은 'http://localhost:8080'이며, "
20+
"운영 환경 또는 테스트 환경에 맞게 변경할 수 있습니다."
21+
),
1322
)
14-
@click.option("--run-streamlit", is_flag=True, help="Run the Streamlit app.")
15-
@click.option("-p", "--port", type=int, default=8501, help="Streamlit port")
16-
def cli(ctx, datahub_server, run_streamlit, port):
23+
@click.option(
24+
"--run-streamlit",
25+
is_flag=True,
26+
help=(
27+
"이 옵션을 지정하면 CLI 실행 시 Streamlit 애플리케이션을 바로 실행합니다. "
28+
"별도의 명령어 입력 없이 웹 인터페이스를 띄우고 싶을 때 사용합니다."
29+
),
30+
)
31+
@click.option(
32+
"-p",
33+
"--port",
34+
type=int,
35+
default=8501,
36+
help=(
37+
"Streamlit 서버가 바인딩될 포트 번호를 지정합니다. "
38+
"기본 포트는 8501이며, 포트 충돌을 피하거나 여러 인스턴스를 실행할 때 변경할 수 있습니다."
39+
),
40+
)
41+
# pylint: disable=redefined-outer-name
42+
def cli(
43+
ctx: click.Context,
44+
datahub_server: str,
45+
run_streamlit: bool,
46+
port: int,
47+
) -> None:
48+
"""
49+
Datahub GMS 서버 URL을 설정하고, Streamlit 애플리케이션을 실행할 수 있는 CLI 명령 그룹입니다.
50+
51+
이 함수는 다음 역할을 수행합니다:
52+
- 전달받은 'datahub_server' URL을 바탕으로 GMS 서버 연결을 설정합니다.
53+
- 설정 과정 중 오류가 발생하면 오류 메시지를 출력하고 프로그램을 종료합니다.
54+
- '--run-streamlit' 옵션이 활성화된 경우, 지정된 포트에서 Streamlit 웹 앱을 즉시 실행합니다.
55+
56+
매개변수:
57+
ctx (click.Context): 명령어 실행 컨텍스트 객체입니다.
58+
datahub_server (str): 설정할 Datahub GMS 서버의 URL입니다.
59+
run_streamlit (bool): Streamlit 앱을 실행할지 여부를 나타내는 플래그입니다.
60+
port (int): Streamlit 서버가 바인딩될 포트 번호입니다.
61+
62+
주의:
63+
'set_gms_server' 함수에서 ValueError가 발생할 경우, 프로그램은 비정상 종료(exit code 1)합니다.
64+
"""
65+
1766
try:
1867
set_gms_server(datahub_server)
1968
except ValueError as e:
20-
click.echo(str(e))
69+
click.secho(f"GMS 서버 URL 설정 실패: {str(e)}", fg="red")
2170
ctx.exit(1)
2271
if run_streamlit:
2372
run_streamlit_command(port)
2473

2574

26-
def run_streamlit_command(port):
27-
"""Run the Streamlit app."""
28-
subprocess.run(
29-
["streamlit", "run", "interface/streamlit_app.py", "--server.port", str(port)]
30-
)
75+
def run_streamlit_command(port: int) -> None:
76+
"""
77+
지정된 포트에서 Streamlit 애플리케이션을 실행하는 함수입니다.
78+
79+
이 함수는 subprocess를 통해 'streamlit run' 명령어를 실행하여
80+
'interface/streamlit_app.py' 파일을 웹 서버 형태로 구동합니다.
81+
사용자가 지정한 포트 번호를 Streamlit 서버의 포트로 설정합니다.
82+
83+
매개변수:
84+
port (int): Streamlit 서버가 바인딩될 포트 번호입니다.
85+
86+
주의:
87+
- Streamlit이 시스템에 설치되어 있어야 정상 동작합니다.
88+
- subprocess 호출 실패 시 예외가 발생할 수 있습니다.
89+
"""
90+
91+
try:
92+
subprocess.run(
93+
[
94+
"streamlit",
95+
"run",
96+
"interface/streamlit_app.py",
97+
"--server.port",
98+
str(port),
99+
],
100+
check=True,
101+
)
102+
except subprocess.CalledProcessError as e:
103+
click.echo(f"Streamlit 실행 실패: {e}")
104+
raise
105+
106+
107+
@cli.command(name="run-streamlit")
108+
@click.option(
109+
"-p",
110+
"--port",
111+
type=int,
112+
default=8501,
113+
help=(
114+
"Streamlit 애플리케이션이 바인딩될 포트 번호를 지정합니다. "
115+
"기본 포트는 8501이며, 필요 시 포트 충돌을 피하거나 "
116+
"여러 인스턴스를 동시에 실행할 때 다른 포트 번호를 설정할 수 있습니다."
117+
),
118+
)
119+
def run_streamlit_cli_command(port: int) -> None:
120+
"""
121+
CLI 명령어를 통해 Streamlit 애플리케이션을 실행하는 함수입니다.
122+
123+
이 명령은 'interface/streamlit_app.py' 파일을 Streamlit 서버로 구동하며,
124+
사용자가 지정한 포트 번호를 바인딩하여 웹 인터페이스를 제공합니다.
125+
126+
매개변수:
127+
port (int): Streamlit 서버가 사용할 포트 번호입니다. 기본값은 8501입니다.
31128
129+
주의:
130+
- Streamlit이 시스템에 설치되어 있어야 정상적으로 실행됩니다.
131+
- Streamlit 실행에 실패할 경우 subprocess 호출에서 예외가 발생할 수 있습니다.
132+
"""
32133

33-
@cli.command()
34-
@click.option("-p", "--port", type=int, default=8501, help="Streamlit port")
35-
def run_streamlit(port):
36-
"""Run the Streamlit app."""
37134
run_streamlit_command(port)

interface/lang2sql.py

Lines changed: 110 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,103 @@
1-
import os
2-
from typing import Union
1+
"""
2+
Lang2SQL Streamlit 애플리케이션.
3+
4+
자연어로 입력된 질문을 SQL 쿼리로 변환하고,
5+
ClickHouse 데이터베이스에 실행한 결과를 출력합니다.
6+
"""
37

4-
import pandas as pd
58
import streamlit as st
6-
from clickhouse_driver import Client
7-
from dotenv import load_dotenv
89
from langchain.chains.sql_database.prompt import SQL_PROMPTS
910
from langchain_core.messages import HumanMessage
1011

1112
from llm_utils.connect_db import ConnectDB
1213
from llm_utils.graph import builder
1314

14-
# Clickhouse 연결
15-
db = ConnectDB()
16-
db.connect_to_clickhouse()
15+
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
16+
SIDEBAR_OPTIONS = {
17+
"show_total_token_usage": "Show Total Token Usage",
18+
"show_result_description": "Show Result Description",
19+
"show_sql": "Show SQL",
20+
"show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI",
21+
"show_referenced_tables": "Show List of Referenced Tables",
22+
"show_table": "Show Table",
23+
"show_chart": "Show Chart",
24+
}
25+
def summarize_total_tokens(data: list) -> int:
26+
"""
27+
메시지 데이터에서 총 토큰 사용량을 집계합니다.
1728
18-
# Streamlit 앱 제목
19-
st.title("Lang2SQL")
29+
Args:
30+
data (list): usage_metadata를 포함하는 객체들의 리스트.
2031
21-
# 사용자 입력 받기
22-
user_query = st.text_area(
23-
"쿼리를 입력하세요:",
24-
value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리",
25-
)
32+
Returns:
33+
int: 총 토큰 사용량 합계.
34+
"""
2635

27-
user_database_env = st.selectbox(
28-
"db 환경정보를 입력하세요:",
29-
options=SQL_PROMPTS.keys(),
30-
index=0,
31-
)
32-
st.sidebar.title("Output Settings")
33-
st.sidebar.checkbox("Show Total Token Usage", value=True, key="show_total_token_usage")
34-
st.sidebar.checkbox(
35-
"Show Result Description", value=True, key="show_result_description"
36-
)
37-
st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
38-
st.sidebar.checkbox(
39-
"Show User Question Reinterpreted by AI",
40-
value=True,
41-
key="show_question_reinterpreted_by_ai",
42-
)
43-
st.sidebar.checkbox(
44-
"Show List of Referenced Tables", value=True, key="show_referenced_tables"
45-
)
46-
st.sidebar.checkbox("Show Table", value=True, key="show_table")
47-
st.sidebar.checkbox("Show Chart", value=True, key="show_chart")
48-
49-
50-
# Token usage 집계 함수 정의
51-
def summarize_total_tokens(data):
5236
total_tokens = 0
5337
for item in data:
5438
token_usage = getattr(item, "usage_metadata", {})
5539
total_tokens += token_usage.get("total_tokens", 0)
5640
return total_tokens
5741

5842

59-
# 버튼 클릭 시 실행
60-
if st.button("쿼리 실행"):
61-
# 그래프 컴파일 및 쿼리 실행
62-
graph = builder.compile()
43+
def execute_query(
44+
*,
45+
query: str,
46+
database_env: str,
47+
) -> dict:
48+
"""
49+
Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.
50+
51+
Args:
52+
query (str): 자연어로 작성된 사용자 쿼리.
53+
database_env (str): 사용할 데이터베이스 환경 설정 이름.
54+
55+
Returns:
56+
dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
57+
"""
58+
# 세션 상태에서 그래프 가져오기
59+
graph = st.session_state.get("graph")
60+
if graph is None:
61+
graph = builder.compile()
62+
st.session_state["graph"] = graph
6363

6464
res = graph.invoke(
6565
input={
66-
"messages": [HumanMessage(content=user_query)],
67-
"user_database_env": user_database_env,
66+
"messages": [HumanMessage(content=query)],
67+
"user_database_env": database_env,
6868
"best_practice_query": "",
6969
}
7070
)
71+
72+
return res
73+
74+
75+
def display_result(
76+
*,
77+
res: dict,
78+
database: ConnectDB,
79+
) -> None:
80+
"""
81+
Lang2SQL 실행 결과를 Streamlit 화면에 출력합니다.
82+
83+
Args:
84+
res (dict): Lang2SQL 실행 결과 딕셔너리.
85+
database (ConnectDB): SQL 쿼리 실행을 위한 데이터베이스 연결 객체.
86+
87+
출력 항목:
88+
- 총 토큰 사용량
89+
- 생성된 SQL 쿼리
90+
- 결과 설명
91+
- AI가 재해석한 사용자 질문
92+
- 참조된 테이블 목록
93+
- 쿼리 실행 결과 테이블
94+
"""
7195
total_tokens = summarize_total_tokens(res["messages"])
7296

73-
# 결과 출력
7497
if st.session_state.get("show_total_token_usage", True):
7598
st.write("총 토큰 사용량:", total_tokens)
7699
if st.session_state.get("show_sql", True):
77-
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
100+
st.write("결과:", "\n\n```sql\n" + res["generated_query"].content + "\n```")
78101
if st.session_state.get("show_result_description", True):
79102
st.write("결과 설명:\n\n", res["messages"][-1].content)
80103
if st.session_state.get("show_question_reinterpreted_by_ai", True):
@@ -83,8 +106,42 @@ def summarize_total_tokens(data):
83106
st.write("참고한 테이블 목록:", res["searched_tables"])
84107
if st.session_state.get("show_table", True):
85108
sql = res["generated_query"]
86-
df = db.run_sql(sql)
87-
if len(df) > 10:
88-
st.dataframe(df.head(10))
89-
else:
90-
st.dataframe(df)
109+
df = database.run_sql(sql)
110+
st.dataframe(df.head(10) if len(df) > 10 else df)
111+
112+
113+
db = ConnectDB()
114+
db.connect_to_clickhouse()
115+
116+
st.title("Lang2SQL")
117+
118+
# 세션 상태 초기화
119+
if "graph" not in st.session_state:
120+
st.session_state["graph"] = builder.compile()
121+
st.info("Lang2SQL이 성공적으로 시작되었습니다.")
122+
123+
# 새로고침 버튼 추가
124+
if st.sidebar.button("Lang2SQL 새로고침"):
125+
st.session_state["graph"] = builder.compile()
126+
st.sidebar.success("Lang2SQL이 성공적으로 새로고침되었습니다.")
127+
128+
user_query = st.text_area(
129+
"쿼리를 입력하세요:",
130+
value=DEFAULT_QUERY,
131+
)
132+
user_database_env = st.selectbox(
133+
"DB 환경정보를 입력하세요:",
134+
options=SQL_PROMPTS.keys(),
135+
index=0,
136+
)
137+
138+
st.sidebar.title("Output Settings")
139+
for key, label in SIDEBAR_OPTIONS.items():
140+
st.sidebar.checkbox(label, value=True, key=key)
141+
142+
if st.button("쿼리 실행"):
143+
result = execute_query(
144+
query=user_query,
145+
database_env=user_database_env,
146+
)
147+
display_result(res=result, database=db)

0 commit comments

Comments
 (0)