|
28 | 28 | get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \ |
29 | 29 | get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \ |
30 | 30 | get_last_execute_sql_error |
31 | | -from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum |
| 31 | +from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ |
| 32 | + ChatFinishStep |
32 | 33 | from apps.data_training.curd.data_training import get_training_template |
33 | 34 | from apps.datasource.crud.datasource import get_table_schema |
34 | 35 | from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user |
@@ -934,17 +935,20 @@ def await_result(self): |
934 | 935 | break |
935 | 936 | yield chunk |
936 | 937 |
|
937 | | - def run_task_async(self, in_chat: bool = True, stream: bool = True): |
| 938 | + def run_task_async(self, in_chat: bool = True, stream: bool = True, |
| 939 | + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART): |
938 | 940 | if in_chat: |
939 | 941 | stream = True |
940 | | - self.future = executor.submit(self.run_task_cache, in_chat, stream) |
| 942 | + self.future = executor.submit(self.run_task_cache, in_chat, stream, finish_step) |
941 | 943 |
|
942 | | - def run_task_cache(self, in_chat: bool = True, stream: bool = True): |
943 | | - for chunk in self.run_task(in_chat, stream): |
| 944 | + def run_task_cache(self, in_chat: bool = True, stream: bool = True, |
| 945 | + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART): |
| 946 | + for chunk in self.run_task(in_chat, stream, finish_step): |
944 | 947 | self.chunk_list.append(chunk) |
945 | 948 |
|
946 | | - def run_task(self, in_chat: bool = True, stream: bool = True): |
947 | | - json_result = {'success': True} |
| 949 | + def run_task(self, in_chat: bool = True, stream: bool = True, |
| 950 | + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART): |
| 951 | + json_result: Dict[str, Any] = {'success': True} |
948 | 952 | try: |
949 | 953 | if self.ds: |
950 | 954 | oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1 |
@@ -1066,13 +1070,47 @@ def run_task(self, in_chat: bool = True, stream: bool = True): |
1066 | 1070 | subsql) |
1067 | 1071 | real_execute_sql = assistant_dynamic_sql |
1068 | 1072 |
|
| 1073 | + if finish_step.value <= ChatFinishStep.GENERATE_SQL.value: |
| 1074 | + if in_chat: |
| 1075 | + yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n' |
| 1076 | + if not stream: |
| 1077 | + yield json_result |
| 1078 | + return |
| 1079 | + |
1069 | 1080 | result = self.execute_sql(sql=real_execute_sql) |
1070 | 1081 | self.save_sql_data(data_obj=result) |
1071 | 1082 | if in_chat: |
1072 | 1083 | yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n' |
1073 | 1084 | if not stream: |
1074 | 1085 | json_result['data'] = result.get('data') |
1075 | 1086 |
|
| 1087 | + if finish_step.value <= ChatFinishStep.QUERY_DATA.value: |
| 1088 | + if stream: |
| 1089 | + if in_chat: |
| 1090 | + yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n' |
| 1091 | + else: |
| 1092 | + data = [] |
| 1093 | + _fields_list = [] |
| 1094 | + _fields_skip = False |
| 1095 | + for _data in result.get('data'): |
| 1096 | + _row = [] |
| 1097 | + for field in result.get('fields'): |
| 1098 | + _row.append(_data.get(field)) |
| 1099 | + if not _fields_skip: |
| 1100 | + _fields_list.append(field) |
| 1101 | + data.append(_row) |
| 1102 | + _fields_skip = True |
| 1103 | + |
| 1104 | + if not data or not _fields_list: |
| 1105 | + yield 'The SQL execution result is empty.\n\n' |
| 1106 | + else: |
| 1107 | + df = pd.DataFrame(np.array(data), columns=_fields_list) |
| 1108 | + markdown_table = df.to_markdown(index=False) |
| 1109 | + yield markdown_table + '\n\n' |
| 1110 | + else: |
| 1111 | + yield json_result |
| 1112 | + return |
| 1113 | + |
1076 | 1114 | # generate chart |
1077 | 1115 | chart_res = self.generate_chart(chart_type) |
1078 | 1116 | full_chart_text = '' |
|
0 commit comments