Skip to content

Commit 9c48299

Browse files
Merge pull request #117 from CausalInferenceLab/feature/116-refactor-extract-token-utils
LLM 토큰 사용량 집계 로직 TokenUtils 클래스로 분리
2 parents 6787f03 + ead12d5 commit 9c48299

File tree

2 files changed

+117
-26
lines changed

2 files changed

+117
-26
lines changed

interface/lang2sql.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
from langchain_core.messages import AIMessage, HumanMessage
1111

1212
from llm_utils.connect_db import ConnectDB
13-
from llm_utils.graph import builder
14-
from llm_utils.enriched_graph import builder as enriched_builder
1513
from llm_utils.display_chart import DisplayChart
14+
from llm_utils.enriched_graph import builder as enriched_builder
15+
from llm_utils.graph import builder
1616
from llm_utils.llm_response_parser import LLMResponseParser
17+
from llm_utils.token_utils import TokenUtils
1718

19+
TITLE = "Lang2SQL"
1820
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
1921
SIDEBAR_OPTIONS = {
20-
"show_total_token_usage": "Show Total Token Usage",
22+
"show_token_usage": "Show Token Usage",
2123
"show_result_description": "Show Result Description",
2224
"show_sql": "Show SQL",
2325
"show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI",
@@ -27,24 +29,6 @@
2729
}
2830

2931

