Skip to content

Commit f556e1d

Browse files
committed
feat: save token usage
1 parent 682c8b4 commit f556e1d

File tree

4 files changed

+137
-13
lines changed

4 files changed

+137
-13
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""023_modify_chat_record
2+
3+
Revision ID: f535d09946f6
4+
Revises: e6b20ae73606
5+
Create Date: 2025-07-11 15:36:18.473133
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = 'f535d09946f6'
15+
down_revision = 'e6b20ae73606'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.alter_column('chat_record', 'token_sql',
23+
existing_type=sa.INTEGER(),
24+
type_=sqlmodel.sql.sqltypes.AutoString(length=256),
25+
existing_nullable=True)
26+
op.alter_column('chat_record', 'token_chart',
27+
existing_type=sa.INTEGER(),
28+
type_=sqlmodel.sql.sqltypes.AutoString(length=256),
29+
existing_nullable=True)
30+
op.alter_column('chat_record', 'token_analysis',
31+
existing_type=sa.INTEGER(),
32+
type_=sqlmodel.sql.sqltypes.AutoString(length=256),
33+
existing_nullable=True)
34+
op.alter_column('chat_record', 'token_predict',
35+
existing_type=sa.INTEGER(),
36+
type_=sqlmodel.sql.sqltypes.AutoString(length=256),
37+
existing_nullable=True)
38+
op.alter_column('chat_record', 'token_recommended_question',
39+
existing_type=sa.INTEGER(),
40+
type_=sqlmodel.sql.sqltypes.AutoString(length=256),
41+
existing_nullable=True)
42+
op.alter_column('chat_record', 'token_select_datasource_question',
43+
existing_type=sa.INTEGER(),
44+
type_=sqlmodel.sql.sqltypes.AutoString(length=256),
45+
existing_nullable=True)
46+
# ### end Alembic commands ###
47+
48+
49+
def downgrade():
50+
# ### commands auto generated by Alembic - please adjust! ###
51+
op.alter_column('chat_record', 'token_select_datasource_question',
52+
existing_type=sqlmodel.sql.sqltypes.AutoString(length=256),
53+
type_=sa.INTEGER(),
54+
existing_nullable=True)
55+
op.alter_column('chat_record', 'token_recommended_question',
56+
existing_type=sqlmodel.sql.sqltypes.AutoString(length=256),
57+
type_=sa.INTEGER(),
58+
existing_nullable=True)
59+
op.alter_column('chat_record', 'token_predict',
60+
existing_type=sqlmodel.sql.sqltypes.AutoString(length=256),
61+
type_=sa.INTEGER(),
62+
existing_nullable=True)
63+
op.alter_column('chat_record', 'token_analysis',
64+
existing_type=sqlmodel.sql.sqltypes.AutoString(length=256),
65+
type_=sa.INTEGER(),
66+
existing_nullable=True)
67+
op.alter_column('chat_record', 'token_chart',
68+
existing_type=sqlmodel.sql.sqltypes.AutoString(length=256),
69+
type_=sa.INTEGER(),
70+
existing_nullable=True)
71+
op.alter_column('chat_record', 'token_sql',
72+
existing_type=sqlmodel.sql.sqltypes.AutoString(length=256),
73+
type_=sa.INTEGER(),
74+
existing_nullable=True)
75+
# ### end Alembic commands ###

