Skip to content

Commit 9598835

Browse files
authored
Merge pull request #24 from kevinlin09/feat/support_encrypt
feat: support encryption
2 parents 1f336e0 + bf30707 commit 9598835

File tree

9 files changed

+343
-4
lines changed

9 files changed

+343
-4
lines changed

dashscope/api_entities/api_request_factory.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
32
from urllib.parse import urlencode
43

54
import dashscope
@@ -10,8 +9,9 @@
109
SERVICE_API_PATH, ApiProtocol,
1110
HTTPMethod)
1211
from dashscope.common.error import InputDataRequired, UnsupportedApiProtocol
12+
from dashscope.common.logging import logger
1313
from dashscope.protocol.websocket import WebsocketStreamingMode
14-
14+
from dashscope.api_entities.encryption import Encryption
1515

1616
def _get_protocol_params(kwargs):
1717
api_protocol = kwargs.pop('api_protocol', ApiProtocol.HTTPS)
@@ -49,6 +49,9 @@ def _build_api_request(model: str,
4949
base_address, flattened_output,
5050
extra_url_parameters) = _get_protocol_params(kwargs)
5151
task_id = kwargs.pop('task_id', None)
52+
enable_encryption = kwargs.pop('enable_encryption', False)
53+
encryption = None
54+
5255
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
5356
if base_address is None:
5457
base_address = dashscope.base_http_api_url
@@ -69,6 +72,12 @@ def _build_api_request(model: str,
6972
if extra_url_parameters is not None and extra_url_parameters:
7073
http_url += '?' + urlencode(extra_url_parameters)
7174

75+
if enable_encryption is True:
76+
encryption = Encryption()
77+
encryption.initialize()
78+
if encryption.is_valid():
79+
logger.debug('encryption enabled')
80+
7281
request = HttpRequest(url=http_url,
7382
api_key=api_key,
7483
http_method=http_method,
@@ -77,7 +86,8 @@ def _build_api_request(model: str,
7786
query=query,
7887
timeout=request_timeout,
7988
task_id=task_id,
80-
flattened_output=flattened_output)
89+
flattened_output=flattened_output,
90+
encryption=encryption)
8191
elif api_protocol == ApiProtocol.WEBSOCKET:
8292
if base_address is not None:
8393
websocket_url = base_address
@@ -103,6 +113,9 @@ def _build_api_request(model: str,
103113
if input is None and form is None:
104114
raise InputDataRequired('There is no input data and form data')
105115

116+
if encryption and encryption.is_valid():
117+
input = encryption.encrypt(input)
118+
106119
request_data = ApiRequestData(model,
107120
task_group=task_group,
108121
task=task,
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import base64
3+
import json
4+
from dataclasses import dataclass
5+
import os
6+
from typing import Optional
7+
8+
import requests
9+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
10+
from cryptography.hazmat.primitives import serialization, hashes
11+
from cryptography.hazmat.primitives.asymmetric import padding
12+
from cryptography.hazmat.backends import default_backend
13+
14+
import dashscope
15+
from dashscope.common.constants import ENCRYPTION_AES_SECRET_KEY_BYTES, ENCRYPTION_AES_IV_LENGTH
16+
from dashscope.common.logging import logger
17+
18+
19+
class Encryption:
20+
def __init__(self):
21+
self.pub_key_id: str = ''
22+
self.pub_key_str: str = ''
23+
self.aes_key_bytes: bytes = b''
24+
self.encrypted_aes_key_str: str = ''
25+
self.iv_bytes: bytes = b''
26+
self.base64_iv_str: str = ''
27+
self.valid: bool = False
28+
29+
def initialize(self):
30+
public_keys = self._get_public_keys()
31+
if not public_keys:
32+
return
33+
34+
public_key_str = public_keys.get('public_key')
35+
public_key_id = public_keys.get('public_key_id')
36+
if not public_key_str or not public_key_id:
37+
logger.error("public keys data not valid")
38+
return
39+
40+
aes_key_bytes = self._generate_aes_secret_key()
41+
iv_bytes = self._generate_iv()
42+
43+
encrypted_aes_key_str = self._encrypt_aes_key_with_rsa(aes_key_bytes, public_key_str)
44+
base64_iv_str = base64.b64encode(iv_bytes).decode('utf-8')
45+
46+
self.pub_key_id = public_key_id
47+
self.pub_key_str = public_key_str
48+
self.aes_key_bytes = aes_key_bytes
49+
self.encrypted_aes_key_str = encrypted_aes_key_str
50+
self.iv_bytes = iv_bytes
51+
self.base64_iv_str = base64_iv_str
52+
53+
self.valid = True
54+
55+
def encrypt(self, dict_plaintext):
56+
return self._encrypt_text_with_aes(json.dumps(dict_plaintext, ensure_ascii=False),
57+
self.aes_key_bytes, self.iv_bytes)
58+
59+
def decrypt(self, base64_ciphertext):
60+
return self._decrypt_text_with_aes(base64_ciphertext, self.aes_key_bytes, self.iv_bytes)
61+
62+
def is_valid(self):
63+
return self.valid
64+
65+
def get_pub_key_id(self):
66+
return self.pub_key_id
67+
68+
def get_encrypted_aes_key_str(self):
69+
return self.encrypted_aes_key_str
70+
71+
def get_base64_iv_str(self):
72+
return self.base64_iv_str
73+
74+
@staticmethod
75+
def _get_public_keys():
76+
url = dashscope.base_http_api_url + '/public-keys/latest'
77+
headers = {
78+
"Authorization": f"Bearer {dashscope.api_key}"
79+
}
80+
81+
response = requests.get(url, headers=headers)
82+
if response.status_code != 200:
83+
logger.error("exceptional public key response: %s" % response)
84+
return None
85+
86+
json_resp = response.json()
87+
response_data = json_resp.get('data')
88+
89+
if not response_data:
90+
logger.error("no valid data in public key response")
91+
return None
92+
93+
return response_data
94+
95+
@staticmethod
96+
def _generate_aes_secret_key():
97+
return os.urandom(ENCRYPTION_AES_SECRET_KEY_BYTES)
98+
99+
@staticmethod
100+
def _generate_iv():
101+
return os.urandom(ENCRYPTION_AES_IV_LENGTH)
102+
103+
@staticmethod
104+
def _encrypt_text_with_aes(plaintext, key, iv):
105+
"""使用AES-GCM加密数据"""
106+
107+
# 创建AES-GCM加密器
108+
aes_gcm = Cipher(
109+
algorithms.AES(key),
110+
modes.GCM(iv, tag=None),
111+
backend=default_backend()
112+
).encryptor()
113+
114+
# 关联数据设为空(根据需求可调整)
115+
aes_gcm.authenticate_additional_data(b'')
116+
117+
# 加密数据
118+
ciphertext = aes_gcm.update(plaintext.encode('utf-8')) + aes_gcm.finalize()
119+
120+
# 获取认证标签
121+
tag = aes_gcm.tag
122+
123+
# 组合密文和标签
124+
encrypted_data = ciphertext + tag
125+
126+
# 返回Base64编码结果
127+
return base64.b64encode(encrypted_data).decode('utf-8')
128+
129+
@staticmethod
130+
def _decrypt_text_with_aes(base64_ciphertext, aes_key, iv):
131+
"""使用AES-GCM解密响应"""
132+
133+
# 解码Base64数据
134+
encrypted_data = base64.b64decode(base64_ciphertext)
135+
136+
# 分离密文和标签(标签长度16字节)
137+
ciphertext = encrypted_data[:-16]
138+
tag = encrypted_data[-16:]
139+
140+
# 创建AES-GCM解密器
141+
aes_gcm = Cipher(
142+
algorithms.AES(aes_key),
143+
modes.GCM(iv, tag),
144+
backend=default_backend()
145+
).decryptor()
146+
147+
# 验证关联数据(与加密时一致)
148+
aes_gcm.authenticate_additional_data(b'')
149+
150+
# 解密数据
151+
decrypted_bytes = aes_gcm.update(ciphertext) + aes_gcm.finalize()
152+
153+
# 明文
154+
plaintext = decrypted_bytes.decode('utf-8')
155+
156+
return json.loads(plaintext)
157+
158+
@staticmethod
159+
def _encrypt_aes_key_with_rsa(aes_key, public_key_str):
160+
"""使用RSA公钥加密AES密钥"""
161+
162+
# 解码Base64格式的公钥
163+
public_key_bytes = base64.b64decode(public_key_str)
164+
165+
# 加载公钥
166+
public_key = serialization.load_der_public_key(
167+
public_key_bytes,
168+
backend=default_backend()
169+
)
170+
171+
base64_aes_key = base64.b64encode(aes_key).decode('utf-8')
172+
173+
# 使用RSA加密
174+
encrypted_bytes = public_key.encrypt(
175+
base64_aes_key.encode('utf-8'),
176+
padding.PKCS1v15()
177+
)
178+
179+
return base64.b64encode(encrypted_bytes).decode('utf-8')

dashscope/api_entities/http_request.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
from http import HTTPStatus
5+
from typing import Optional
56

67
import aiohttp
78
import requests
@@ -16,6 +17,7 @@
1617
_handle_aiohttp_failed_response,
1718
_handle_http_failed_response,
1819
_handle_stream)
20+
from dashscope.api_entities.encryption import Encryption
1921

2022

2123
class HttpRequest(AioBaseRequest):
@@ -28,7 +30,8 @@ def __init__(self,
2830
query: bool = False,
2931
timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS,
3032
task_id: str = None,
31-
flattened_output: bool = False) -> None:
33+
flattened_output: bool = False,
34+
encryption: Optional[Encryption] = None) -> None:
3235
"""HttpSSERequest, processing http server sent event stream.
3336
3437
Args:
@@ -44,11 +47,23 @@ def __init__(self,
4447
self.url = url
4548
self.flattened_output = flattened_output
4649
self.async_request = async_request
50+
self.encryption = encryption
4751
self.headers = {
4852
'Accept': 'application/json',
4953
'Authorization': 'Bearer %s' % api_key,
5054
**self.headers,
5155
}
56+
57+
if encryption and encryption.is_valid():
58+
self.headers = {
59+
"X-DashScope-EncryptionKey": json.dumps({
60+
"public_key_id": encryption.get_pub_key_id(),
61+
"encrypt_key": encryption.get_encrypted_aes_key_str(),
62+
"iv": encryption.get_base64_iv_str()
63+
}),
64+
**self.headers,
65+
}
66+
5267
self.query = query
5368
if self.async_request and self.query is False:
5469
self.headers = {
@@ -168,6 +183,8 @@ async def _handle_aio_response(self, response: aiohttp.ClientResponse):
168183
code=msg['code'],
169184
message=msg['message'])
170185
else:
186+
if self.encryption and self.encryption.is_valid():
187+
output = self.encryption.decrypt(output)
171188
yield DashScopeAPIResponse(request_id=request_id,
172189
status_code=HTTPStatus.OK,
173190
output=output,
@@ -183,6 +200,8 @@ async def _handle_aio_response(self, response: aiohttp.ClientResponse):
183200
output[part.name] = await part.read()
184201
if 'request_id' in output:
185202
request_id = output['request_id']
203+
if self.encryption and self.encryption.is_valid():
204+
output = self.encryption.decrypt(output)
186205
yield DashScopeAPIResponse(request_id=request_id,
187206
status_code=HTTPStatus.OK,
188207
output=output)
@@ -196,6 +215,8 @@ async def _handle_aio_response(self, response: aiohttp.ClientResponse):
196215
usage = json_content['usage']
197216
if 'request_id' in json_content:
198217
request_id = json_content['request_id']
218+
if self.encryption and self.encryption.is_valid():
219+
output = self.encryption.decrypt(output)
199220
yield DashScopeAPIResponse(request_id=request_id,
200221
status_code=HTTPStatus.OK,
201222
output=output,
@@ -243,6 +264,8 @@ def _handle_response(self, response: requests.Response):
243264
if self.flattened_output:
244265
yield msg
245266
else:
267+
if self.encryption and self.encryption.is_valid():
268+
output = self.encryption.decrypt(output)
246269
yield DashScopeAPIResponse(request_id=request_id,
247270
status_code=HTTPStatus.OK,
248271
output=output,
@@ -263,6 +286,8 @@ def _handle_response(self, response: requests.Response):
263286
if self.flattened_output:
264287
yield json_content
265288
else:
289+
if self.encryption and self.encryption.is_valid():
290+
output = self.encryption.decrypt(output)
266291
yield DashScopeAPIResponse(request_id=request_id,
267292
status_code=HTTPStatus.OK,
268293
output=output,

dashscope/common/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
REQUEST_CONTENT_AUDIO = 'audio'
3838
FILE_PATH_SCHEMA = 'file://'
3939

40+
ENCRYPTION_AES_SECRET_KEY_BYTES = 32
41+
ENCRYPTION_AES_IV_LENGTH = 12
42+
4043
REPEATABLE_STATUS = [
4144
HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.GATEWAY_TIMEOUT
4245
]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
aiohttp
22
requests
33
websocket-client
4+
cryptography

samples/test_generation.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
from dashscope import Generation
3+
4+
messages = [
5+
{"role": "system", "content": "You are a helpful assistant."},
6+
{"role": "user", "content": "你是谁?"},
7+
]
8+
response = Generation.call(
9+
api_key=os.getenv("DASHSCOPE_API_KEY"),
10+
model="qwen-plus",
11+
messages=messages,
12+
result_format="message",
13+
enable_encryption=True,
14+
stream=True,
15+
)
16+
17+
for chunk in response:
18+
print(chunk.output.choices[0].message.content)
19+
20+
# if response.status_code == 200:
21+
# print(response.output.choices[0].message.content)
22+
# else:
23+
# print(f"HTTP返回码:{response.status_code}")
24+
# print(f"错误码:{response.code}")
25+
# print(f"错误信息:{response.message}")
26+
# print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code")

0 commit comments

Comments
 (0)