22import warnings
33from typing import Any , List , Union , Dict
44
5+ import numpy as np
56import orjson
67import pandas as pd
78from langchain_community .utilities import SQLDatabase
@@ -543,16 +544,16 @@ def execute_sql_with_db(db: SQLDatabase, sql: str) -> str:
543544 raise RuntimeError (error_msg )
544545
545546
546- def run_task (llm_service : LLMService , session : SessionDep , stream : bool = True ):
547+ def run_task (llm_service : LLMService , session : SessionDep , in_chat : bool = True ):
547548 try :
548549 # return id
549- if stream :
550+ if in_chat :
550551 yield orjson .dumps ({'type' : 'id' , 'id' : llm_service .get_record ().id }).decode () + '\n \n '
551552
552553 # select datasource if datasource is none
553554 if not llm_service .ds :
554555 ds_res = llm_service .select_datasource ()
555- if stream :
556+ if in_chat :
556557 for chunk in ds_res :
557558 yield orjson .dumps ({'content' : chunk , 'type' : 'datasource-result' }).decode () + '\n \n '
558559 yield orjson .dumps ({'id' : llm_service .ds .id , 'datasource_name' : llm_service .ds .name ,
@@ -565,63 +566,81 @@ def run_task(llm_service: LLMService, session: SessionDep, stream: bool = True):
565566 full_sql_text = ''
566567 for chunk in sql_res :
567568 full_sql_text += chunk
568- if stream :
569+ if in_chat :
569570 yield orjson .dumps ({'content' : chunk , 'type' : 'sql-result' }).decode () + '\n \n '
570- if stream :
571+ if in_chat :
571572 yield orjson .dumps ({'type' : 'info' , 'msg' : 'sql generated' }).decode () + '\n \n '
572573
573574 # filter sql
574575 print (full_sql_text )
575576 sql = llm_service .check_save_sql (res = full_sql_text )
576577 print (sql )
577- if stream :
578+ if in_chat :
578579 yield orjson .dumps ({'content' : sql , 'type' : 'sql' }).decode () + '\n \n '
580+ else :
581+ yield f'```sql\n { sql } \n ```\n \n '
579582
580583 # execute sql
581584 result = llm_service .execute_sql (sql = sql )
582585 llm_service .save_sql_data (data_obj = result )
583- if stream :
586+ if in_chat :
584587 yield orjson .dumps ({'content' : orjson .dumps (result ).decode (), 'type' : 'sql-data' }).decode () + '\n \n '
585588
586589 # generate chart
587590 chart_res = llm_service .generate_chart ()
588591 full_chart_text = ''
589592 for chunk in chart_res :
590593 full_chart_text += chunk
591- if stream :
594+ if in_chat :
592595 yield orjson .dumps ({'content' : chunk , 'type' : 'chart-result' }).decode () + '\n \n '
593- if stream :
596+ if in_chat :
594597 yield orjson .dumps ({'type' : 'info' , 'msg' : 'chart generated' }).decode () + '\n \n '
595598
596599 # filter chart
597600 print (full_chart_text )
598601 chart = llm_service .check_save_chart (res = full_chart_text )
599602 print (chart )
600- if stream :
603+ if in_chat :
601604 yield orjson .dumps ({'content' : orjson .dumps (chart ).decode (), 'type' : 'chart' }).decode () + '\n \n '
605+ else :
606+ data = []
607+ _fields = {}
608+ if chart .get ('columns' ):
609+ for _column in chart .get ('columns' ):
610+ if _column :
611+ _fields [_column .get ('value' )] = _column .get ('name' )
612+ if chart .get ('axis' ):
613+ if chart .get ('axis' ).get ('x' ):
614+ _fields [chart .get ('axis' ).get ('x' ).get ('value' )] = chart .get ('axis' ).get ('x' ).get ('name' )
615+ if chart .get ('axis' ).get ('y' ):
616+ _fields [chart .get ('axis' ).get ('y' ).get ('value' )] = chart .get ('axis' ).get ('y' ).get ('name' )
617+ if chart .get ('axis' ).get ('series' ):
618+ _fields [chart .get ('axis' ).get ('series' ).get ('value' )] = chart .get ('axis' ).get ('series' ).get ('name' )
619+ _fields_list = []
620+ _fields_skip = False
621+ for _data in result .get ('data' ):
622+ _row = []
623+ for field in result .get ('fields' ):
624+ _row .append (_data .get (field ))
625+ if not _fields_skip :
626+ _fields_list .append (field if not _fields .get (field ) else _fields .get (field ))
627+ data .append (_row )
628+ _fields_skip = True
629+ df = pd .DataFrame (np .array (data ), columns = _fields_list )
630+ markdown_table = df .to_markdown (index = False )
631+ yield markdown_table + '\n \n '
602632
603633 record = llm_service .finish ()
604- if stream :
634+ if in_chat :
605635 yield orjson .dumps ({'type' : 'finish' }).decode () + '\n \n '
606636 else :
607- md_str = f'```sql\n { sql } \n ```\n \n '
608637 # todo generate picture
609- if chart ['type' ] == 'table' :
610- data = {}
611- for _data in result ['data' ]:
612- for field in result ['fields' ]:
613- if not data [field ]:
614- data [field ] = []
615- data [field ].append (_data [field ])
616- df = pd .DataFrame (data , columns = result ['fields' ])
617- markdown_table = df .to_markdown (index = False )
618- md_str += markdown_table
619- else :
620- md_str += ''
638+ if chart ['type' ] != 'table' :
639+ yield '# todo generate chart picture'
621640
622641 except Exception as e :
623642 llm_service .save_error (message = str (e ))
624- if stream :
643+ if in_chat :
625644 yield orjson .dumps ({'content' : str (e ), 'type' : 'error' }).decode () + '\n \n '
626645 else :
627646 raise e
0 commit comments