|
1 | 1 | import asyncio |
| 2 | +import hashlib |
2 | 3 | import io |
| 4 | +import os |
| 5 | +import uuid |
| 6 | +from http.client import HTTPException |
3 | 7 | from typing import Optional |
4 | 8 |
|
5 | 9 | import pandas as pd |
6 | | -from fastapi import APIRouter, Query |
7 | | -from fastapi.responses import StreamingResponse |
| 10 | +from fastapi import APIRouter, File, UploadFile, Query |
| 11 | +from fastapi.responses import StreamingResponse, FileResponse |
8 | 12 |
|
9 | 13 | from apps.chat.models.chat_model import AxisObj |
10 | 14 | from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training, \ |
11 | | - enable_training, get_all_data_training |
| 15 | + enable_training, get_all_data_training, batch_create_training |
12 | 16 | from apps.data_training.models.data_training_model import DataTrainingInfo |
| 17 | +from common.core.config import settings |
13 | 18 | from common.core.deps import SessionDep, CurrentUser, Trans |
14 | 19 | from common.utils.data_format import DataFormat |
15 | 20 |
|
@@ -90,3 +95,146 @@ def inner(): |
90 | 95 |
|
91 | 96 | result = await asyncio.to_thread(inner) |
92 | 97 | return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") |
| 98 | + |
| 99 | + |
| 100 | +path = settings.EXCEL_PATH |
| 101 | + |
| 102 | +from sqlalchemy.orm import sessionmaker, scoped_session |
| 103 | +from common.core.db import engine |
| 104 | +from sqlmodel import Session |
| 105 | + |
| 106 | +session_maker = scoped_session(sessionmaker(bind=engine, class_=Session)) |
| 107 | + |
| 108 | + |
| 109 | +@router.post("/uploadExcel") |
| 110 | +async def upload_excel(trans: Trans, current_user: CurrentUser, file: UploadFile = File(...)): |
| 111 | + ALLOWED_EXTENSIONS = {"xlsx", "xls"} |
| 112 | + if not file.filename.lower().endswith(tuple(ALLOWED_EXTENSIONS)): |
| 113 | + raise HTTPException(400, "Only support .xlsx/.xls") |
| 114 | + |
| 115 | + os.makedirs(path, exist_ok=True) |
| 116 | + base_filename = f"{file.filename.split('.')[0]}_{hashlib.sha256(uuid.uuid4().bytes).hexdigest()[:10]}" |
| 117 | + filename = f"{base_filename}.{file.filename.split('.')[1]}" |
| 118 | + save_path = os.path.join(path, filename) |
| 119 | + with open(save_path, "wb") as f: |
| 120 | + f.write(await file.read()) |
| 121 | + |
| 122 | + oid = current_user.oid |
| 123 | + |
| 124 | + use_cols = [0, 1, 2] # 问题, 描述, 数据源名称 |
| 125 | + # 根据oid确定要读取的列 |
| 126 | + if oid == 1: |
| 127 | + use_cols = [0, 1, 2, 3] # 问题, 描述, 数据源名称, 高级应用名称 |
| 128 | + |
| 129 | + def inner(): |
| 130 | + |
| 131 | + session = session_maker() |
| 132 | + |
| 133 | + sheet_names = pd.ExcelFile(save_path).sheet_names |
| 134 | + |
| 135 | + import_data = [] |
| 136 | + |
| 137 | + for sheet_name in sheet_names: |
| 138 | + |
| 139 | + df = pd.read_excel( |
| 140 | + save_path, |
| 141 | + sheet_name=sheet_name, |
| 142 | + engine='calamine', |
| 143 | + header=0, |
| 144 | + usecols=use_cols, |
| 145 | + dtype=str |
| 146 | + ).fillna("") |
| 147 | + |
| 148 | + for index, row in df.iterrows(): |
| 149 | + # 跳过空行 |
| 150 | + if row.isnull().all(): |
| 151 | + continue |
| 152 | + |
| 153 | + question = row[0].strip() if pd.notna(row[0]) and row[0].strip() else None |
| 154 | + description = row[1].strip() if pd.notna(row[1]) and row[1].strip() else None |
| 155 | + datasource_name = row[2].strip() if pd.notna(row[2]) and row[2].strip() else None |
| 156 | + |
| 157 | + advanced_application_name = None |
| 158 | + if oid == 1 and len(row) > 3: |
| 159 | + advanced_application_name = row[3].strip() if pd.notna(row[3]) and row[3].strip() else None |
| 160 | + |
| 161 | + if oid == 1: |
| 162 | + import_data.append( |
| 163 | + DataTrainingInfo(oid=oid, question=question, description=description, |
| 164 | + datasource_name=datasource_name, |
| 165 | + advanced_application_name=advanced_application_name)) |
| 166 | + else: |
| 167 | + import_data.append( |
| 168 | + DataTrainingInfo(oid=oid, question=question, description=description, |
| 169 | + datasource_name=datasource_name)) |
| 170 | + |
| 171 | + res = batch_create_training(session, import_data, oid, trans) |
| 172 | + |
| 173 | + failed_records = res['failed_records'] |
| 174 | + |
| 175 | + error_excel_filename = None |
| 176 | + |
| 177 | + if len(failed_records) > 0: |
| 178 | + data_list = [] |
| 179 | + for obj in failed_records: |
| 180 | + _data = { |
| 181 | + "question": obj['data'].question, |
| 182 | + "description": obj['data'].description, |
| 183 | + "datasource_name": obj['data'].datasource_name, |
| 184 | + "advanced_application_name": obj['data'].advanced_application_name, |
| 185 | + "errors": obj['errors'] |
| 186 | + } |
| 187 | + data_list.append(_data) |
| 188 | + |
| 189 | + fields = [] |
| 190 | + fields.append(AxisObj(name=trans('i18n_data_training.problem_description'), value='question')) |
| 191 | + fields.append(AxisObj(name=trans('i18n_data_training.sample_sql'), value='description')) |
| 192 | + fields.append(AxisObj(name=trans('i18n_data_training.effective_data_sources'), value='datasource_name')) |
| 193 | + if current_user.oid == 1: |
| 194 | + fields.append( |
| 195 | + AxisObj(name=trans('i18n_data_training.advanced_application'), value='advanced_application_name')) |
| 196 | + fields.append(AxisObj(name=trans('i18n_data_training.error_info'), value='errors')) |
| 197 | + |
| 198 | + md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list) |
| 199 | + |
| 200 | + df = pd.DataFrame(md_data, columns=_fields_list) |
| 201 | + error_excel_filename = f"{base_filename}_error.xlsx" |
| 202 | + save_error_path = os.path.join(path, error_excel_filename) |
| 203 | + # 保存 DataFrame 到 Excel |
| 204 | + df.to_excel(save_error_path, index=False) |
| 205 | + |
| 206 | + return { |
| 207 | + 'success_count': res['success_count'], |
| 208 | + 'failed_count': len(failed_records), |
| 209 | + 'duplicate_count': res['duplicate_count'], |
| 210 | + 'original_count': res['original_count'], |
| 211 | + 'error_excel_filename': error_excel_filename, |
| 212 | + } |
| 213 | + |
| 214 | + return await asyncio.to_thread(inner) |
| 215 | + |
| 216 | + |
| 217 | +@router.get("/download-fail-info/{filename}") |
| 218 | +async def download_excel(filename: str, trans: Trans): |
| 219 | + """ |
| 220 | + 根据文件路径下载 Excel 文件 |
| 221 | + """ |
| 222 | + file_path = os.path.join(path, filename) |
| 223 | + |
| 224 | + # 检查文件是否存在 |
| 225 | + if not os.path.exists(file_path): |
| 226 | + raise HTTPException(404, "File Not Exists") |
| 227 | + |
| 228 | + # 检查文件是否是 Excel 文件 |
| 229 | + if not filename.endswith('_error.xlsx'): |
| 230 | + raise HTTPException(400, "Only support _error.xlsx") |
| 231 | + |
| 232 | + # 获取文件名 |
| 233 | + filename = os.path.basename(file_path) |
| 234 | + |
| 235 | + # 返回文件 |
| 236 | + return FileResponse( |
| 237 | + path=file_path, |
| 238 | + filename=filename, |
| 239 | + media_type='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' |
| 240 | + ) |
0 commit comments