Skip to content

Commit a49f3f0

Browse files
시각화 코드 작성
1 parent 8124066 commit a49f3f0

File tree

3 files changed

+150
-1
lines changed

3 files changed

+150
-1
lines changed

interface/lang2sql.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from llm_utils.connect_db import ConnectDB
1313
from llm_utils.graph import builder
14+
from llm_utils.display_chart import DisplayChart
1415

1516
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
1617
SIDEBAR_OPTIONS = {
@@ -115,9 +116,22 @@ def display_result(
115116
if st.session_state.get("show_referenced_tables", True):
116117
st.write("참고한 테이블 목록:", res["searched_tables"])
117118
if st.session_state.get("show_table", True):
118-
sql = res["generated_query"]
119+
st.write("쿼리 실행 결과")
120+
sql = res["generated_query"].content.split("```")[1][
121+
3:
122+
] # 쿼리 앞쪽의 "sql " 제거
119123
df = database.run_sql(sql)
120124
st.dataframe(df.head(10) if len(df) > 10 else df)
125+
if st.session_state.get("show_chart", True):
126+
st.write("쿼리 결과 시각화")
127+
display_code = DisplayChart(
128+
question=res["refined_input"].content,
129+
sql=sql,
130+
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
131+
)
132+
plotly_code = display_code.generate_plotly_code()
133+
fig = display_code.get_plotly_figure(plotly_code=plotly_code, df=df)
134+
st.plotly_chart(fig)
121135

122136

123137
db = ConnectDB()

llm_utils/display_chart.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import re
2+
from llm_utils import llm_factory
3+
from dotenv import load_dotenv
4+
from langchain.chains.llm import LLMChain
5+
from langchain_openai import ChatOpenAI
6+
from langchain_core.prompts import PromptTemplate
7+
from langchain_core.messages import HumanMessage, SystemMessage
8+
import pandas as pd
9+
import os
10+
11+
import plotly
12+
import plotly.express as px
13+
import plotly.graph_objects as go
14+
15+
16+
# .env 파일 로딩
17+
load_dotenv()
18+
19+
20+
class DisplayChart:
21+
"""
22+
SQL쿼리가 실행된 결과를 그래프로 시각화하는 Class입니다.
23+
24+
쿼리 결과를 비롯한 유저 질문, sql를 prompt에 입력하여
25+
plotly코드를 출력하여 excute한 결과를 fig 객체로 반환합니다.
26+
"""
27+
28+
def __init__(self, question, sql, df_metadata):
29+
self.question = question
30+
self.sql = sql
31+
self.df_metadata = df_metadata
32+
33+
def llm_model_for_chart(self, message_log):
34+
provider = os.getenv("LLM_PROVIDER")
35+
if provider == "openai":
36+
llm = ChatOpenAI(
37+
model=os.getenv("OPEN_AI_LLM_MODEL", "gpt-4o"),
38+
api_key=os.getenv("OPEN_AI_KEY"),
39+
)
40+
result = llm.invoke(message_log)
41+
return result
42+
43+
def _extract_python_code(self, markdown_string: str) -> str:
44+
# Strip whitespace to avoid indentation errors in LLM-generated code
45+
markdown_string = markdown_string.content.split("```")[1][6:].strip()
46+
47+
# Regex pattern to match Python code blocks
48+
pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```"
49+
50+
# Find all matches in the markdown string
51+
matches = re.findall(pattern, markdown_string, re.IGNORECASE)
52+
53+
# Extract the Python code from the matches
54+
python_code = []
55+
for match in matches:
56+
python = match[0] if match[0] else match[1]
57+
python_code.append(python.strip())
58+
59+
if len(python_code) == 0:
60+
return markdown_string
61+
62+
return python_code[0]
63+
64+
def _sanitize_plotly_code(self, raw_plotly_code):
65+
# Remove the fig.show() statement from the plotly code
66+
plotly_code = raw_plotly_code.replace("fig.show()", "")
67+
68+
return plotly_code
69+
70+
def generate_plotly_code(self) -> str:
71+
if self.question is not None:
72+
system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{self.question}'"
73+
else:
74+
system_msg = "The following is a pandas DataFrame "
75+
76+
if self.sql is not None:
77+
system_msg += (
78+
f"\n\nThe DataFrame was produced using this query: {self.sql}\n\n"
79+
)
80+
81+
system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{self.df_metadata}"
82+
83+
message_log = [
84+
SystemMessage(content=system_msg),
85+
HumanMessage(
86+
content="Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code."
87+
),
88+
]
89+
90+
plotly_code = self.llm_model_for_chart(message_log)
91+
92+
return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
93+
94+
def get_plotly_figure(
95+
self, plotly_code: str, df: pd.DataFrame, dark_mode: bool = True
96+
) -> plotly.graph_objs.Figure:
97+
98+
ldict = {"df": df, "px": px, "go": go}
99+
try:
100+
exec(plotly_code, globals(), ldict)
101+
fig = ldict.get("fig", None)
102+
103+
except Exception as e:
104+
105+
# Inspect data types
106+
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
107+
categorical_cols = df.select_dtypes(
108+
include=["object", "category"]
109+
).columns.tolist()
110+
111+
# Decision-making for plot type
112+
if len(numeric_cols) >= 2:
113+
# Use the first two numeric columns for a scatter plot
114+
fig = px.scatter(df, x=numeric_cols[0], y=numeric_cols[1])
115+
elif len(numeric_cols) == 1 and len(categorical_cols) >= 1:
116+
# Use a bar plot if there's one numeric and one categorical column
117+
fig = px.bar(df, x=categorical_cols[0], y=numeric_cols[0])
118+
elif len(categorical_cols) >= 1 and df[categorical_cols[0]].nunique() < 10:
119+
# Use a pie chart for categorical data with fewer unique values
120+
fig = px.pie(df, names=categorical_cols[0])
121+
else:
122+
# Default to a simple line plot if above conditions are not met
123+
fig = px.line(df)
124+
125+
if fig is None:
126+
return None
127+
128+
if dark_mode:
129+
fig.update_layout(template="plotly_dark")
130+
131+
return fig

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ langchain-google-genai>=2.1.3,<3.0.0
1717
langchain-ollama>=0.3.2,<0.4.0
1818
langchain-huggingface>=0.1.2,<0.2.0
1919
clickhouse_driver
20+
plotly
21+
matplotlib
22+
ipython
23+
kaleido

0 commit comments

Comments
 (0)