30-
def summarize_total_tokens(data: list) -> int:
31-
"""
32-
메시지 데이터에서 총 토큰 사용량을 집계합니다.
33-
34-
Args:
35-
data (list): usage_metadata를 포함하는 객체들의 리스트.
36-
37-
Returns:
38-
int: 총 토큰 사용량 합계.
39-
"""
40-
41-
total_tokens = 0
42-
for item in data:
43-
token_usage = getattr(item, "usage_metadata", {})
44-
total_tokens += token_usage.get("total_tokens", 0)
45-
return total_tokens
46-
47-
4832
def execute_query(
4933
*,
5034
query: str,
@@ -119,14 +103,22 @@ def display_result(
119103
"""
120104

121105
def should_show(_key: str) -> bool:
122-
st.markdown("---")
123106
return st.session_state.get(_key, True)
124107

125-
if should_show("show_total_token_usage"):
126-
total_tokens = summarize_total_tokens(res["messages"])
127-
st.write("**총 토큰 사용량:**", total_tokens)
108+
if should_show("show_token_usage"):
109+
st.markdown("---")
110+
token_summary = TokenUtils.get_token_usage_summary(data=res["messages"])
111+
st.write("**토큰 사용량:**")
112+
st.markdown(
113+
f"""
114+
- Input tokens: `{token_summary['input_tokens']}`
115+
- Output tokens: `{token_summary['output_tokens']}`
116+
- Total tokens: `{token_summary['total_tokens']}`
117+
"""
118+
)
128119

129120
if should_show("show_sql"):
121+
st.markdown("---")
130122
generated_query = res.get("generated_query")
131123
query_text = (
132124
generated_query.content
@@ -148,6 +140,7 @@ def should_show(_key: str) -> bool:
148140
st.code(interpretation)
149141

150142
if should_show("show_result_description"):
143+
st.markdown("---")
151144
st.markdown("**결과 설명:**")
152145
result_message = res["messages"][-1].content
153146

@@ -163,14 +156,17 @@ def should_show(_key: str) -> bool:
163156
st.code(interpretation, language="plaintext")
164157

165158
if should_show("show_question_reinterpreted_by_ai"):
159+
st.markdown("---")
166160
st.markdown("**AI가 재해석한 사용자 질문:**")
167161
st.code(res["refined_input"].content)
168162

169163
if should_show("show_referenced_tables"):
164+
st.markdown("---")
170165
st.markdown("**참고한 테이블 목록:**")
171166
st.write(res.get("searched_tables", []))
172167

173168
if should_show("show_table"):
169+
st.markdown("---")
174170
try:
175171
sql_raw = (
176172
res["generated_query"].content
@@ -182,7 +178,9 @@ def should_show(_key: str) -> bool:
182178
st.dataframe(df.head(10) if len(df) > 10 else df)
183179
except Exception as e:
184180
st.error(f"쿼리 실행 중 오류 발생: {e}")
181+
185182
if should_show("show_chart"):
183+
st.markdown("---")
186184
df = database.run_sql(sql)
187185
st.markdown("**쿼리 결과 시각화:**")
188186
display_code = DisplayChart(
@@ -199,7 +197,7 @@ def should_show(_key: str) -> bool:
199197

200198
db = ConnectDB()
201199

202-
st.title("Lang2SQL")
200+
st.title(TITLE)
203201

204202
# 워크플로우 선택(UI)
205203
use_enriched = st.sidebar.checkbox(

llm_utils/token_utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
token_utils.py
3+
4+
LLM 응답 메시지에서 토큰 사용량을 집계하기 위한 유틸리티 모듈입니다.
5+
6+
이 모듈은 LLM의 `usage_metadata` 필드를 기반으로 입력 토큰, 출력 토큰, 총 토큰 사용량을 계산하는 기능을 제공합니다.
7+
Streamlit, LangChain 등 LLM 응답을 다루는 애플리케이션에서 비용 분석, 사용량 추적 등에 활용할 수 있습니다.
8+
"""
9+
10+
import logging
11+
from typing import Any, List
12+
13+
logging.basicConfig(
14+
level=logging.INFO,
15+
format="%(asctime)s [%(levelname)s] %(message)s",
16+
datefmt="%Y-%m-%d %H:%M:%S",
17+
)
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class TokenUtils:
22+
"""
23+
LLM 토큰 사용량 집계 유틸리티 클래스입니다.
24+
25+
이 클래스는 LLM 응답 메시지 리스트에서 usage_metadata 필드를 추출하여
26+
input_tokens, output_tokens, total_tokens의 합계를 계산합니다.
27+
28+
예를 들어, LangChain 또는 OpenAI API 응답 메시지 객체의 토큰 사용 정보를 분석하고자 할 때
29+
활용할 수 있습니다.
30+
31+
사용 예:
32+
>>> from token_utils import TokenUtils
33+
>>> summary = TokenUtils.get_token_usage_summary(messages)
34+
>>> print(summary["total_tokens"])
35+
36+
반환 형식:
37+
{
38+
"input_tokens": int,
39+
"output_tokens": int,
40+
"total_tokens": int,
41+
}
42+
"""
43+
44+
@staticmethod
45+
def get_token_usage_summary(*, data: List[Any]) -> dict:
46+
"""
47+
메시지 데이터에서 input/output/total 토큰 사용량을 각각 집계합니다.
48+
49+
Args:
50+
data (List[Any]): 각 항목이 usage_metadata를 포함할 수 있는 객체 리스트.
51+
52+
Returns:
53+
dict: {
54+
"input_tokens": int,
55+
"output_tokens": int,
56+
"total_tokens": int
57+
}
58+
"""
59+
60+
input_tokens = 0
61+
output_tokens = 0
62+
total_tokens = 0
63+
64+
for idx, item in enumerate(data):
65+
token_usage = getattr(item, "usage_metadata", {})
66+
in_tok = token_usage.get("input_tokens", 0)
67+
out_tok = token_usage.get("output_tokens", 0)
68+
total_tok = token_usage.get("total_tokens", 0)
69+
70+
logger.debug(
71+
"Message[%d] → input=%d, output=%d, total=%d",
72+
idx,
73+
in_tok,
74+
out_tok,
75+
total_tok,
76+
)
77+
78+
input_tokens += in_tok
79+
output_tokens += out_tok
80+
total_tokens += total_tok
81+
82+
logger.info(
83+
"Token usage summary → input: %d, output: %d, total: %d",
84+
input_tokens,
85+
output_tokens,
86+
total_tokens,
87+
)
88+
89+
return {
90+
"input_tokens": input_tokens,
91+
"output_tokens": output_tokens,
92+
"total_tokens": total_tokens,
93+
}

0 commit comments

Comments
 (0)