Skip to content

Commit dfc1c12

Browse files
perf: Embedded Data Source Interface Supports AES Encryption
1 parent 8a03ea8 commit dfc1c12

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

backend/apps/system/api/assistant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ async def info(request: Request, response: Response, session: SessionDep, trans:
3131
db_model = AssistantModel.model_validate(db_model)
3232
response.headers["Access-Control-Allow-Origin"] = db_model.domain
3333
origin = request.headers.get("origin") or get_origin_from_referer(request)
34+
if not origin:
35+
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or ''))
3436
origin = origin.rstrip('/')
3537
if origin != db_model.domain:
3638
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or ''))

backend/apps/system/crud/assistant.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from common.core.config import settings
1616
from common.core.db import engine
1717
from common.core.sqlbot_cache import cache
18+
from common.utils.aes_crypto import simple_aes_decrypt
1819
from common.utils.utils import string_to_numeric_hash
1920

20-
2121
@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
2222
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
2323
db_model = session.get(AssistantModel, assistant_id)
@@ -117,13 +117,13 @@ def get_ds_from_api(self):
117117
res = requests.get(url=endpoint, params=param, headers=header, cookies=cookies, timeout=10)
118118
if res.status_code == 200:
119119
result_json: dict[any] = json.loads(res.text)
120-
if result_json.get('code') == 0:
120+
if result_json.get('code') == 0 or result_json.get('code') == 200:
121121
temp_list = result_json.get('data', [])
122-
self.ds_list = [
123-
self.convert2schema(item)
122+
temp_ds_list = [
123+
self.convert2schema(item, config)
124124
for item in temp_list
125125
]
126-
126+
self.ds_list = temp_ds_list
127127
return self.ds_list
128128
else:
129129
raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}")
@@ -169,9 +169,19 @@ def get_ds(self, ds_id: int):
169169
raise Exception("Datasource list is not found.")
170170
raise Exception(f"Datasource with id {ds_id} not found.")
171171

172-
def convert2schema(self, ds_dict: dict) -> AssistantOutDsSchema:
172+
def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema:
173173
id_marker: str = ''
174174
attr_list = ['name', 'type', 'host', 'port', 'user', 'dataBase', 'schema']
175+
if config.get('encrypt', True):
176+
key = config.get('aes_key', None)
177+
iv = config.get('aes_iv', None)
178+
aes_attrs = ['host', 'user', 'password', 'dataBase', 'db_schema']
179+
for attr in aes_attrs:
180+
if attr in ds_dict and ds_dict[attr]:
181+
try:
182+
ds_dict[attr] = simple_aes_decrypt(ds_dict[attr], key, iv)
183+
except Exception as e:
184+
raise Exception(f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}")
175185
for attr in attr_list:
176186
if attr in ds_dict:
177187
id_marker += str(ds_dict.get(attr, '')) + '--sqlbot--'
@@ -180,7 +190,6 @@ def convert2schema(self, ds_dict: dict) -> AssistantOutDsSchema:
180190
ds_dict.pop("schema", None)
181191
return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema})
182192

183-
184193
class AssistantOutDsFactory:
185194
@staticmethod
186195
def get_instance(assistant: AssistantHeader) -> AssistantOutDs:

backend/common/utils/aes_crypto.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Optional
2+
from common.core.config import settings
3+
from sqlbot_xpack.aes_utils import SecureEncryption
4+
5+
simple_aes_iv_text = 'sqlbot_em_aes_iv'
6+
def sqlbot_aes_encrypt(text: str, key: Optional[str] = None) -> str:
7+
return SecureEncryption.encrypt_to_single_string(text, key or settings.SECRET_KEY)
8+
9+
def sqlbot_aes_decrypt(text: str, key: Optional[str] = None) -> str:
10+
return SecureEncryption.decrypt_from_single_string(text, key or settings.SECRET_KEY)
11+
12+
def simple_aes_encrypt(text: str, key: Optional[str] = None, ivtext: Optional[str] = None) -> str:
13+
return SecureEncryption.simple_aes_encrypt(text, key or settings.SECRET_KEY[:32], ivtext or simple_aes_iv_text)
14+
15+
def simple_aes_decrypt(text: str, key: Optional[str] = None, ivtext: Optional[str] = None) -> str:
16+
return SecureEncryption.simple_aes_decrypt(text, key or settings.SECRET_KEY[:32], ivtext or simple_aes_iv_text)

0 commit comments

Comments
 (0)