Skip to content

Commit ca24908

Browse files
committed
docstrings & param checks & oracle exception handle
1 parent b018871 commit ca24908

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

padding_oracle/padding_oracle.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
'''
2222

2323
import logging
24+
import traceback
2425
from typing import Union, Callable
2526
from concurrent.futures import ThreadPoolExecutor
2627

@@ -33,37 +34,80 @@
3334

3435

3536
def remove_padding(data: Union[str, bytes]):
37+
'''
38+
Remove PKCS#7 padding bytes.
39+
40+
Args:
41+
data (str | bytes)
42+
43+
Returns:
44+
data with padding removed (bytes)
45+
'''
3646
data = _to_bytes(data)
3747
return data[:-data[-1]]
3848

3949

50+
def _get_logger():
51+
logger = logging.getLogger('padding_oracle')
52+
formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s')
53+
handler = logging.StreamHandler()
54+
handler.setFormatter(formatter)
55+
logger.addHandler(handler)
56+
return logger
57+
58+
4059
def padding_oracle(cipher: bytes,
4160
block_size: int,
4261
oracle: Callable[[bytes], bool],
4362
num_threads: int = 1,
4463
log_level: int = logging.INFO,
4564
null: bytes = b' ') -> bytes:
65+
'''
66+
Run padding oracle attack to decrypt cipher given a function to check wether the cipher
67+
can be decrypted successfully.
68+
69+
Args:
70+
cipher (bytes) the cipher you want to decrypt
71+
block_size (int) block size (the cipher length should be multiple of this)
72+
oracle (function) a function: oracle(cipher: bytes) -> bool
73+
num_threads (int) how many oracle functions will be run in parallel (default: 1)
74+
log_level (int) log level (default: logging.INFO)
75+
null (bytes) the null byte if the (default: b' ')
76+
77+
Returns:
78+
plaintext (bytes) the decrypted plaintext
79+
'''
80+
4681
# Check the oracle function
4782
assert callable(oracle), 'the oracle function should be callable'
4883
assert oracle.__code__.co_argcount == 1, 'expect oracle function with only 1 argument'
4984
assert len(cipher) % block_size == 0, 'cipher length should be multiple of block size'
85+
assert isinstance(null, bytes), 'expect null with type bytes'
86+
assert len(null) == 1, 'null byte should have length of 1'
5087

51-
logger = logging.getLogger('padding_oracle')
88+
logger = _get_logger()
5289
logger.setLevel(log_level)
53-
formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s')
54-
handler = logging.StreamHandler()
55-
handler.setFormatter(formatter)
56-
logger.addHandler(handler)
90+
91+
# Wrapper to handle exception from the oracle function
92+
def _oracle_wrapper(i: int, j: int, cipher: bytes):
93+
try:
94+
return oracle(cipher)
95+
except Exception as e:
96+
logger.error('unhandled error at block[{}][{}]: ', i, j, e)
97+
logger.debug('error details at block[{}][{}]: ', i, j, traceback.format_exc())
98+
return False
5799

58100
# The plaintext bytes list to save the decrypted data
59101
plaintext = [null] * (len(cipher) - block_size)
60102

103+
# Update the decrypted plaintext list
61104
def _update_plaintext(i: int, c: bytes):
62105
plaintext[i] = c
63106
logger.info('plaintext: {}'.format(b''.join(plaintext)))
64107

65108
oracle_executor = ThreadPoolExecutor(max_workers=num_threads)
66109

110+
# Block decrypting task to be run in parallel
67111
def _block_decrypt_task(i, prev: bytes, block: bytes):
68112
logger.debug('task={} prev={} block={}'.format(i, prev, block))
69113
guess_list = list(prev)
@@ -80,7 +124,7 @@ def _block_decrypt_task(i, prev: bytes, block: bytes):
80124
test_list = guess_list.copy()
81125
test_list[-j] = k
82126
oracle_futures[k] = oracle_executor.submit(
83-
oracle, bytes(test_list) + block)
127+
_oracle_wrapper, i, j, bytes(test_list) + block)
84128

85129
for k, future in oracle_futures.items():
86130
if future.result():
@@ -89,6 +133,7 @@ def _block_decrypt_task(i, prev: bytes, block: bytes):
89133
logger.debug(
90134
'oracles at block[{}][{}] -> {}'.format(i, block_size - j, oracle_hits))
91135

136+
# Number of oracle hits should be 1, or we just ignore this block
92137
if len(oracle_hits) != 1:
93138
logfmt = 'at block[{}][{}]: expect only one hit, got {}. (skipped)'
94139
logger.error(logfmt.format(i, block_size-j, len(oracle_hits)))

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name='padding_oracle',
8-
version='0.1.3',
8+
version='0.1.4',
99
author='Yuankui Lee',
1010
author_email='[email protected]',
1111
description='Threaded padding oracle automation.',

0 commit comments

Comments
 (0)