Skip to content

Commit 430fdd9

Browse files
authored
Merge pull request #161 from jhongy1994/refactor/make-db-type-optional
refactor/make-db-type-optional - db_type 파라미터를 Optional하게 받도록 수정
2 parents 1275902 + 3100b17 commit 430fdd9

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

interface/lang2sql.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,12 @@ def execute_query(
7878
def display_result(
7979
*,
8080
res: dict,
81-
database: BaseConnector,
8281
) -> None:
8382
"""
8483
Lang2SQL 실행 결과를 Streamlit 화면에 출력합니다.
8584
8685
Args:
8786
res (dict): Lang2SQL 실행 결과 딕셔너리.
88-
database (ConnectDB): SQL 쿼리 실행을 위한 데이터베이스 연결 객체.
8987
9088
출력 항목:
9189
- 총 토큰 사용량
@@ -240,8 +238,8 @@ def _as_float(value):
240238
if not has_query:
241239
st.info("QUERY_MAKER 없이 실행되었습니다. 검색된 테이블 정보만 표시합니다.")
242240

243-
if show_table_section:
244-
st.markdown("---")
241+
if show_table_section or show_chart_section:
242+
database = get_db_connector()
245243
try:
246244
sql_raw = (
247245
res["generated_query"].content
@@ -251,23 +249,24 @@ def _as_float(value):
251249
if isinstance(sql_raw, str):
252250
sql = LLMResponseParser.extract_sql(sql_raw)
253251
df = database.run_sql(sql)
254-
st.dataframe(df.head(10) if len(df) > 10 else df)
255252
else:
256253
st.error("SQL 원본이 문자열이 아닙니다.")
257254
except Exception as e:
255+
st.markdown("---")
258256
st.error(f"쿼리 실행 중 오류 발생: {e}")
257+
df = None
259258

260-
if show_chart_section:
261-
st.markdown("---")
262-
try:
263-
sql_raw = (
264-
res["generated_query"].content
265-
if isinstance(res["generated_query"], AIMessage)
266-
else str(res["generated_query"])
267-
)
268-
if isinstance(sql_raw, str):
269-
sql = LLMResponseParser.extract_sql(sql_raw)
270-
df = database.run_sql(sql)
259+
if df is not None and show_table_section:
260+
st.markdown("---")
261+
st.markdown("**쿼리 실행 결과:**")
262+
try:
263+
st.dataframe(df.head(10) if len(df) > 10 else df)
264+
except Exception as e:
265+
st.error(f"결과 테이블 생성 중 오류 발생: {e}")
266+
267+
if df is not None and show_chart_section:
268+
st.markdown("---")
269+
try:
271270
st.markdown("**쿼리 결과 시각화:**")
272271
try:
273272
if len(res["messages"]) > 1:
@@ -292,13 +291,9 @@ def _as_float(value):
292291
plotly_code=display_code.generate_plotly_code(), df=df
293292
)
294293
st.plotly_chart(fig)
295-
else:
296-
st.error("SQL 원본이 문자열이 아닙니다.")
297-
except Exception as e:
298-
st.error(f"차트 생성 중 오류 발생: {e}")
299-
294+
except Exception as e:
295+
st.error(f"차트 생성 중 오류 발생: {e}")
300296

301-
db = get_db_connector()
302297

303298
st.title(TITLE)
304299

@@ -401,4 +396,4 @@ def _as_float(value):
401396
top_n=user_top_n,
402397
device=device,
403398
)
404-
display_result(res=result, database=db)
399+
display_result(res=result)

0 commit comments

Comments
 (0)