Skip to content

Commit 7378438

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents 3fe287c + 8d66932 commit 7378438

File tree

2 files changed

+73
-63
lines changed

2 files changed

+73
-63
lines changed

backend/apps/chat/task/llm.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
get_old_questions, save_analysis_predict_record, list_base_records, rename_chat
2929
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat
3030
from apps.datasource.crud.datasource import get_table_schema
31+
from apps.datasource.crud.datasource import is_normal_user
3132
from apps.datasource.crud.row_permission import transFilterTree
3233
from apps.datasource.models.datasource import CoreDatasource, CoreTable
3334
from apps.db.db import exec_sql
@@ -791,26 +792,28 @@ def run_task(self, in_chat: bool = True):
791792
SQLBotLogUtil.info(full_sql_text)
792793

793794
# todo row permission
794-
sql_json_str = extract_nested_json(full_sql_text)
795-
data = orjson.loads(sql_json_str)
796-
797-
sql = ''
798-
message = ''
799-
error = False
800-
if data['success']:
801-
sql = data['sql']
795+
if is_normal_user(self.current_user):
796+
sql_json_str = extract_nested_json(full_sql_text)
797+
data = orjson.loads(sql_json_str)
798+
799+
sql = ''
800+
message = ''
801+
error = False
802+
if data['success']:
803+
sql = data['sql']
804+
else:
805+
message = data['message']
806+
error = True
807+
if error:
808+
raise Exception(message)
809+
if sql.strip() == '':
810+
raise Exception("SQL query is empty")
811+
812+
sql_result = self.generate_filter(data.get('sql'), data.get('tables')) # maybe no sql and tables
813+
SQLBotLogUtil.info(sql_result)
814+
sql = self.check_save_sql(res=sql_result)
802815
else:
803-
message = data['message']
804-
error = True
805-
if error:
806-
raise Exception(message)
807-
if sql.strip() == '':
808-
raise Exception("SQL query is empty")
809-
810-
sql_result = self.generate_filter(data.get('sql'), data.get('tables')) # maybe no sql and tables
811-
SQLBotLogUtil.info(sql_result)
812-
sql = self.check_save_sql(res=sql_result)
813-
# sql = llm_service.check_save_sql(res=full_sql_text)
816+
sql = self.check_save_sql(res=full_sql_text)
814817

815818
SQLBotLogUtil.info(sql)
816819
format_sql = sqlparse.format(sql, reindent=True)

backend/apps/datasource/crud/datasource.py

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -246,43 +246,45 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table
246246
if data.fields is None or len(data.fields) == 0:
247247
return {"fields": [], "data": [], "sql": ''}
248248

249-
# column is checked, and, column permission for data.fields
249+
where = None
250250
f_list = [f for f in data.fields if f.checked]
251-
column_permissions = session.query(DsPermission).filter(
252-
and_(DsPermission.table_id == data.table.id, DsPermission.type == 'column')).all()
253-
if column_permissions is not None:
254-
for permission in column_permissions:
255-
# check permission and user in same rules
256-
obj = session.query(DsRules).filter(
257-
and_(DsRules.permission_list.op('@>')(cast([permission.id], JSONB)),
258-
or_(DsRules.user_list.op('@>')(cast([f'{current_user.id}'], JSONB)),
259-
DsRules.user_list.op('@>')(cast([current_user.id], JSONB))))
260-
).first()
261-
if obj is not None:
262-
permission_list = json.loads(permission.permissions)
263-
f_list = filter_list(f_list, permission_list)
251+
if is_normal_user(current_user):
252+
# column is checked, and, column permission for data.fields
253+
column_permissions = session.query(DsPermission).filter(
254+
and_(DsPermission.table_id == data.table.id, DsPermission.type == 'column')).all()
255+
if column_permissions is not None:
256+
for permission in column_permissions:
257+
# check permission and user in same rules
258+
obj = session.query(DsRules).filter(
259+
and_(DsRules.permission_list.op('@>')(cast([permission.id], JSONB)),
260+
or_(DsRules.user_list.op('@>')(cast([f'{current_user.id}'], JSONB)),
261+
DsRules.user_list.op('@>')(cast([current_user.id], JSONB))))
262+
).first()
263+
if obj is not None:
264+
permission_list = json.loads(permission.permissions)
265+
f_list = filter_list(f_list, permission_list)
266+
267+
# row permission tree
268+
row_permissions = session.query(DsPermission).filter(
269+
and_(DsPermission.table_id == data.table.id, DsPermission.type == 'row')).all()
270+
res: List[PermissionDTO] = []
271+
if row_permissions is not None:
272+
for permission in row_permissions:
273+
# check permission and user in same rules
274+
obj = session.query(DsRules).filter(
275+
and_(DsRules.permission_list.op('@>')(cast([permission.id], JSONB)),
276+
or_(DsRules.user_list.op('@>')(cast([f'{current_user.id}'], JSONB)),
277+
DsRules.user_list.op('@>')(cast([current_user.id], JSONB))))
278+
).first()
279+
if obj is not None:
280+
res.append(transRecord2DTO(session, permission))
281+
wheres = transFilterTree(session, res, ds)
282+
where = (' where ' + wheres) if wheres is not None and wheres != '' else ''
264283

