1010from langchain_core .messages import AIMessage , HumanMessage
1111
1212from llm_utils .connect_db import ConnectDB
13- from llm_utils .graph import builder
14- from llm_utils .enriched_graph import builder as enriched_builder
1513from llm_utils .display_chart import DisplayChart
14+ from llm_utils .enriched_graph import builder as enriched_builder
15+ from llm_utils .graph import builder
1616from llm_utils .llm_response_parser import LLMResponseParser
17+ from llm_utils .token_utils import TokenUtils
1718
19+ TITLE = "Lang2SQL"
1820DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
1921SIDEBAR_OPTIONS = {
20- "show_total_token_usage " : "Show Total Token Usage" ,
22+ "show_token_usage " : "Show Token Usage" ,
2123 "show_result_description" : "Show Result Description" ,
2224 "show_sql" : "Show SQL" ,
2325 "show_question_reinterpreted_by_ai" : "Show User Question Reinterpreted by AI" ,
2729}
2830
2931
30- def summarize_total_tokens (data : list ) -> int :
31- """
32- 메시지 데이터에서 총 토큰 사용량을 집계합니다.
33-
34- Args:
35- data (list): usage_metadata를 포함하는 객체들의 리스트.
36-
37- Returns:
38- int: 총 토큰 사용량 합계.
39- """
40-
41- total_tokens = 0
42- for item in data :
43- token_usage = getattr (item , "usage_metadata" , {})
44- total_tokens += token_usage .get ("total_tokens" , 0 )
45- return total_tokens
46-
47-
4832def execute_query (
4933 * ,
5034 query : str ,
@@ -119,14 +103,22 @@ def display_result(
119103 """
120104
121105 def should_show (_key : str ) -> bool :
122- st .markdown ("---" )
123106 return st .session_state .get (_key , True )
124107
125- if should_show ("show_total_token_usage" ):
126- total_tokens = summarize_total_tokens (res ["messages" ])
127- st .write ("**총 토큰 사용량:**" , total_tokens )
108+ if should_show ("show_token_usage" ):
109+ st .markdown ("---" )
110+ token_summary = TokenUtils .get_token_usage_summary (data = res ["messages" ])
111+ st .write ("**토큰 사용량:**" )
112+ st .markdown (
113+ f"""
114+ - Input tokens: `{ token_summary ['input_tokens' ]} `
115+ - Output tokens: `{ token_summary ['output_tokens' ]} `
116+ - Total tokens: `{ token_summary ['total_tokens' ]} `
117+ """
118+ )
128119
129120 if should_show ("show_sql" ):
121+ st .markdown ("---" )
130122 generated_query = res .get ("generated_query" )
131123 query_text = (
132124 generated_query .content
@@ -148,6 +140,7 @@ def should_show(_key: str) -> bool:
148140 st .code (interpretation )
149141
150142 if should_show ("show_result_description" ):
143+ st .markdown ("---" )
151144 st .markdown ("**결과 설명:**" )
152145 result_message = res ["messages" ][- 1 ].content
153146
@@ -163,14 +156,17 @@ def should_show(_key: str) -> bool:
163156 st .code (interpretation , language = "plaintext" )
164157
165158 if should_show ("show_question_reinterpreted_by_ai" ):
159+ st .markdown ("---" )
166160 st .markdown ("**AI가 재해석한 사용자 질문:**" )
167161 st .code (res ["refined_input" ].content )
168162
169163 if should_show ("show_referenced_tables" ):
164+ st .markdown ("---" )
170165 st .markdown ("**참고한 테이블 목록:**" )
171166 st .write (res .get ("searched_tables" , []))
172167
173168 if should_show ("show_table" ):
169+ st .markdown ("---" )
174170 try :
175171 sql_raw = (
176172 res ["generated_query" ].content
@@ -182,7 +178,9 @@ def should_show(_key: str) -> bool:
182178 st .dataframe (df .head (10 ) if len (df ) > 10 else df )
183179 except Exception as e :
184180 st .error (f"쿼리 실행 중 오류 발생: { e } " )
181+
185182 if should_show ("show_chart" ):
183+ st .markdown ("---" )
186184 df = database .run_sql (sql )
187185 st .markdown ("**쿼리 결과 시각화:**" )
188186 display_code = DisplayChart (
@@ -199,7 +197,7 @@ def should_show(_key: str) -> bool:
199197
200198db = ConnectDB ()
201199
202- st .title ("Lang2SQL" )
200+ st .title (TITLE )
203201
204202# 워크플로우 선택(UI)
205203use_enriched = st .sidebar .checkbox (
0 commit comments