88"""
99import base64
1010import threading
11+ from functools import lru_cache
1112
1213from Crypto .Cipher import PKCS1_v1_5 as PKCS1_cipher
1314from Crypto .PublicKey import RSA
@@ -70,7 +71,7 @@ def encrypt(msg, public_key: str | None = None):
7071 """
7172 if public_key is None :
7273 public_key = get_key_pair ().get ('key' )
73- cipher = PKCS1_cipher . new ( RSA . importKey ( public_key ) )
74+ cipher = _get_encrypt_cipher ( public_key )
7475 encrypt_msg = cipher .encrypt (msg .encode ("utf-8" ))
7576 return base64 .b64encode (encrypt_msg ).decode ()
7677
@@ -84,56 +85,69 @@ def decrypt(msg, pri_key: str | None = None):
8485 """
8586 if pri_key is None :
8687 pri_key = get_key_pair ().get ('value' )
87- cipher = PKCS1_cipher . new ( RSA . importKey ( pri_key , passphrase = secret_code ) )
88+ cipher = _get_cipher ( pri_key )
8889 decrypt_data = cipher .decrypt (base64 .b64decode (msg ), 0 )
8990 return decrypt_data .decode ("utf-8" )
9091
9192
93+
94+ @lru_cache (maxsize = 2 )
95+ def _get_encrypt_cipher (public_key : str ):
96+ """缓存加密 cipher 对象"""
97+ return PKCS1_cipher .new (RSA .importKey (extern_key = public_key , passphrase = secret_code ))
98+
99+
92100def rsa_long_encrypt (message , public_key : str | None = None , length = 200 ):
93101 """
94102 超长文本加密
95103
96104 :param message: 需要加密的字符串
97105 :param public_key 公钥
98- :param length: 1024bit的证书用100, 2048bit的证书用 200
106+ :param length: 1024bit的证书用100, 2048bit的证书用 200
99107 :return: 加密后的数据
100108 """
101- # 读取公钥
102109 if public_key is None :
103110 public_key = get_key_pair ().get ('key' )
104- cipher = PKCS1_cipher . new ( RSA . importKey ( extern_key = public_key ,
105- passphrase = secret_code ) )
106- # 处理:Plaintext is too long. 分段加密
111+
112+ cipher = _get_encrypt_cipher ( public_key )
113+
107114 if len (message ) <= length :
108- # 对编码的数据进行加密,并通过base64进行编码
109115 result = base64 .b64encode (cipher .encrypt (message .encode ('utf-8' )))
110116 else :
111117 rsa_text = []
112- # 对编码后的数据进行切片,原因:加密长度不能过长
113118 for i in range (0 , len (message ), length ):
114119 cont = message [i :i + length ]
115- # 对切片后的数据进行加密,并新增到text后面
116120 rsa_text .append (cipher .encrypt (cont .encode ('utf-8' )))
117- # 加密完进行拼接
118121 cipher_text = b'' .join (rsa_text )
119- # base64进行编码
120122 result = base64 .b64encode (cipher_text )
123+
121124 return result .decode ()
122125
123126
127+ @lru_cache (maxsize = 2 )
128+ def _get_cipher (pri_key : str ):
129+ """缓存 cipher 对象,避免重复创建"""
130+ return PKCS1_cipher .new (RSA .importKey (pri_key , passphrase = secret_code ))
131+
132+
124133def rsa_long_decrypt (message , pri_key : str | None = None , length = 256 ):
125134 """
126- 超长文本解密,默认不加密
135+ 超长文本解密,优化内存使用
127136 :param message: 需要解密的数据
128137 :param pri_key: 秘钥
129- :param length : 1024bit的证书用128, 2048bit证书用256位
138+ :param length : 1024bit的证书用128, 2048bit证书用256位
130139 :return: 解密后的数据
131140 """
132141 if pri_key is None :
133142 pri_key = get_key_pair ().get ('value' )
134- cipher = PKCS1_cipher .new (RSA .importKey (pri_key , passphrase = secret_code ))
143+
144+ cipher = _get_cipher (pri_key )
135145 base64_de = base64 .b64decode (message )
136- res = []
146+
147+ # 使用 bytearray 减少内存分配
148+ result = bytearray ()
137149 for i in range (0 , len (base64_de ), length ):
138- res .append (cipher .decrypt (base64_de [i :i + length ], 0 ))
139- return b"" .join (res ).decode ()
150+ result .extend (cipher .decrypt (base64_de [i :i + length ], 0 ))
151+
152+ return result .decode ()
153+
0 commit comments