|
| 1 | +import json |
1 | 2 | import logging |
2 | 3 | import traceback |
3 | 4 | import warnings |
|
10 | 11 | from langchain.chat_models.base import BaseChatModel |
11 | 12 | from langchain_community.utilities import SQLDatabase |
12 | 13 | from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk |
| 14 | +from sqlalchemy import and_, cast |
13 | 15 | from sqlalchemy import select |
| 16 | +from sqlalchemy.dialects.postgresql import JSONB |
14 | 17 | from sqlalchemy.orm import load_only |
| 18 | +from sqlbot_xpack.permissions.api.permission import transRecord2DTO |
| 19 | +from sqlbot_xpack.permissions.models.ds_permission import DsPermission, PermissionDTO |
| 20 | +from sqlbot_xpack.permissions.models.ds_rules import DsRules |
15 | 21 |
|
16 | 22 | from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_default_config |
17 | 23 | from apps.chat.curd.chat import save_question, save_full_sql_message, save_full_sql_message_and_answer, save_sql, \ |
|
21 | 27 | get_old_questions, save_analysis_predict_record, list_base_records, rename_chat |
22 | 28 | from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat |
23 | 29 | from apps.datasource.crud.datasource import get_table_schema |
24 | | -from apps.datasource.models.datasource import CoreDatasource |
| 30 | +from apps.datasource.crud.row_permission import transFilterTree |
| 31 | +from apps.datasource.models.datasource import CoreDatasource, CoreTable |
25 | 32 | from apps.db.db import exec_sql |
26 | 33 | from apps.system.crud.assistant import get_assistant_ds |
27 | 34 | from common.core.config import settings |
@@ -77,7 +84,7 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question |
77 | 84 |
|
78 | 85 | # get schema |
79 | 86 | if ds: |
80 | | - chat_question.db_schema = get_table_schema(session=self.session, ds=ds) |
| 87 | + chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds) |
81 | 88 |
|
82 | 89 | chat_question.lang = current_user.language |
83 | 90 |
|
@@ -467,6 +474,78 @@ def generate_sql(self): |
467 | 474 | [{'type': msg.type, 'content': msg.content} for msg in |
468 | 475 | self.sql_message]).decode()) |
469 | 476 |
|
| 477 | + def generate_filter(self, sql: str, tables: List): |
| 478 | + table_list = self.session.query(CoreTable).filter( |
| 479 | + and_(CoreTable.ds_id == self.ds.id, CoreTable.table_name.in_(tables)) |
| 480 | + ).all() |
| 481 | + |
| 482 | + filters = [] |
| 483 | + for table in table_list: |
| 484 | + row_permissions = self.session.query(DsPermission).filter( |
| 485 | + and_(DsPermission.table_id == table.id, DsPermission.type == 'row')).all() |
| 486 | + res: List[PermissionDTO] = [] |
| 487 | + if row_permissions is not None: |
| 488 | + for permission in row_permissions: |
| 489 | + # check permission and user in same rules |
| 490 | + obj = self.session.query(DsRules).filter( |
| 491 | + and_(DsRules.permission_list.op('@>')(cast([permission.id], JSONB)), |
| 492 | + DsRules.user_list.op('@>')(cast([self.current_user.id], JSONB))) |
| 493 | + ).first() |
| 494 | + if obj is not None: |
| 495 | + res.append(transRecord2DTO(self.session, permission)) |
| 496 | + wheres = transFilterTree(self.session, res, self.ds) |
| 497 | + filters.append({"table": table.table_name, "filter": wheres}) |
| 498 | + |
| 499 | + filter = json.dumps(filters, ensure_ascii=False) |
| 500 | + # filter = f"""[{{"table":"{tables[0]}","filter":"省份 = '广东省' or 销售额(万元) > 10000"}}]""" # todo get filters |
| 501 | + self.chat_question.sql = sql |
| 502 | + self.chat_question.filter = filter |
| 503 | + msg: List[Union[BaseMessage, dict[str, Any]]] = [] |
| 504 | + msg.append(SystemMessage(content=self.chat_question.filter_sys_question())) |
| 505 | + msg.append(HumanMessage(content=self.chat_question.filter_user_question())) |
| 506 | + |
| 507 | + history_msg = [] |
| 508 | + # if self.record.full_analysis_message and self.record.full_analysis_message.strip() != '': |
| 509 | + # history_msg = orjson.loads(self.record.full_analysis_message) |
| 510 | + |
| 511 | + # self.record = save_full_analysis_message_and_answer(session=self.session, record_id=self.record.id, answer='', |
| 512 | + # full_message=orjson.dumps(history_msg + |
| 513 | + # [{'type': msg.type, |
| 514 | + # 'content': msg.content} for msg |
| 515 | + # in |
| 516 | + # msg]).decode()) |
| 517 | + full_thinking_text = '' |
| 518 | + full_filter_text = '' |
| 519 | + res = self.llm.stream(msg) |
| 520 | + token_usage = {} |
| 521 | + for chunk in res: |
| 522 | + print(chunk) |
| 523 | + reasoning_content_chunk = '' |
| 524 | + if 'reasoning_content' in chunk.additional_kwargs: |
| 525 | + reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
| 526 | + # else: |
| 527 | + # reasoning_content_chunk = chunk.get('reasoning_content') |
| 528 | + if reasoning_content_chunk is None: |
| 529 | + reasoning_content_chunk = '' |
| 530 | + full_thinking_text += reasoning_content_chunk |
| 531 | + |
| 532 | + full_filter_text += chunk.content |
| 533 | + # yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} |
| 534 | + get_token_usage(chunk, token_usage) |
| 535 | + |
| 536 | + msg.append(AIMessage(full_filter_text)) |
| 537 | + # self.record = save_full_analysis_message_and_answer(session=self.session, record_id=self.record.id, |
| 538 | + # token_usage=token_usage, |
| 539 | + # answer=orjson.dumps({'content': full_analysis_text, |
| 540 | + # 'reasoning_content': full_thinking_text}).decode(), |
| 541 | + # full_message=orjson.dumps(history_msg + |
| 542 | + # [{'type': msg.type, |
| 543 | + # 'content': msg.content} for msg |
| 544 | + # in |
| 545 | + # analysis_msg]).decode()) |
| 546 | + print(full_filter_text) |
| 547 | + return full_filter_text |
| 548 | + |
470 | 549 | def generate_chart(self): |
471 | 550 | # append current question |
472 | 551 | self.chart_message.append(HumanMessage(self.chat_question.chart_user_question())) |
@@ -663,7 +742,15 @@ def run_task(llm_service: LLMService, in_chat: bool = True): |
663 | 742 |
|
664 | 743 | # filter sql |
665 | 744 | print(full_sql_text) |
666 | | - sql = llm_service.check_save_sql(res=full_sql_text) |
| 745 | + |
| 746 | + # todo row permission |
| 747 | + sql_json_str = extract_nested_json(full_sql_text) |
| 748 | + data = orjson.loads(sql_json_str) |
| 749 | + sql_result = llm_service.generate_filter(data['sql'], data['tables']) |
| 750 | + print(sql_result) |
| 751 | + sql = llm_service.check_save_sql(res=sql_result) |
| 752 | + # sql = llm_service.check_save_sql(res=full_sql_text) |
| 753 | + |
667 | 754 | print(sql) |
668 | 755 | if in_chat: |
669 | 756 | yield orjson.dumps({'content': sql, 'type': 'sql'}).decode() + '\n\n' |
|
0 commit comments