Skip to content

Commit 4422780

Browse files
committed
Add more args check; fix typo in docstring
1 parent ca24908 commit 4422780

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

padding_oracle/padding_oracle.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,19 @@ def padding_oracle(cipher: bytes,
7272
oracle (function) a function: oracle(cipher: bytes) -> bool
7373
num_threads (int) how many oracle functions will be run in parallel (default: 1)
7474
log_level (int) log level (default: logging.INFO)
75-
null (bytes) the null byte if the (default: b' ')
75+
null (bytes) the default byte when plaintext are not set (default: b' ')
7676
7777
Returns:
7878
plaintext (bytes) the decrypted plaintext
7979
'''
8080

81-
# Check the oracle function
81+
# Check args
8282
assert callable(oracle), 'the oracle function should be callable'
8383
assert oracle.__code__.co_argcount == 1, 'expect oracle function with only 1 argument'
84+
assert isinstance(cipher, bytes), 'cipher should have type bytes'
85+
assert isinstance(block_size, int), 'block_size should have type int'
8486
assert len(cipher) % block_size == 0, 'cipher length should be multiple of block size'
87+
assert 1 <= num_threads <= 1000, 'num_threads should be in [1, 1000]'
8588
assert isinstance(null, bytes), 'expect null with type bytes'
8689
assert len(null) == 1, 'null byte should have length of 1'
8790

@@ -97,7 +100,7 @@ def _oracle_wrapper(i: int, j: int, cipher: bytes):
97100
logger.debug('error details at block[{}][{}]: ', i, j, traceback.format_exc())
98101
return False
99102

100-
# The plaintext bytes list to save the decrypted data
103+
# The plaintext bytes list to store the decrypted data
101104
plaintext = [null] * (len(cipher) - block_size)
102105

103106
# Update the decrypted plaintext list
@@ -145,14 +148,12 @@ def _block_decrypt_task(i, prev: bytes, block: bytes):
145148
_update_plaintext(i * block_size - j, bytes([p]))
146149

147150
for n in range(j):
148-
guess_list[-n-1] ^= j
149-
guess_list[-n-1] ^= j + 1
151+
guess_list[-n-1] ^= j ^ (j + 1)
150152

151153
blocks = []
152154

153155
for i in range(0, len(cipher), block_size):
154-
j = i + block_size
155-
blocks.append(cipher[i:j])
156+
blocks.append(cipher[i:i + block_size])
156157

157158
logger.debug('blocks: {}'.format(blocks))
158159

0 commit comments

Comments
 (0)