Skip to content

Commit 8439708

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents 4a0104c + 42cacff commit 8439708

File tree

22 files changed

+1142
-481
lines changed

22 files changed

+1142
-481
lines changed

backend/apps/datasource/api/datasource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def get_tables_by_conf(session: SessionDep, ds: CoreDatasource):
6464

6565

6666
@router.post("/getFields/{id}/{table_name}")
67-
async def get_tables(session: SessionDep, id: int, table_name: str):
67+
async def get_fields(session: SessionDep, id: int, table_name: str):
6868
return getFields(session, id, table_name)
6969

7070

backend/apps/datasource/crud/datasource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def fieldEnum(session: SessionDep, id: int):
263263

264264

265265
def updateNum(session: SessionDep, ds: CoreDatasource):
266-
all_tables = get_tables(ds)
266+
all_tables = get_tables(ds) if ds.type != 'excel' else json.loads(aes_decrypt(ds.configuration)).get('sheets')
267267
selected_tables = get_tables_by_ds_id(session, ds.id)
268268
num = f'{len(selected_tables)}/{len(all_tables)}'
269269

backend/apps/system/api/login.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Annotated
22
from fastapi import APIRouter, Depends, HTTPException
33
from fastapi.security import OAuth2PasswordRequestForm
4+
from apps.system.schemas.system_schema import BaseUserDTO
45
from common.core.deps import SessionDep
56
from ..crud.user import authenticate
67
from common.core.security import create_access_token
@@ -14,9 +15,12 @@ def local_login(
1415
session: SessionDep,
1516
form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
1617
) -> Token:
17-
user = authenticate(session=session, account=form_data.username, password=form_data.password)
18+
user: BaseUserDTO = authenticate(session=session, account=form_data.username, password=form_data.password)
1819
if not user:
1920
raise HTTPException(status_code=400, detail="Incorrect account or password")
21+
22+
if not user.oid or user.oid == 0:
23+
raise HTTPException(status_code=400, detail="No associated workspace, Please contact the administrator")
2024
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
2125
user_dict = user.to_dict()
2226
return Token(access_token=create_access_token(

backend/apps/system/api/user.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Optional
2-
from fastapi import APIRouter, Query
3-
from sqlmodel import func, or_, select
4-
from apps.system.crud.user import get_db_user, user_ws_options
5-
from apps.system.models.system_model import UserWsModel, WorkspaceModel
2+
from fastapi import APIRouter, HTTPException, Query
3+
from sqlmodel import func, or_, select, delete as sqlmodel_delete
4+
from apps.system.crud.user import get_db_user, single_delete, user_ws_options
5+
from apps.system.models.system_model import UserWsModel
66
from apps.system.models.user import UserModel
77
from apps.system.schemas.auth import CacheName, CacheNamespace
88
from apps.system.schemas.system_schema import PwdEditor, UserCreator, UserEditor, UserGrid, UserLanguage, UserWs
@@ -34,14 +34,21 @@ async def pager(
3434
stmt = (
3535
select(
3636
UserModel,
37-
func.coalesce(func.string_agg(WorkspaceModel.name, ','), '').label("space_name")
37+
func.coalesce(
38+
func.array_remove(
39+
func.array_agg(UserWsModel.oid),
40+
None
41+
),
42+
[]
43+
).label("oid_list")
44+
#func.coalesce(func.string_agg(WorkspaceModel.name, ','), '').label("space_name")
3845
)
3946
.join(UserWsModel, UserModel.id == UserWsModel.uid, isouter=True)
40-
.join(WorkspaceModel, UserWsModel.oid == WorkspaceModel.id, isouter=True)
47+
#.join(WorkspaceModel, UserWsModel.oid == WorkspaceModel.id, isouter=True)
48+
.where(UserModel.id != 1)
4149
.group_by(UserModel.id)
4250
.order_by(UserModel.create_time)
4351
)
44-
4552
if status is not None:
4653
stmt = stmt.where(UserModel.status == status)
4754

@@ -67,10 +74,21 @@ async def pager(
6774
)
6875
)
6976

70-
return await paginator.get_paginated_response(
77+
user_page = await paginator.get_paginated_response(
7178
stmt=stmt,
7279
pagination=pagination,
7380
**filters)
81+
82+
""" for item in user_page.items:
83+
space_name: str = item['space_name']
84+
if space_name and 'i18n_default_workspace' in space_name:
85+
parts = list(map(
86+
lambda x: trans(x) if x == "i18n_default_workspace" else x,
87+
space_name.split(',')
88+
))
89+
output_str = ','.join(parts)
90+
item['space_name'] = output_str """
91+
return user_page
7492

7593
@router.get("/ws")
7694
async def ws_options(session: SessionDep, current_user: CurrentUser, trans: Trans) -> list[UserWs]:
@@ -81,46 +99,78 @@ async def ws_options(session: SessionDep, current_user: CurrentUser, trans: Tran
8199
async def ws_change(session: SessionDep, current_user: CurrentUser, oid: int):
82100
ws_list: list[UserWs] = await user_ws_options(session, current_user.id)
83101
if not any(x.id == oid for x in ws_list):
84-
raise RuntimeError(f"oid [{oid}] is invalid!")
102+
raise HTTPException(f"oid [{oid}] is invalid!")
85103
user_model: UserModel = get_db_user(session = session, user_id = current_user.id)
86104
user_model.oid = oid
87105
session.add(user_model)
88106
session.commit()
89107

90108
@router.get("/{id}", response_model=UserEditor)
91-
async def query(session: SessionDep, id: int) -> UserEditor:
109+
async def query(session: SessionDep, trans: Trans, id: int) -> UserEditor:
92110
db_user: UserModel = get_db_user(session = session, user_id = id)
93-
return db_user
111+
u_ws_options = await user_ws_options(session, id, trans)
112+
result = UserEditor.model_validate(db_user.model_dump())
113+
if u_ws_options:
114+
result.oid_list = [item.id for item in u_ws_options]
115+
return result
94116

95117
@router.post("")
96118
async def create(session: SessionDep, creator: UserCreator):
97119
data = creator.model_dump(exclude_unset=True)
98120
user_model = UserModel.model_validate(data)
99121
#user_model.create_time = get_timestamp()
100122
user_model.language = "zh-CN"
123+
user_model.oid = 0
124+
if creator.oid_list:
125+
# need to validate oid_list
126+
db_model_list = [
127+
UserWsModel.model_validate({
128+
"oid": oid,
129+
"uid": user_model.id,
130+
"weight": 0
131+
})
132+
for oid in creator.oid_list
133+
]
134+
session.add_all(db_model_list)
135+
user_model.oid = creator.oid_list[0]
101136
session.add(user_model)
102137
session.commit()
103138

104139
@router.put("")
105140
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="editor.id")
106141
async def update(session: SessionDep, editor: UserEditor):
107142
user_model: UserModel = get_db_user(session = session, user_id = editor.id)
143+
origin_oid: int = user_model.oid
144+
del_stmt = sqlmodel_delete(UserWsModel).where(UserWsModel.uid == editor.id)
145+
session.exec(del_stmt)
146+
108147
data = editor.model_dump(exclude_unset=True)
109148
user_model.sqlmodel_update(data)
149+
150+
user_model.oid = 0
151+
if editor.oid_list:
152+
# need to validate oid_list
153+
db_model_list = [
154+
UserWsModel.model_validate({
155+
"oid": oid,
156+
"uid": user_model.id,
157+
"weight": 0
158+
})
159+
for oid in editor.oid_list
160+
]
161+
session.add_all(db_model_list)
162+
user_model.oid = origin_oid if origin_oid in editor.oid_list else editor.oid_list[0]
110163
session.add(user_model)
111164
session.commit()
112165

113166
@router.delete("/{id}")
114-
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id")
115167
async def delete(session: SessionDep, id: int):
116-
user_model: UserModel = get_db_user(session = session, user_id = id)
117-
session.delete(user_model)
118-
session.commit()
168+
await single_delete(session, id)
119169

120170
@router.delete("")
121171
async def batch_del(session: SessionDep, id_list: list[int]):
122172
for id in id_list:
123-
delete(session, id)
173+
await single_delete(session, id)
124174

125175
@router.put("/language")
126176
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
@@ -137,7 +187,7 @@ async def langChange(session: SessionDep, current_user: CurrentUser, language: U
137187
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id")
138188
async def pwdReset(session: SessionDep, current_user: CurrentUser, id: int):
139189
if not current_user.isAdmin:
140-
raise RuntimeError('only for admin')
190+
raise HTTPException('only for admin')
141191
db_user: UserModel = get_db_user(session=session, user_id=id)
142192
db_user.password = default_md5_pwd()
143193
session.add(db_user)
@@ -148,7 +198,7 @@ async def pwdReset(session: SessionDep, current_user: CurrentUser, id: int):
148198
async def pwdUpdate(session: SessionDep, current_user: CurrentUser, editor: PwdEditor):
149199
db_user: UserModel = get_db_user(session=session, user_id=current_user.id)
150200
if not verify_md5pwd(editor.pwd, db_user.password):
151-
raise RuntimeError("pwd error")
201+
raise HTTPException("pwd error")
152202
db_user.password = md5pwd(editor.new_pwd)
153203
session.add(db_user)
154204
session.commit()

backend/apps/system/crud/user.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11

22
from typing import Optional
3-
from sqlmodel import Session, select
3+
from sqlmodel import Session, select, delete as sqlmodel_delete
44

55
from apps.system.models.system_model import UserWsModel, WorkspaceModel
66
from apps.system.schemas.auth import CacheName, CacheNamespace
77
from apps.system.schemas.system_schema import BaseUserDTO, UserInfoDTO, UserWs
8-
from common.core.sqlbot_cache import cache
8+
from common.core.deps import SessionDep
9+
from common.core.sqlbot_cache import cache, clear_cache
910
from common.utils.locale import I18n
1011
from ..models.user import UserModel
1112
from common.core.security import verify_md5pwd
@@ -55,4 +56,12 @@ async def user_ws_options(session: Session, uid: int, trans: Optional[I18n] = No
5556
return [
5657
UserWs(id = id, name = trans(name) if name.startswith('i18n') else name)
5758
for id, name in result.all()
58-
]
59+
]
60+
61+
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id")
62+
async def single_delete(session: SessionDep, id: int):
63+
user_model: UserModel = get_db_user(session = session, user_id = id)
64+
del_stmt = sqlmodel_delete(UserWsModel).where(UserWsModel.uid == id)
65+
session.exec(del_stmt)
66+
session.delete(user_model)
67+
session.commit()

backend/apps/system/middleware/auth.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
from typing import Optional
3-
from fastapi import Request
3+
from fastapi import HTTPException, Request
44
from fastapi.responses import JSONResponse
55
import jwt
66
from sqlmodel import Session
@@ -35,15 +35,15 @@ async def dispatch(self, request, call_next):
3535
request.state.current_user = validator[1]
3636
request.state.assistant = validator[2]
3737
return await call_next(request)
38-
return JSONResponse({"error": f"Unauthorized:[{validator[1]}]"}, status_code=401)
38+
return JSONResponse({"msg": f"Unauthorized:[{validator[1]}]"}, status_code=401)
3939
#validate pass
4040
tokenkey = settings.TOKEN_KEY
4141
token = request.headers.get(tokenkey)
4242
validate_pass, data = await self.validateToken(token)
4343
if validate_pass:
4444
request.state.current_user = data
4545
return await call_next(request)
46-
return JSONResponse({"error": f"Unauthorized:[{data}]"}, status_code=401)
46+
return JSONResponse({"msg": f"Unauthorized:[{data}]"}, status_code=401)
4747

4848
def is_options(self, request: Request):
4949
return request.method == "OPTIONS"
@@ -62,6 +62,12 @@ async def validateToken(self, token: Optional[str]):
6262
with Session(engine) as session:
6363
session_user = await get_user_info(session = session, user_id = token_data.id)
6464
session_user = UserInfoDTO.model_validate(session_user)
65+
session_user = UserInfoDTO.model_validate(session_user)
66+
""" if token_data.oid != session_user.oid:
67+
raise HTTPException(
68+
status_code=401,
69+
detail="Default space has been changed, please login again!"
70+
) """
6571
return True, session_user
6672
except Exception as e:
6773
return False, e

backend/apps/system/models/user.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
class BaseUserPO(SQLModel):
1212
account: str = Field(max_length=255, unique=True)
13-
oid: int = Field(nullable=False, sa_type=BigInteger())
13+
oid: int = Field(nullable=False, sa_type=BigInteger(), default=0)
1414
name: str = Field(max_length=255, unique=True)
1515
password: str = Field(default_factory=default_md5_pwd, max_length=255)
1616
email: str = Field(max_length=255)

backend/apps/system/schemas/system_schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ class UserCreator(BaseUser):
3030
name: str
3131
email: str
3232
status: int = 1
33+
oid_list: Optional[list[int]] = None
3334

3435
class UserEditor(UserCreator, BaseCreatorDTO):
3536
pass
3637

3738
class UserGrid(UserEditor):
3839
create_time: int
3940
language: str = "zh-CN"
40-
space_name: Optional[str] = None
41+
#space_name: Optional[str] = None
4142
origin: str = ''
4243

4344

frontend/src/api/chat.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ export interface ChatMessage {
2424
isTyping?: boolean
2525
first_chat?: boolean
2626
recommended_question?: string
27+
index: number
2728
}
2829

2930
export class ChatRecord {
Lines changed: 3 additions & 0 deletions
Loading

0 commit comments

Comments
 (0)