Skip to content

Commit 49a9f55

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents 0f3a9c0 + 1fe5fcb commit 49a9f55

File tree

24 files changed

+323
-118
lines changed

24 files changed

+323
-118
lines changed

backend/alembic/versions/010_upgrade_user_language.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
def upgrade():
2121
# ### commands auto generated by Alembic - please adjust! ###
2222

23-
op.add_column('sys_user', sa.Column('language', sa.VARCHAR(length=255), server_default=sa.text("'zh-CN'::character varying"), nullable=False))
23+
op.add_column('sys_user', sa.Column('language', sa.VARCHAR(length=255), server_default=sa.text("'zh-CN'"), nullable=False))
2424

2525
# ### end Alembic commands ###
2626

backend/alembic/versions/018_modify_chat.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def upgrade():
2323
existing_type=sa.INTEGER(),
2424
type_=sa.BigInteger(),
2525
existing_nullable=False,
26-
autoincrement=True,
27-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1, maxvalue=2147483647, cycle=False, cache=1))
26+
autoincrement=True
27+
)
2828
op.alter_column('chat', 'datasource',
2929
existing_type=sa.INTEGER(),
3030
type_=sa.BigInteger(),
@@ -33,8 +33,7 @@ def upgrade():
3333
existing_type=sa.INTEGER(),
3434
type_=sa.BigInteger(),
3535
existing_nullable=False,
36-
autoincrement=True,
37-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1, maxvalue=2147483647, cycle=False, cache=1))
36+
autoincrement=True)
3837
op.alter_column('chat_record', 'chat_id',
3938
existing_type=sa.INTEGER(),
4039
type_=sa.BigInteger(),
@@ -68,8 +67,7 @@ def downgrade():
6867
existing_type=sa.BigInteger(),
6968
type_=sa.INTEGER(),
7069
existing_nullable=False,
71-
autoincrement=True,
72-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1, maxvalue=2147483647, cycle=False, cache=1))
70+
autoincrement=True)
7371
op.alter_column('chat', 'datasource',
7472
existing_type=sa.BigInteger(),
7573
type_=sa.INTEGER(),
@@ -78,6 +76,5 @@ def downgrade():
7876
existing_type=sa.BigInteger(),
7977
type_=sa.INTEGER(),
8078
existing_nullable=False,
81-
autoincrement=True,
82-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1, maxvalue=2147483647, cycle=False, cache=1))
79+
autoincrement=True)
8380
# ### end Alembic commands ###

backend/alembic/versions/030_permission_oid.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ def upgrade():
2424
existing_type=sa.INTEGER(),
2525
type_=sa.BigInteger(),
2626
existing_nullable=False,
27-
autoincrement=True,
28-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1,
29-
maxvalue=2147483647, cycle=False, cache=1))
27+
autoincrement=True)
3028
# ### end Alembic commands ###
3129

3230

@@ -36,8 +34,6 @@ def downgrade():
3634
existing_type=sa.BigInteger(),
3735
type_=sa.INTEGER(),
3836
existing_nullable=False,
39-
autoincrement=True,
40-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1,
41-
maxvalue=2147483647, cycle=False, cache=1))
37+
autoincrement=True)
4238
op.drop_column('ds_rules', 'oid')
4339
# ### end Alembic commands ###

backend/apps/chat/api/chat.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
import datetime
12
import traceback
23

4+
import numpy as np
5+
import pandas as pd
36
from fastapi import APIRouter, HTTPException
47
from fastapi.responses import StreamingResponse
58

69
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
710
delete_chat, get_chat_chart_data, get_chat_predict_data
8-
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion
11+
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, ExcelData
912
from apps.chat.task.llm import LLMService
13+
from common.core.config import settings
1014
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser
1115

1216
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
@@ -183,3 +187,27 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
183187
)
184188