265284
fields = [f.field_name for f in f_list]
266285
if fields is None or len(fields) == 0:
267286
return {"fields": [], "data": [], "sql": ''}
268287

269-
# row permission tree
270-
row_permissions = session.query(DsPermission).filter(
271-
and_(DsPermission.table_id == data.table.id, DsPermission.type == 'row')).all()
272-
res: List[PermissionDTO] = []
273-
if row_permissions is not None:
274-
for permission in row_permissions:
275-
# check permission and user in same rules
276-
obj = session.query(DsRules).filter(
277-
and_(DsRules.permission_list.op('@>')(cast([permission.id], JSONB)),
278-
or_(DsRules.user_list.op('@>')(cast([f'{current_user.id}'], JSONB)),
279-
DsRules.user_list.op('@>')(cast([current_user.id], JSONB))))
280-
).first()
281-
if obj is not None:
282-
res.append(transRecord2DTO(session, permission))
283-
wheres = transFilterTree(session, res, ds)
284-
where = (' where ' + wheres) if wheres is not None and wheres != '' else ''
285-
286288
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
287289
sql: str = ""
288290
if ds.type == "mysql":
@@ -346,19 +348,20 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
346348
fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all()
347349

348350
# do column permissions, filter fields
349-
column_permissions = session.query(DsPermission).filter(
350-
and_(DsPermission.table_id == table.id, DsPermission.type == 'column')).all()
351-
if column_permissions is not None:
352-
for permission in column_permissions:
353-
# check permission and user in same rules
354-
obj = session.query(DsRules).filter(
355-
and_(DsRules.permission_list.op('@>')(cast([permission.id], JSONB)),
356-
or_(DsRules.user_list.op('@>')(cast([f'{current_user.id}'], JSONB)),
357-
DsRules.user_list.op('@>')(cast([current_user.id], JSONB))))
358-
).first()
359-
if obj is not None:
360-
permission_list = json.loads(permission.permissions)
361-
fields = filter_list(fields, permission_list)
351+
if is_normal_user(current_user):
352+
column_permissions = session.query(DsPermission).filter(
353+
and_(DsPermission.table_id == table.id, DsPermission.type == 'column')).all()
354+
if column_permissions is not None:
355+
for permission in column_permissions:
356+
# check permission and user in same rules
357+
obj = session.query(DsRules).filter(
358+
and_(DsRules.permission_list.op('@>')(cast([permission.id], JSONB)),
359+
or_(DsRules.user_list.op('@>')(cast([f'{current_user.id}'], JSONB)),
360+
DsRules.user_list.op('@>')(cast([current_user.id], JSONB))))
361+
).first()
362+
if obj is not None:
363+
permission_list = json.loads(permission.permissions)
364+
fields = filter_list(fields, permission_list)
362365

363366
_list.append(TableAndFields(schema=schema, table=table, fields=fields))
364367
return _list
@@ -402,3 +405,7 @@ def filter_list(list_a, list_b):
402405
id_to_invalid[b['field_id']] = True
403406

404407
return [a for a in list_a if not id_to_invalid.get(a.id, False)]
408+
409+
410+
def is_normal_user(current_user: CurrentUser):
411+
return current_user.id != 1 and (current_user.weight is not None and current_user.weight != 1)

0 commit comments

Comments
 (0)