1- import re
2- from langchain_openai import ChatOpenAI
3- from langchain_core .messages import HumanMessage , SystemMessage
4- import pandas as pd
1+ """
2+ SQL 쿼리 결과를 Plotly로 시각화하는 모듈
3+
4+ 이 모듈은 Lang2SQL 실행 결과를 다양한 형태의 차트로 시각화하는 기능을 제공합니다.
5+ LLM을 활용하여 적절한 Plotly 코드를 생성하고 실행합니다.
6+ """
7+
58import os
9+ import re
10+ from typing import Optional
611
12+ import pandas as pd
713import plotly
814import plotly .express as px
915import plotly .graph_objects as go
16+ from langchain_core .messages import HumanMessage , SystemMessage
17+ from langchain_openai import ChatOpenAI
1018
1119
1220class DisplayChart :
@@ -17,12 +25,29 @@ class DisplayChart:
1725 plotly코드를 출력하여 excute한 결과를 fig 객체로 반환합니다.
1826 """
1927
20- def __init__ (self , question , sql , df_metadata ):
28+ def __init__ (self , question : str , sql : str , df_metadata : str ):
29+ """
30+ DisplayChart 인스턴스를 초기화합니다.
31+
32+ Args:
33+ question (str): 사용자 질문
34+ sql (str): 실행된 SQL 쿼리
35+ df_metadata (str): 데이터프레임 메타데이터
36+ """
2137 self .question = question
2238 self .sql = sql
2339 self .df_metadata = df_metadata
2440
25- def llm_model_for_chart (self , message_log ):
41+ def llm_model_for_chart (self , message_log ) -> Optional [str ]:
42+ """
43+ LLM 모델을 사용하여 차트 생성 코드를 생성합니다.
44+
45+ Args:
46+ message_log: LLM에 전달할 메시지 목록
47+
48+ Returns:
49+ Optional[str]: 생성된 차트 코드 또는 None
50+ """
2651 provider = os .getenv ("LLM_PROVIDER" )
2752 if provider == "openai" :
2853 llm = ChatOpenAI (
@@ -31,18 +56,29 @@ def llm_model_for_chart(self, message_log):
3156 )
3257 result = llm .invoke (message_log )
3358 return result
59+ return None
3460
3561 def _extract_python_code (self , markdown_string : str ) -> str :
62+ """
63+ 마크다운 문자열에서 Python 코드 블록을 추출합니다.
64+
65+ Args:
66+ markdown_string: 마크다운 형식의 문자열
67+
68+ Returns:
69+ str: 추출된 Python 코드
70+ """
3671 # Strip whitespace to avoid indentation errors in LLM-generated code
37- markdown_string = markdown_string .content .split ("```" )[1 ][6 :].strip ()
72+ if hasattr (markdown_string , "content" ):
73+ markdown_string = markdown_string .content .split ("```" )[1 ][6 :].strip ()
74+ else :
75+ markdown_string = str (markdown_string )
3876
3977 # Regex pattern to match Python code blocks
40- pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" # 여러 문자와 공백 뒤에 python이 나오고, 줄바꿈 이후의 모든 내용
78+ pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```"
4179
4280 # Find all matches in the markdown string
43- matches = re .findall (
44- pattern , markdown_string , re .IGNORECASE
45- ) # 대소문자 구분 안함
81+ matches = re .findall (pattern , markdown_string , re .IGNORECASE )
4682
4783 # Extract the Python code from the matches
4884 python_code = []
@@ -55,13 +91,27 @@ def _extract_python_code(self, markdown_string: str) -> str:
5591
5692 return python_code [0 ]
5793
58- def _sanitize_plotly_code (self , raw_plotly_code ):
94+ def _sanitize_plotly_code (self , raw_plotly_code : str ) -> str :
95+ """
96+ Plotly 코드에서 불필요한 부분을 제거합니다.
97+
98+ Args:
99+ raw_plotly_code: 원본 Plotly 코드
100+
101+ Returns:
102+ str: 정리된 Plotly 코드
103+ """
59104 # Remove the fig.show() statement from the plotly code
60105 plotly_code = raw_plotly_code .replace ("fig.show()" , "" )
61-
62106 return plotly_code
63107
64108 def generate_plotly_code (self ) -> str :
109+ """
110+ LLM을 사용하여 Plotly 코드를 생성합니다.
111+
112+ Returns:
113+ str: 생성된 Plotly 코드
114+ """
65115 if self .question is not None :
66116 system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{ self .question } '"
67117 else :
@@ -82,20 +132,33 @@ def generate_plotly_code(self) -> str:
82132 ]
83133
84134 plotly_code = self .llm_model_for_chart (message_log )
135+ if plotly_code is None :
136+ return ""
85137
86138 return self ._sanitize_plotly_code (self ._extract_python_code (plotly_code ))
87139
88140 def get_plotly_figure (
89141 self , plotly_code : str , df : pd .DataFrame , dark_mode : bool = True
90- ) -> plotly .graph_objs .Figure :
91-
142+ ) -> Optional [plotly .graph_objs .Figure ]:
143+ """
144+ Plotly 코드를 실행하여 Figure 객체를 생성합니다.
145+
146+ Args:
147+ plotly_code: 실행할 Plotly 코드
148+ df: 데이터프레임
149+ dark_mode: 다크 모드 사용 여부
150+
151+ Returns:
152+ Optional[plotly.graph_objs.Figure]: 생성된 Figure 객체 또는 None
153+ """
92154 ldict = {"df" : df , "px" : px , "go" : go }
155+ fig = None
156+
93157 try :
94- exec (plotly_code , globals (), ldict )
158+ exec (plotly_code , globals (), ldict ) # noqa: S102
95159 fig = ldict .get ("fig" , None )
96160
97- except Exception as e :
98-
161+ except Exception :
99162 # Inspect data types
100163 numeric_cols = df .select_dtypes (include = ["number" ]).columns .tolist ()
101164 categorical_cols = df .select_dtypes (
0 commit comments