99import requests
1010from langchain .chat_models .base import BaseChatModel
1111from langchain_community .utilities import SQLDatabase
12- from langchain_core .messages import BaseMessage , SystemMessage , HumanMessage , AIMessage
12+ from langchain_core .messages import BaseMessage , SystemMessage , HumanMessage , AIMessage , BaseMessageChunk
1313from sqlalchemy import select
1414from sqlalchemy .orm import load_only
1515
@@ -198,6 +198,7 @@ def generate_analysis(self):
198198 full_thinking_text = ''
199199 full_analysis_text = ''
200200 res = self .llm .stream (analysis_msg )
201+ token_usage = {}
201202 for chunk in res :
202203 print (chunk )
203204 reasoning_content_chunk = ''
@@ -211,9 +212,11 @@ def generate_analysis(self):
211212
212213 full_analysis_text += chunk .content
213214 yield {'content' : chunk .content , 'reasoning_content' : reasoning_content_chunk }
215+ get_token_usage (chunk , token_usage )
214216
215217 analysis_msg .append (AIMessage (full_analysis_text ))
216218 self .record = save_full_analysis_message_and_answer (session = self .session , record_id = self .record .id ,
219+ token_usage = token_usage ,
217220 answer = orjson .dumps ({'content' : full_analysis_text ,
218221 'reasoning_content' : full_thinking_text }).decode (),
219222 full_message = orjson .dumps (history_msg +
@@ -245,6 +248,7 @@ def generate_predict(self):
245248 full_thinking_text = ''
246249 full_predict_text = ''
247250 res = self .llm .stream (predict_msg )
251+ token_usage = {}
248252 for chunk in res :
249253 print (chunk )
250254 reasoning_content_chunk = ''
@@ -258,9 +262,11 @@ def generate_predict(self):
258262
259263 full_predict_text += chunk .content
260264 yield {'content' : chunk .content , 'reasoning_content' : reasoning_content_chunk }
265+ get_token_usage (chunk , token_usage )
261266
262267 predict_msg .append (AIMessage (full_predict_text ))
263268 self .record = save_full_predict_message_and_answer (session = self .session , record_id = self .record .id ,
269+ token_usage = token_usage ,
264270 answer = orjson .dumps ({'content' : full_predict_text ,
265271 'reasoning_content' : full_thinking_text }).decode (),
266272 data = '' ,
@@ -291,6 +297,7 @@ def generate_recommend_questions_task(self):
291297 guess_msg ]).decode ())
292298 full_thinking_text = ''
293299 full_guess_text = ''
300+ token_usage = {}
294301 res = self .llm .stream (guess_msg )
295302 for chunk in res :
296303 print (chunk )
@@ -305,9 +312,11 @@ def generate_recommend_questions_task(self):
305312
306313 full_guess_text += chunk .content
307314 yield {'content' : chunk .content , 'reasoning_content' : reasoning_content_chunk }
315+ get_token_usage (chunk , token_usage )
308316
309317 guess_msg .append (AIMessage (full_guess_text ))
310318 self .record = save_full_recommend_question_message_and_answer (session = self .session , record_id = self .record .id ,
319+ token_usage = token_usage ,
311320 answer = {'content' : full_guess_text ,
312321 'reasoning_content' : full_thinking_text },
313322 full_message = orjson .dumps ([{'type' : msg .type ,
@@ -342,6 +351,7 @@ def select_datasource(self):
342351 datasource_msg ]).decode ())
343352 full_thinking_text = ''
344353 full_text = ''
354+ token_usage = {}
345355 res = self .llm .stream (datasource_msg )
346356 for chunk in res :
347357 print (chunk )
@@ -356,6 +366,7 @@ def select_datasource(self):
356366
357367 full_text += chunk .content
358368 yield {'content' : chunk .content , 'reasoning_content' : reasoning_content_chunk }
369+ get_token_usage (chunk , token_usage )
359370 datasource_msg .append (AIMessage (full_text ))
360371
361372 json_str = extract_nested_json (full_text )
@@ -418,6 +429,7 @@ def generate_sql(self):
418429 self .sql_message ]).decode ())
419430 full_thinking_text = ''
420431 full_sql_text = ''
432+ token_usage = {}
421433 res = self .llm .stream (self .sql_message )
422434 for chunk in res :
423435 print (chunk )
@@ -432,9 +444,11 @@ def generate_sql(self):
432444
433445 full_sql_text += chunk .content
434446 yield {'content' : chunk .content , 'reasoning_content' : reasoning_content_chunk }
447+ get_token_usage (chunk , token_usage )
435448
436449 self .sql_message .append (AIMessage (full_sql_text ))
437450 self .record = save_full_sql_message_and_answer (session = self .session , record_id = self .record .id ,
451+ token_usage = token_usage ,
438452 answer = orjson .dumps ({'content' : full_sql_text ,
439453 'reasoning_content' : full_thinking_text }).decode (),
440454 full_message = orjson .dumps (
@@ -450,6 +464,7 @@ def generate_chart(self):
450464 self .chart_message ]).decode ())
451465 full_thinking_text = ''
452466 full_chart_text = ''
467+ token_usage = {}
453468 res = self .llm .stream (self .chart_message )
454469 for chunk in res :
455470 print (chunk )
@@ -464,9 +479,11 @@ def generate_chart(self):
464479
465480 full_chart_text += chunk .content
466481 yield {'content' : chunk .content , 'reasoning_content' : reasoning_content_chunk }
482+ get_token_usage (chunk , token_usage )
467483
468484 self .chart_message .append (AIMessage (full_chart_text ))
469485 self .record = save_full_chart_message_and_answer (session = self .session , record_id = self .record .id ,
486+ token_usage = token_usage ,
470487 answer = orjson .dumps ({'content' : full_chart_text ,
471488 'reasoning_content' : full_thinking_text }).decode (),
472489 full_message = orjson .dumps (
@@ -740,6 +757,9 @@ def run_analysis_or_predict_task(llm_service: LLMService, action_type: str):
740757 traceback .print_exc ()
741758 # llm_service.save_error(session=session, message=str(e))
742759 yield orjson .dumps ({'content' : str (e ), 'type' : 'error' }).decode () + '\n \n '
760+ finally :
761+ # end
762+ pass
743763
744764
745765def run_recommend_questions_task (llm_service : LLMService ):
@@ -788,3 +808,13 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
788808 requests .post (url = settings .MCP_IMAGE_HOST , json = request_obj )
789809
790810 return f'{ (settings .SERVER_IMAGE_HOST if settings .SERVER_IMAGE_HOST [- 1 ] == "/" else (settings .SERVER_IMAGE_HOST + "/" ))} { file_name } .png'
811+
812+
813+ def get_token_usage (chunk : BaseMessageChunk , token_usage : dict = {}):
814+ try :
815+ if chunk .usage_metadata :
816+ token_usage ['input_tokens' ] = chunk .usage_metadata .get ('input_tokens' )
817+ token_usage ['output_tokens' ] = chunk .usage_metadata .get ('output_tokens' )
818+ token_usage ['total_tokens' ] = chunk .usage_metadata .get ('total_tokens' )
819+ except Exception :
820+ pass
0 commit comments