Skip to content

Commit 16bf2dc

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents d1ab16c + d1b5b05 commit 16bf2dc

File tree

28 files changed

+480
-99
lines changed

28 files changed

+480
-99
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""028_ds_oid
2+
3+
Revision ID: e96b16d3daab
4+
Revises: b049c9f8ca5b
5+
Create Date: 2025-07-17 14:40:48.522033
6+
7+
"""
8+
import sqlalchemy as sa
9+
from alembic import op
10+
11+
# revision identifiers, used by Alembic.
12+
revision = 'e96b16d3daab'
13+
down_revision = 'b049c9f8ca5b'
14+
branch_labels = None
15+
depends_on = None
16+
17+
18+
def upgrade():
19+
# ### commands auto generated by Alembic - please adjust! ###
20+
op.add_column('core_datasource', sa.Column('oid', sa.BigInteger(), nullable=True))
21+
op.execute('update core_datasource set oid = 1')
22+
# ### end Alembic commands ###
23+
24+
25+
def downgrade():
26+
# ### commands auto generated by Alembic - please adjust! ###
27+
op.drop_column('core_datasource', 'oid')
28+
# ### end Alembic commands ###
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""029_modify_chat
2+
3+
Revision ID: 77d4c39ec22f
4+
Revises: e96b16d3daab
5+
Create Date: 2025-07-17 17:05:13.392973
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '77d4c39ec22f'
15+
down_revision = 'e96b16d3daab'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.add_column('chat', sa.Column('oid', sa.BigInteger(), nullable=True))
23+
op.execute('update chat set oid = 1')
24+
op.alter_column('chat', 'create_time',
25+
existing_type=postgresql.TIMESTAMP(timezone=True),
26+
type_=sa.DateTime(),
27+
existing_nullable=True)
28+
op.alter_column('chat_record', 'create_time',
29+
existing_type=postgresql.TIMESTAMP(timezone=True),
30+
type_=sa.DateTime(),
31+
existing_nullable=True)
32+
op.alter_column('chat_record', 'finish_time',
33+
existing_type=postgresql.TIMESTAMP(timezone=True),
34+
type_=sa.DateTime(),
35+
existing_nullable=True)
36+
# ### end Alembic commands ###
37+
38+
39+
def downgrade():
40+
# ### commands auto generated by Alembic - please adjust! ###
41+
op.alter_column('chat_record', 'finish_time',
42+
existing_type=sa.DateTime(),
43+
type_=postgresql.TIMESTAMP(timezone=True),
44+
existing_nullable=True)
45+
op.alter_column('chat_record', 'create_time',
46+
existing_type=sa.DateTime(),
47+
type_=postgresql.TIMESTAMP(timezone=True),
48+
existing_nullable=True)
49+
op.alter_column('chat', 'create_time',
50+
existing_type=sa.DateTime(),
51+
type_=postgresql.TIMESTAMP(timezone=True),
52+
existing_nullable=True)
53+
op.drop_column('chat', 'oid')
54+
# ### end Alembic commands ###

backend/apps/chat/models/chat_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
class Chat(SQLModel, table=True):
1818
__tablename__ = "chat"
1919
id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
20-
create_time: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
20+
oid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True, default=1))
21+
create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
2122
create_by: int = Field(sa_column=Column(BigInteger, nullable=True))
2223
brief: str = Field(max_length=64, nullable=True)
2324
chat_type: str = Field(max_length=20, default="chat") # chat, datasource
@@ -31,8 +32,8 @@ class ChatRecord(SQLModel, table=True):
3132
chat_id: int = Field(sa_column=Column(BigInteger, nullable=False))
3233
ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger))
3334
first_chat: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
34-
create_time: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
35-
finish_time: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
35+
create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
36+
finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
3637
create_by: int = Field(sa_column=Column(BigInteger, nullable=True))
3738
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
3839
engine_type: str = Field(max_length=64)
@@ -66,6 +67,7 @@ class ChatRecord(SQLModel, table=True):
6667
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
6768
predict_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
6869

70+
6971
class CreateChat(BaseModel):
7072
id: int = None
7173
question: str = None

backend/apps/datasource/api/datasource.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020

2121
@router.get("/list")
22-
async def datasource_list(session: SessionDep):
23-
return get_datasource_list(session=session)
22+
async def datasource_list(session: SessionDep, user: CurrentUser):
23+
return get_datasource_list(session=session, user=user)
2424

2525

2626
@router.post("/get/{id}")
@@ -44,8 +44,8 @@ async def choose_tables(session: SessionDep, id: int, tables: List[CoreTable]):
4444

4545

4646
@router.post("/update", response_model=CoreDatasource)
47-
async def update(session: SessionDep, ds: CoreDatasource):
48-
return update_ds(session, ds)
47+
async def update(session: SessionDep,user: CurrentUser, ds: CoreDatasource):
48+
return update_ds(session,user, ds)
4949

5050

5151
@router.post("/delete/{id}", response_model=CoreDatasource)

backend/apps/datasource/crud/datasource.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
DatasourceConf, TableAndFields
2323

2424

25-
def get_datasource_list(session: SessionDep):
26-
statement = select(CoreDatasource).order_by(CoreDatasource.create_time.desc())
27-
datasource_list = session.exec(statement).fetchall()
28-
return datasource_list
25+
def get_datasource_list(session: SessionDep, user: CurrentUser):
26+
oid = user.oid if user.oid is not None else 1
27+
return session.query(CoreDatasource).filter(CoreDatasource.oid == oid).order_by(
28+
CoreDatasource.create_time.desc()).all()
2929

3030

3131
def get_ds(session: SessionDep, id: int):
@@ -45,25 +45,27 @@ def check_status(session: SessionDep, ds: CoreDatasource):
4545
return False
4646

4747

48-
def check_name(session: SessionDep, ds: CoreDatasource):
48+
def check_name(session: SessionDep, user: CurrentUser, ds: CoreDatasource):
4949
if ds.id is not None:
5050
ds_list = session.query(CoreDatasource).filter(
51-
and_(CoreDatasource.name == ds.name, CoreDatasource.id != ds.id)).all()
51+
and_(CoreDatasource.name == ds.name, CoreDatasource.id != ds.id, CoreDatasource.oid == user.oid)).all()
5252
if ds_list is not None and len(ds_list) > 0:
5353
raise 'Name exist'
5454
else:
55-
ds_list = session.query(CoreDatasource).filter(CoreDatasource.name == ds.name).all()
55+
ds_list = session.query(CoreDatasource).filter(
56+
and_(CoreDatasource.name == ds.name, CoreDatasource.oid == user.oid)).all()
5657
if ds_list is not None and len(ds_list) > 0:
5758
raise 'Name exist'
5859

5960

6061
def create_ds(session: SessionDep, user: CurrentUser, create_ds: CreateDatasource):
6162
ds = CoreDatasource()
6263
deepcopy_ignore_extra(create_ds, ds)
63-
check_name(session, ds)
64+
check_name(session, user, ds)
6465
ds.create_time = datetime.datetime.now()
6566
# status = check_status(session, ds)
6667
ds.create_by = user.id
68+
ds.oid = user.oid if user.oid is not None else 1
6769
ds.status = "Success"
6870
ds.type_name = db_type_relation()[ds.type]
6971
record = CoreDatasource(**ds.model_dump())
@@ -85,9 +87,9 @@ def chooseTables(session: SessionDep, id: int, tables: List[CoreTable]):
8587
updateNum(session, ds)
8688

8789

88-
def update_ds(session: SessionDep, ds: CoreDatasource):
90+
def update_ds(session: SessionDep, user: CurrentUser, ds: CoreDatasource):
8991
ds.id = int(ds.id)
90-
check_name(session, ds)
92+
check_name(session, user, ds)
9193
status = check_status(session, ds)
9294
ds.status = "Success" if status is True else "Fail"
9395
record = session.exec(select(CoreDatasource).where(CoreDatasource.id == ds.id)).first()

backend/apps/datasource/models/datasource.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class CoreDatasource(SQLModel, table=True):
1818
create_by: int = Field(sa_column=Column(BigInteger()))
1919
status: str = Field(max_length=64, nullable=True)
2020
num: str = Field(max_length=256, nullable=True)
21+
oid: int = Field(sa_column=Column(BigInteger()))
2122

2223

2324
class CoreTable(SQLModel, table=True):
@@ -53,6 +54,7 @@ class CreateDatasource(BaseModel):
5354
create_by: int = 0
5455
status: str = ''
5556
num: str = ''
57+
oid: int = 1
5658
tables: List[CoreTable] = []
5759

5860

backend/apps/mcp/mcp.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,29 @@
33

44
from datetime import timedelta
55

6-
from fastapi import APIRouter, HTTPException
6+
import jwt
7+
from fastapi import HTTPException, status, APIRouter
78
from fastapi.responses import StreamingResponse
9+
# from fastapi.security import OAuth2PasswordBearer
10+
from jwt.exceptions import InvalidTokenError
11+
from pydantic import ValidationError
812

913
from apps.chat.api.chat import create_chat
1014
from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart
1115
from apps.chat.task.llm import LLMService, run_task
12-
from apps.datasource.crud.datasource import get_datasource_list
13-
from apps.system.crud.user import authenticate
14-
from apps.system.models.system_model import AiModelDetail
16+
from apps.system.crud.user import authenticate, get_user_info
17+
from apps.system.schemas.system_schema import BaseUserDTO
18+
from apps.system.schemas.system_schema import UserInfoDTO
19+
from common.core import security
1520
from common.core.config import settings
16-
from common.core.deps import SessionDep, get_current_user
17-
from common.core.schemas import Token
21+
from common.core.deps import SessionDep
22+
from common.core.schemas import TokenPayload, XOAuth2PasswordBearer, Token
1823
from common.core.security import create_access_token
1924

25+
reusable_oauth2 = XOAuth2PasswordBearer(
26+
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
27+
)
28+
2029
router = APIRouter(tags=["mcp"], prefix="/mcp")
2130

2231

@@ -35,21 +44,24 @@
3544
# ))
3645

3746

38-
@router.get("/ds_list", operation_id="get_datasource_list")
39-
async def datasource_list(session: SessionDep):
40-
return get_datasource_list(session=session)
41-
42-
43-
@router.get("/model_list", operation_id="get_model_list")
44-
async def get_model_list(session: SessionDep):
45-
return session.query(AiModelDetail).all()
47+
# @router.get("/ds_list", operation_id="get_datasource_list")
48+
# async def datasource_list(session: SessionDep):
49+
# return get_datasource_list(session=session)
50+
#
51+
#
52+
# @router.get("/model_list", operation_id="get_model_list")
53+
# async def get_model_list(session: SessionDep):
54+
# return session.query(AiModelDetail).all()
4655

4756

4857
@router.post("/mcp_start", operation_id="mcp_start")
4958
async def mcp_start(session: SessionDep, chat: ChatStart):
50-
user = authenticate(session=session, account=chat.username, password=chat.password)
59+
user: BaseUserDTO = authenticate(session=session, account=chat.username, password=chat.password)
5160
if not user:
5261
raise HTTPException(status_code=400, detail="Incorrect account or password")
62+
63+
if not user.oid or user.oid == 0:
64+
raise HTTPException(status_code=400, detail="No associated workspace, Please contact the administrator")
5365
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
5466
user_dict = user.to_dict()
5567
t = Token(access_token=create_access_token(
@@ -61,9 +73,26 @@ async def mcp_start(session: SessionDep, chat: ChatStart):
6173

6274
@router.post("/mcp_question", operation_id="mcp_question")
6375
async def mcp_question(session: SessionDep, chat: ChatMcp):
64-
user = await get_current_user(session, chat.token)
65-
66-
llm_service = LLMService(session, user, chat)
76+
try:
77+
payload = jwt.decode(
78+
chat.token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
79+
)
80+
token_data = TokenPayload(**payload)
81+
except (InvalidTokenError, ValidationError):
82+
raise HTTPException(
83+
status_code=status.HTTP_403_FORBIDDEN,
84+
detail="Could not validate credentials",
85+
)
86+
session_user = await get_user_info(session=session, user_id=token_data.id)
87+
session_user = UserInfoDTO.model_validate(session_user)
88+
if not session_user:
89+
raise HTTPException(status_code=404, detail="User not found")
90+
91+
if session_user.status != 1:
92+
raise HTTPException(status_code=400, detail="Inactive user")
93+
94+
# ask
95+
llm_service = LLMService(session, session_user, chat)
6796
llm_service.init_record()
6897

6998
return StreamingResponse(run_task(llm_service, False), media_type="text/event-stream")

0 commit comments

Comments
 (0)