185189
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
190+
191+
192+
@router.post("/excel/export")
193+
async def export_excel(excel_data: ExcelData):
194+
_fields_list = []
195+
data = []
196+
for _data in excel_data.data:
197+
_row = []
198+
for field in excel_data.axis:
199+
_row.append(_data.get(field.value))
200+
data.append(_row)
201+
for field in excel_data.axis:
202+
_fields_list.append(field.name)
203+
df = pd.DataFrame(np.array(data), columns=_fields_list)
204+
205+
file_name = f"{excel_data.name}-{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.xlsx"
206+
207+
file_path = f'{(settings.EXCEL_PATH if settings.EXCEL_PATH[-1] == "/" else (settings.EXCEL_PATH + "/"))}{file_name}'
208+
209+
file_download_path = f'{(settings.SERVER_EXCEL_HOST if settings.SERVER_EXCEL_HOST[-1] == "/" else (settings.SERVER_EXCEL_HOST + "/"))}{file_name}'
210+
211+
df.to_excel(file_path, index=False)
212+
213+
return file_download_path

backend/apps/chat/models/chat_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,15 @@ class McpQuestion(BaseModel):
180180
question: str = Body(description='用户提问')
181181
chat_id: int = Body(description='会话ID')
182182
token: str = Body(description='token')
183+
184+
185+
class AxisObj(BaseModel):
186+
name: str = ''
187+
value: str = ''
188+
type: str | None = None
189+
190+
191+
class ExcelData(BaseModel):
192+
axis: list[AxisObj] = []
193+
data: list[dict] = []
194+
name: str = 'Excel'

backend/apps/datasource/crud/datasource.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
from fastapi import HTTPException
66
from sqlalchemy import and_, text, func
7-
87
from sqlmodel import select
98

10-
9+
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
1110
from apps.datasource.utils.utils import aes_decrypt
1211
from apps.db.constant import DB
1312
from apps.db.db import get_engine, get_tables, get_fields, exec_sql
@@ -20,7 +19,7 @@
2019
from ..crud.table import delete_table_by_ds_id, update_table
2120
from ..models.datasource import CoreDatasource, CreateDatasource, CoreTable, CoreField, ColumnSchema, TableObj, \
2221
DatasourceConf, TableAndFields
23-
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
22+
2423

2524
def get_datasource_list(session: SessionDep, user: CurrentUser, oid: Optional[int] = None) -> List[CoreDatasource]:
2625
current_oid = user.oid if user.oid is not None else 1
@@ -249,11 +248,13 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table
249248
f_list = [f for f in data.fields if f.checked]
250249
if is_normal_user(current_user):
251250
# column is checked, and, column permission for data.fields
252-
f_list = get_column_permission_fields(session=session, current_user=current_user, table=data.table, fields=f_list) or f_list
251+
f_list = get_column_permission_fields(session=session, current_user=current_user, table=data.table,
252+
fields=f_list) or f_list
253253

254254
# row permission tree
255255
where_str = ''
256-
filter_mapping = get_row_permission_filters(session=session, current_user=current_user, ds=ds, tables=None, single_table=data.table)
256+
filter_mapping = get_row_permission_filters(session=session, current_user=current_user, ds=ds, tables=None,
257+
single_table=data.table)
257258
if filter_mapping:
258259
mapping_dict = filter_mapping[0]
259260
where_str = mapping_dict.get('filter')
@@ -325,13 +326,12 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
325326
fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all()
326327

327328
# do column permissions, filter fields
328-
fields = get_column_permission_fields(session=session, current_user=current_user, table=table, fields=fields) or fields
329+
fields = get_column_permission_fields(session=session, current_user=current_user, table=table,
330+
fields=fields) or fields
329331
_list.append(TableAndFields(schema=schema, table=table, fields=fields))
330332
return _list
331333

332334

333-
334-
335335
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource) -> str:
336336
schema_str = ""
337337
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
@@ -360,4 +360,5 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
360360
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
361361
schema_str += ",\n".join(field_list)
362362
schema_str += '\n]\n'
363+
# todo 外键
363364
return schema_str