backend/apps/chat/curd/chat.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,13 +229,17 @@ def save_full_sql_message(session: SessionDep, record_id: int, full_message: str
229229
return save_full_sql_message_and_answer(session=session, record_id=record_id, full_message=full_message, answer='')
230230

231231

232-
def save_full_sql_message_and_answer(session: SessionDep, record_id: int, answer: str, full_message: str) -> ChatRecord:
232+
def save_full_sql_message_and_answer(session: SessionDep, record_id: int, answer: str, full_message: str,
233+
token_usage: dict = None) -> ChatRecord:
233234
if not record_id:
234235
raise Exception("Record id cannot be None")
235236
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
236237
record.full_sql_message = full_message
237238
record.sql_answer = answer
238239

240+
if token_usage:
241+
record.token_sql = orjson.dumps(token_usage).decode()
242+
239243
result = ChatRecord(**record.model_dump())
240244

241245
session.add(record)
@@ -248,13 +252,16 @@ def save_full_sql_message_and_answer(session: SessionDep, record_id: int, answer
248252

249253

250254
def save_full_analysis_message_and_answer(session: SessionDep, record_id: int, answer: str,
251-
full_message: str) -> ChatRecord:
255+
full_message: str, token_usage: dict = None) -> ChatRecord:
252256
if not record_id:
253257
raise Exception("Record id cannot be None")
254258
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
255259
record.full_analysis_message = full_message
256260
record.analysis = answer
257261

262+
if token_usage:
263+
record.token_analysis = orjson.dumps(token_usage).decode()
264+
258265
result = ChatRecord(**record.model_dump())
259266

260267
session.add(record)
@@ -267,14 +274,17 @@ def save_full_analysis_message_and_answer(session: SessionDep, record_id: int, a
267274

268275

269276
def save_full_predict_message_and_answer(session: SessionDep, record_id: int, answer: str,
270-
full_message: str, data: str) -> ChatRecord:
277+
full_message: str, data: str, token_usage: dict = None) -> ChatRecord:
271278
if not record_id:
272279
raise Exception("Record id cannot be None")
273280
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
274281
record.full_predict_message = full_message
275282
record.predict = answer
276283
record.predict_data = data
277284

285+
if token_usage:
286+
record.token_predict = orjson.dumps(token_usage).decode()
287+
278288
result = ChatRecord(**record.model_dump())
279289

280290
session.add(record)
@@ -288,7 +298,7 @@ def save_full_predict_message_and_answer(session: SessionDep, record_id: int, an
288298

289299
def save_full_select_datasource_message_and_answer(session: SessionDep, record_id: int, answer: str,
290300
full_message: str, datasource: int = None,
291-
engine_type: str = None) -> ChatRecord:
301+
engine_type: str = None, token_usage: dict = None) -> ChatRecord:
292302
if not record_id:
293303
raise Exception("Record id cannot be None")
294304
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
@@ -299,6 +309,9 @@ def save_full_select_datasource_message_and_answer(session: SessionDep, record_i
299309
record.datasource = datasource
300310
record.engine_type = engine_type
301311

312+
if token_usage:
313+
record.token_select_datasource_question = orjson.dumps(token_usage).decode()
314+
302315
result = ChatRecord(**record.model_dump())
303316

304317
session.add(record)
@@ -311,7 +324,7 @@ def save_full_select_datasource_message_and_answer(session: SessionDep, record_i
311324

312325

313326
def save_full_recommend_question_message_and_answer(session: SessionDep, record_id: int, answer: dict = None,
314-
full_message: str = '[]') -> ChatRecord:
327+
full_message: str = '[]', token_usage: dict = None) -> ChatRecord:
315328
if not record_id:
316329
raise Exception("Record id cannot be None")
317330
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
@@ -329,6 +342,9 @@ def save_full_recommend_question_message_and_answer(session: SessionDep, record_
329342
pass
330343
record.recommended_question = json_str
331344

345+
if token_usage:
346+
record.token_recommended_question = orjson.dumps(token_usage).decode()
347+
332348
result = ChatRecord(**record.model_dump())
333349

334350
session.add(record)
@@ -363,13 +379,16 @@ def save_full_chart_message(session: SessionDep, record_id: int, full_message: s
363379

364380

365381
def save_full_chart_message_and_answer(session: SessionDep, record_id: int, answer: str,
366-
full_message: str) -> ChatRecord:
382+
full_message: str, token_usage: dict = None) -> ChatRecord:
367383
if not record_id:
368384
raise Exception("Record id cannot be None")
369385
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
370386
record.full_chart_message = full_message
371387
record.chart_answer = answer
372388

389+
if token_usage:
390+
record.token_chart = orjson.dumps(token_usage).decode()
391+
373392
result = ChatRecord(**record.model_dump())
374393

375394
session.add(record)

backend/apps/chat/models/chat_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ class ChatRecord(SQLModel, table=True):
4949
recommended_question: str = Field(sa_column=Column(Text, nullable=True))
5050
datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True))
5151
full_sql_message: str = Field(sa_column=Column(Text, nullable=True))
52-
token_sql: int = Field(default=0, nullable=True)
52+
token_sql: str = Field(max_length=256, nullable=True)
5353
full_chart_message: str = Field(sa_column=Column(Text, nullable=True))
54-
token_chart: int = Field(default=0, nullable=True)
54+
token_chart: str = Field(max_length=256, nullable=True)
5555
full_analysis_message: str = Field(sa_column=Column(Text, nullable=True))
56-
token_analysis: int = Field(default=0, nullable=True)
56+
token_analysis: str = Field(max_length=256, nullable=True)
5757
full_predict_message: str = Field(sa_column=Column(Text, nullable=True))
58-
token_predict: int = Field(default=0, nullable=True)
58+
token_predict: str = Field(max_length=256, nullable=True)
5959
full_recommended_question_message: str = Field(sa_column=Column(Text, nullable=True))
60-
token_recommended_question: int = Field(default=0, nullable=True)
60+
token_recommended_question: str = Field(max_length=256, nullable=True)
6161
full_select_datasource_message: str = Field(sa_column=Column(Text, nullable=True))
62-
token_select_datasource_question: int = Field(default=0, nullable=True)
62+
token_select_datasource_question: str = Field(max_length=256, nullable=True)
6363
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
6464
error: str = Field(sa_column=Column(Text, nullable=True))
6565
run_time: float = Field(default=0)

backend/apps/chat/task/llm.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import requests
1010
from langchain.chat_models.base import BaseChatModel
1111
from langchain_community.utilities import SQLDatabase
12-
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage
12+
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk
1313
from sqlalchemy import select
1414
from sqlalchemy.orm import load_only
1515

@@ -198,6 +198,7 @@ def generate_analysis(self):
198198
full_thinking_text = ''
199199
full_analysis_text = ''
200200
res = self.llm.stream(analysis_msg)
201+
token_usage = {}
201202
for chunk in res:
202203
print(chunk)
203204
reasoning_content_chunk = ''
@@ -211,9 +212,11 @@ def generate_analysis(self):
211212

212213
full_analysis_text += chunk.content
213214
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
215+
get_token_usage(chunk, token_usage)
214216

215217
analysis_msg.append(AIMessage(full_analysis_text))
216218
self.record = save_full_analysis_message_and_answer(session=self.session, record_id=self.record.id,
219+
token_usage=token_usage,
217220
answer=orjson.dumps({'content': full_analysis_text,
218221
'reasoning_content': full_thinking_text}).decode(),
219222
full_message=orjson.dumps(history_msg +
@@ -245,6 +248,7 @@ def generate_predict(self):
245248
full_thinking_text = ''
246249
full_predict_text = ''
247250
res = self.llm.stream(predict_msg)
251+
token_usage = {}
248252
for chunk in res:
249253
print(chunk)
250254
reasoning_content_chunk = ''
@@ -258,9 +262,11 @@ def generate_predict(self):
258262

259263
full_predict_text += chunk.content
260264
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
265+
get_token_usage(chunk, token_usage)
261266

262267
predict_msg.append(AIMessage(full_predict_text))
263268
self.record = save_full_predict_message_and_answer(session=self.session, record_id=self.record.id,
269+
token_usage=token_usage,
264270
answer=orjson.dumps({'content': full_predict_text,
265271
'reasoning_content': full_thinking_text}).decode(),
266272
data='',
@@ -291,6 +297,7 @@ def generate_recommend_questions_task(self):
291297
guess_msg]).decode())
292298
full_thinking_text = ''
293299
full_guess_text = ''
300+
token_usage = {}
294301
res = self.llm.stream(guess_msg)
295302
for chunk in res:
296303
print(chunk)
@@ -305,9 +312,11 @@ def generate_recommend_questions_task(self):
305312

306313
full_guess_text += chunk.content
307314
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
315+
get_token_usage(chunk, token_usage)
308316

309317
guess_msg.append(AIMessage(full_guess_text))
310318
self.record = save_full_recommend_question_message_and_answer(session=self.session, record_id=self.record.id,
319+
token_usage=token_usage,
311320
answer={'content': full_guess_text,
312321
'reasoning_content': full_thinking_text},
313322
full_message=orjson.dumps([{'type': msg.type,
@@ -342,6 +351,7 @@ def select_datasource(self):
342351
datasource_msg]).decode())
343352
full_thinking_text = ''
344353
full_text = ''
354+
token_usage = {}
345355
res = self.llm.stream(datasource_msg)
346356
for chunk in res:
347357
print(chunk)
@@ -356,6 +366,7 @@ def select_datasource(self):
356366

357367
full_text += chunk.content
358368
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
369+
get_token_usage(chunk, token_usage)
359370
datasource_msg.append(AIMessage(full_text))
360371

361372
json_str = extract_nested_json(full_text)
@@ -418,6 +429,7 @@ def generate_sql(self):
418429
self.sql_message]).decode())
419430
full_thinking_text = ''
420431
full_sql_text = ''
432+
token_usage = {}
421433
res = self.llm.stream(self.sql_message)
422434
for chunk in res:
423435
print(chunk)
@@ -432,9 +444,11 @@ def generate_sql(self):
432444

