Skip to content

Commit 5b436c7

Browse files
committed
feat: import Sample SQL
1 parent 0674066 commit 5b436c7

File tree

10 files changed

+528
-19
lines changed

10 files changed

+528
-19
lines changed

backend/apps/data_training/api/data_training.py

Lines changed: 151 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
import asyncio
2+
import hashlib
23
import io
4+
import os
5+
import uuid
6+
from http.client import HTTPException
37
from typing import Optional
48

59
import pandas as pd
6-
from fastapi import APIRouter, Query
7-
from fastapi.responses import StreamingResponse
10+
from fastapi import APIRouter, File, UploadFile, Query
11+
from fastapi.responses import StreamingResponse, FileResponse
812

913
from apps.chat.models.chat_model import AxisObj
1014
from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training, \
11-
enable_training, get_all_data_training
15+
enable_training, get_all_data_training, batch_create_training
1216
from apps.data_training.models.data_training_model import DataTrainingInfo
17+
from common.core.config import settings
1318
from common.core.deps import SessionDep, CurrentUser, Trans
1419
from common.utils.data_format import DataFormat
1520

@@ -90,3 +95,146 @@ def inner():
9095

9196
result = await asyncio.to_thread(inner)
9297
return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
98+
99+
100+
path = settings.EXCEL_PATH
101+
102+
from sqlalchemy.orm import sessionmaker, scoped_session
103+
from common.core.db import engine
104+
from sqlmodel import Session
105+
106+
session_maker = scoped_session(sessionmaker(bind=engine, class_=Session))
107+
108+
109+
@router.post("/uploadExcel")
110+
async def upload_excel(trans: Trans, current_user: CurrentUser, file: UploadFile = File(...)):
111+
ALLOWED_EXTENSIONS = {"xlsx", "xls"}
112+
if not file.filename.lower().endswith(tuple(ALLOWED_EXTENSIONS)):
113+
raise HTTPException(400, "Only support .xlsx/.xls")
114+
115+
os.makedirs(path, exist_ok=True)
116+
base_filename = f"{file.filename.split('.')[0]}_{hashlib.sha256(uuid.uuid4().bytes).hexdigest()[:10]}"
117+
filename = f"{base_filename}.{file.filename.split('.')[1]}"
118+
save_path = os.path.join(path, filename)
119+
with open(save_path, "wb") as f:
120+
f.write(await file.read())
121+
122+
oid = current_user.oid
123+
124+
use_cols = [0, 1, 2] # 问题, 描述, 数据源名称
125+
# 根据oid确定要读取的列
126+
if oid == 1:
127+
use_cols = [0, 1, 2, 3] # 问题, 描述, 数据源名称, 高级应用名称
128+
129+
def inner():
130+
131+
session = session_maker()
132+
133+
sheet_names = pd.ExcelFile(save_path).sheet_names
134+
135+
import_data = []
136+
137+
for sheet_name in sheet_names:
138+
139+
df = pd.read_excel(
140+
save_path,
141+
sheet_name=sheet_name,
142+
engine='calamine',
143+
header=0,
144+
usecols=use_cols,
145+
dtype=str
146+
).fillna("")
147+
148+
for index, row in df.iterrows():
149+
# 跳过空行
150+
if row.isnull().all():
151+
continue
152+
153+
question = row[0].strip() if pd.notna(row[0]) and row[0].strip() else None
154+
description = row[1].strip() if pd.notna(row[1]) and row[1].strip() else None
155+
datasource_name = row[2].strip() if pd.notna(row[2]) and row[2].strip() else None
156+
157+
advanced_application_name = None
158+
if oid == 1 and len(row) > 3:
159+
advanced_application_name = row[3].strip() if pd.notna(row[3]) and row[3].strip() else None
160+
161+
if oid == 1:
162+
import_data.append(
163+
DataTrainingInfo(oid=oid, question=question, description=description,
164+
datasource_name=datasource_name,
165+
advanced_application_name=advanced_application_name))
166+
else:
167+
import_data.append(
168+
DataTrainingInfo(oid=oid, question=question, description=description,
169+
datasource_name=datasource_name))
170+
171+
res = batch_create_training(session, import_data, oid, trans)
172+
173+
failed_records = res['failed_records']
174+
175+
error_excel_filename = None
176+
177+
if len(failed_records) > 0:
178+
data_list = []
179+
for obj in failed_records:
180+
_data = {
181+
"question": obj['data'].question,
182+
"description": obj['data'].description,
183+
"datasource_name": obj['data'].datasource_name,
184+
"advanced_application_name": obj['data'].advanced_application_name,
185+
"errors": obj['errors']
186+
}
187+
data_list.append(_data)
188+
189+
fields = []
190+
fields.append(AxisObj(name=trans('i18n_data_training.problem_description'), value='question'))
191+
fields.append(AxisObj(name=trans('i18n_data_training.sample_sql'), value='description'))
192+
fields.append(AxisObj(name=trans('i18n_data_training.effective_data_sources'), value='datasource_name'))
193+
if current_user.oid == 1:
194+
fields.append(
195+
AxisObj(name=trans('i18n_data_training.advanced_application'), value='advanced_application_name'))
196+
fields.append(AxisObj(name=trans('i18n_data_training.error_info'), value='errors'))
197+
198+
md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list)
199+
200+
df = pd.DataFrame(md_data, columns=_fields_list)
201+
error_excel_filename = f"{base_filename}_error.xlsx"
202+
save_error_path = os.path.join(path, error_excel_filename)
203+
# 保存 DataFrame 到 Excel
204+
df.to_excel(save_error_path, index=False)
205+
206+
return {
207+
'success_count': res['success_count'],
208+
'failed_count': len(failed_records),
209+
'duplicate_count': res['duplicate_count'],
210+
'original_count': res['original_count'],
211+
'error_excel_filename': error_excel_filename,
212+
}
213+
214+
return await asyncio.to_thread(inner)
215+
216+
217+
@router.get("/download-fail-info/{filename}")
218+
async def download_excel(filename: str, trans: Trans):
219+
"""
220+
根据文件路径下载 Excel 文件
221+
"""
222+
file_path = os.path.join(path, filename)
223+
224+
# 检查文件是否存在
225+
if not os.path.exists(file_path):
226+
raise HTTPException(404, "File Not Exists")
227+
228+
# 检查文件是否是 Excel 文件
229+
if not filename.endswith('_error.xlsx'):
230+
raise HTTPException(400, "Only support _error.xlsx")
231+
232+
# 获取文件名
233+
filename = os.path.basename(file_path)
234+
235+
# 返回文件
236+
return FileResponse(
237+
path=file_path,
238+
filename=filename,
239+
media_type='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
240+
)

0 commit comments

Comments
 (0)