backend/apps/system/api/user.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from collections import defaultdict
12
from typing import Optional
23
from fastapi import APIRouter, HTTPException, Query
3-
from sqlmodel import func, or_, select, delete as sqlmodel_delete
4+
from sqlmodel import SQLModel, func, or_, select, delete as sqlmodel_delete
45
from apps.system.crud.user import check_account_exists, check_email_exists, check_email_format, check_pwd_format, get_db_user, single_delete, user_ws_options
56
from apps.system.models.system_model import UserWsModel, WorkspaceModel
67
from apps.system.models.user import UserModel
@@ -11,6 +12,7 @@
1112
from common.core.schemas import PaginatedResponse, PaginationParams
1213
from common.core.security import default_md5_pwd, md5pwd, verify_md5pwd
1314
from common.core.sqlbot_cache import clear_cache
15+
1416
router = APIRouter(tags=["user"], prefix="/user")
1517

1618
@router.get("/info")
@@ -30,43 +32,15 @@ async def pager(
3032
pagination = PaginationParams(page=pageNum, size=pageSize)
3133
paginator = Paginator(session)
3234
filters = {}
33-
34-
stmt = (
35-
select(
36-
UserModel,
37-
func.coalesce(
38-
func.array_remove(
39-
func.array_agg(UserWsModel.oid),
40-
None
41-
),
42-
[]
43-
).label("oid_list")
44-
#func.coalesce(func.string_agg(WorkspaceModel.name, ','), '').label("space_name")
45-
)
46-
.join(UserWsModel, UserModel.id == UserWsModel.uid, isouter=True)
47-
#.join(WorkspaceModel, UserWsModel.oid == WorkspaceModel.id, isouter=True)
48-
.where(UserModel.id != 1)
49-
.group_by(UserModel.id)
50-
.order_by(UserModel.create_time)
51-
)
52-
if status is not None:
53-
stmt = stmt.where(UserModel.status == status)
5435

36+
origin_stmt = select(UserModel.id).join(UserWsModel, UserModel.id == UserWsModel.uid).where(UserModel.id != 1).distinct()
5537
if oidlist:
56-
user_filter = (
57-
select(UserModel.id)
58-
.join(UserWsModel, UserModel.id == UserWsModel.uid)
59-
.where(UserWsModel.oid.in_(oidlist))
60-
.distinct()
61-
)
62-
stmt = stmt.where(UserModel.id.in_(user_filter))
63-
64-
""" if origins is not None:
65-
stmt = stmt.where(UserModel.origin == origins) """
66-
38+
origin_stmt = origin_stmt.where(UserWsModel.oid.in_(oidlist))
39+
if status is not None:
40+
origin_stmt = origin_stmt.where(UserModel.status == status)
6741
if keyword:
6842
keyword_pattern = f"%{keyword}%"
69-
stmt = stmt.where(
43+
origin_stmt = origin_stmt.where(
7044
or_(
7145
UserModel.account.ilike(keyword_pattern),
7246
UserModel.name.ilike(keyword_pattern),
@@ -75,21 +49,47 @@ async def pager(
7549
)
7650

7751
user_page = await paginator.get_paginated_response(
78-
stmt=stmt,
52+
stmt=origin_stmt,
7953
pagination=pagination,
8054
**filters)
81-
82-
""" for item in user_page.items:
83-
space_name: str = item['space_name']
84-
if space_name and 'i18n_default_workspace' in space_name:
85-
parts = list(map(
86-
lambda x: trans(x) if x == "i18n_default_workspace" else x,
87-
space_name.split(',')
88-
))
89-
output_str = ','.join(parts)
90-
item['space_name'] = output_str """
55+
uid_list = [item.get('id') for item in user_page.items]
56+
if not uid_list:
57+
return user_page
58+
stmt = (
59+
select(UserModel, UserWsModel.oid.label('ws_oid'))
60+
.join(UserWsModel, UserModel.id == UserWsModel.uid, isouter=True)
61+
.where(UserModel.id.in_(uid_list))
62+
.order_by(UserModel.create_time)
63+
)
64+
user_workspaces = session.exec(stmt).all()
65+
merged = defaultdict(list)
66+
extra_attrs = {}
67+
68+
for (user, ws_oid) in user_workspaces:
69+
item = {}
70+
item.update(user.model_dump())
71+
user_id = item['id']
72+
merged[user_id].append(ws_oid)
73+
if user_id not in extra_attrs:
74+
extra_attrs[user_id] = {k: v for k, v in item.items() if k != "ws_oid"}
75+
76+
# 组合结果
77+
result = [
78+
{**extra_attrs[user_id], "oid_list": oid_list}
79+
for user_id, oid_list in merged.items()
80+
]
81+
user_page.items = result
9182
return user_page
9283

84+
def format_user_dict(row) -> dict:
85+
result_dict = {}
86+
for item, key in zip(row, row._fields):
87+
if isinstance(item, SQLModel):
88+
result_dict.update(item.model_dump())
89+
else:
90+
result_dict[key] = item
91+
92+
return result_dict
9393
@router.get("/ws")
9494
async def ws_options(session: SessionDep, current_user: CurrentUser, trans: Trans) -> list[UserWs]:
9595
return await user_ws_options(session, current_user.id, trans)

backend/common/core/config.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ def all_cors_origins(self) -> list[str]:
4444
self.FRONTEND_HOST
4545
]
4646

47-
POSTGRES_SERVER: str
47+
POSTGRES_SERVER: str = 'localhost'
4848
POSTGRES_PORT: int = 5432
49-
POSTGRES_USER: str
50-
POSTGRES_PASSWORD: str = ""
51-
POSTGRES_DB: str = ""
49+
POSTGRES_USER: str = 'root'
50+
POSTGRES_PASSWORD: str = "123456"
51+
POSTGRES_DB: str = "sqlbot"
52+
SQLBOT_DB_URL: str = ''
53+
#SQLBOT_DB_URL: str = 'mysql+pymysql://root:Password123%[email protected]:3306/sqlbot'
5254

5355
TOKEN_KEY: str = "X-SQLBOT-TOKEN"
5456
DEFAULT_PWD: str = "SQLBot@123456"
@@ -64,7 +66,9 @@ def all_cors_origins(self) -> list[str]:
6466

6567
@computed_field # type: ignore[prop-decorator]
6668
@property
67-
def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
69+
def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
70+
if self.SQLBOT_DB_URL:
71+
return self.SQLBOT_DB_URL
6872
return MultiHostUrl.build(
6973
scheme="postgresql+psycopg",
7074
username=self.POSTGRES_USER,
@@ -75,8 +79,10 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
7579
)
7680

7781
MCP_IMAGE_PATH: str = '/opt/sqlbot/images'
82+
EXCEL_PATH: str
7883
MCP_IMAGE_HOST: str = 'http://localhost:3000'
7984
SERVER_IMAGE_HOST: str
85+
SERVER_EXCEL_HOST: str
8086

8187
settings = Settings() # type: ignore
8288
print(settings)

backend/common/core/pagination.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def __init__(self, session: Session):
1212
self.session = session
1313
def _process_result_row(self, row: Row) -> Dict[str, Any]:
1414
result_dict = {}
15+
if isinstance(row, int):
16+
return {'id': row}
1517
for item, key in zip(row, row._fields):
1618
if isinstance(item, SQLModel):
1719
result_dict.update(item.model_dump())

backend/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def custom_generate_unique_id(route: APIRoute) -> str:
4444
lifespan=lifespan
4545
)
4646

47+
app.mount("/excel", StaticFiles(directory=settings.EXCEL_PATH), name="excel")
48+
49+
4750
mcp_app = FastAPI()
4851
# mcp server, images path
4952
mcp_app.mount("/images", StaticFiles(directory=settings.MCP_IMAGE_PATH), name="images")

0 commit comments

Comments
 (0)