433445
full_sql_text += chunk.content
434446
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
447+
get_token_usage(chunk, token_usage)
435448

436449
self.sql_message.append(AIMessage(full_sql_text))
437450
self.record = save_full_sql_message_and_answer(session=self.session, record_id=self.record.id,
451+
token_usage=token_usage,
438452
answer=orjson.dumps({'content': full_sql_text,
439453
'reasoning_content': full_thinking_text}).decode(),
440454
full_message=orjson.dumps(
@@ -450,6 +464,7 @@ def generate_chart(self):
450464
self.chart_message]).decode())
451465
full_thinking_text = ''
452466
full_chart_text = ''
467+
token_usage = {}
453468
res = self.llm.stream(self.chart_message)
454469
for chunk in res:
455470
print(chunk)
@@ -464,9 +479,11 @@ def generate_chart(self):
464479

465480
full_chart_text += chunk.content
466481
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
482+
get_token_usage(chunk, token_usage)
467483

468484
self.chart_message.append(AIMessage(full_chart_text))
469485
self.record = save_full_chart_message_and_answer(session=self.session, record_id=self.record.id,
486+
token_usage=token_usage,
470487
answer=orjson.dumps({'content': full_chart_text,
471488
'reasoning_content': full_thinking_text}).decode(),
472489
full_message=orjson.dumps(
@@ -740,6 +757,9 @@ def run_analysis_or_predict_task(llm_service: LLMService, action_type: str):
740757
traceback.print_exc()
741758
# llm_service.save_error(session=session, message=str(e))
742759
yield orjson.dumps({'content': str(e), 'type': 'error'}).decode() + '\n\n'
760+
finally:
761+
# end
762+
pass
743763

744764

745765
def run_recommend_questions_task(llm_service: LLMService):
@@ -788,3 +808,13 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
788808
requests.post(url=settings.MCP_IMAGE_HOST, json=request_obj)
789809

790810
return f'{(settings.SERVER_IMAGE_HOST if settings.SERVER_IMAGE_HOST[-1] == "/" else (settings.SERVER_IMAGE_HOST + "/"))}{file_name}.png'
811+
812+
813+
def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}):
814+
try:
815+
if chunk.usage_metadata:
816+
token_usage['input_tokens'] = chunk.usage_metadata.get('input_tokens')
817+
token_usage['output_tokens'] = chunk.usage_metadata.get('output_tokens')
818+
token_usage['total_tokens'] = chunk.usage_metadata.get('total_tokens')
819+
except Exception:
820+
pass

0 commit comments

Comments
 (0)