Skip to content

Commit 37d886e

Browse files
committed
refactor: optimize RSA encryption and decryption functions with caching
1 parent 23147e5 commit 37d886e

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

apps/common/utils/rsa_util.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99
import base64
1010
import threading
11+
from functools import lru_cache
1112

1213
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
1314
from Crypto.PublicKey import RSA
@@ -70,7 +71,7 @@ def encrypt(msg, public_key: str | None = None):
7071
"""
7172
if public_key is None:
7273
public_key = get_key_pair().get('key')
73-
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
74+
cipher = _get_encrypt_cipher(public_key)
7475
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
7576
return base64.b64encode(encrypt_msg).decode()
7677

@@ -84,56 +85,69 @@ def decrypt(msg, pri_key: str | None = None):
8485
"""
8586
if pri_key is None:
8687
pri_key = get_key_pair().get('value')
87-
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
88+
cipher = _get_cipher(pri_key)
8889
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
8990
return decrypt_data.decode("utf-8")
9091

9192

93+
94+
@lru_cache(maxsize=2)
95+
def _get_encrypt_cipher(public_key: str):
96+
"""缓存加密 cipher 对象"""
97+
return PKCS1_cipher.new(RSA.importKey(extern_key=public_key, passphrase=secret_code))
98+
99+
92100
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
93101
"""
94102
超长文本加密
95103
96104
:param message: 需要加密的字符串
97105
:param public_key 公钥
98-
:param length: 1024bit的证书用100 2048bit的证书用 200
106+
:param length: 1024bit的证书用100, 2048bit的证书用 200
99107
:return: 加密后的数据
100108
"""
101-
# 读取公钥
102109
if public_key is None:
103110
public_key = get_key_pair().get('key')
104-
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
105-
passphrase=secret_code))
106-
# 处理:Plaintext is too long. 分段加密
111+
112+
cipher = _get_encrypt_cipher(public_key)
113+
107114
if len(message) <= length:
108-
# 对编码的数据进行加密,并通过base64进行编码
109115
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
110116
else:
111117
rsa_text = []
112-
# 对编码后的数据进行切片,原因:加密长度不能过长
113118
for i in range(0, len(message), length):
114119
cont = message[i:i + length]
115-
# 对切片后的数据进行加密,并新增到text后面
116120
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
117-
# 加密完进行拼接
118121
cipher_text = b''.join(rsa_text)
119-
# base64进行编码
120122
result = base64.b64encode(cipher_text)
123+
121124
return result.decode()
122125

123126

127+
@lru_cache(maxsize=2)
128+
def _get_cipher(pri_key: str):
129+
"""缓存 cipher 对象,避免重复创建"""
130+
return PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
131+
132+
124133
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
125134
"""
126-
超长文本解密,默认不加密
135+
超长文本解密,优化内存使用
127136
:param message: 需要解密的数据
128137
:param pri_key: 秘钥
129-
:param length : 1024bit的证书用1282048bit证书用256位
138+
:param length : 1024bit的证书用128,2048bit证书用256位
130139
:return: 解密后的数据
131140
"""
132141
if pri_key is None:
133142
pri_key = get_key_pair().get('value')
134-
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
143+
144+
cipher = _get_cipher(pri_key)
135145
base64_de = base64.b64decode(message)
136-
res = []
146+
147+
# 使用 bytearray 减少内存分配
148+
result = bytearray()
137149
for i in range(0, len(base64_de), length):
138-
res.append(cipher.decrypt(base64_de[i:i + length], 0))
139-
return b"".join(res).decode()
150+
result.extend(cipher.decrypt(base64_de[i:i + length], 0))
151+
152+
return result.decode()
153+

0 commit comments

Comments
 (0)