Skip to content

Commit f60db00

Browse files
committed
refactor: 💡 _ntt function
1 parent 96f7d5e commit f60db00

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

pyencrypt/encrypt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def encrypt_key(key: bytes) -> str:
3232
ascii_ls = [ord(x) for x in key.decode()]
3333
numbers = generate_rsa_number(2048)
3434
e, n = numbers['e'], numbers['n']
35+
# fill length to be a power of 2
36+
length = len(ascii_ls)
37+
if length & (length - 1) != 0:
38+
length = 1 << length.bit_length()
39+
ascii_ls = ascii_ls + [0] * (length - len(ascii_ls))
3540
cipher_ls = list()
3641
# ntt后再用RSA加密
3742
for num in ntt(ascii_ls):

pyencrypt/ntt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ def bitreverse(x: list, length: int):
1616

1717

1818
def _ntt(arr: list, inverse=False):
19-
length = 1
20-
while length < len(arr):
21-
length *= 2
22-
x = arr + [0] * (length - len(arr))
19+
length = len(arr)
20+
if length & (length - 1) != 0:
21+
raise ValueError("The length of input must be a power of 2.")
22+
x = arr.copy()
2323
g = pow(G, (M - 1) // length, M)
2424
if inverse:
2525
g = pow(g, M - 2, M)

0 commit comments

Comments
 (0)