21
21
'''
22
22
23
23
import logging
24
+ import traceback
24
25
from typing import Union , Callable
25
26
from concurrent .futures import ThreadPoolExecutor
26
27
33
34
34
35
35
36
def remove_padding (data : Union [str , bytes ]):
37
+ '''
38
+ Remove PKCS#7 padding bytes.
39
+
40
+ Args:
41
+ data (str | bytes)
42
+
43
+ Returns:
44
+ data with padding removed (bytes)
45
+ '''
36
46
data = _to_bytes (data )
37
47
return data [:- data [- 1 ]]
38
48
39
49
50
+ def _get_logger ():
51
+ logger = logging .getLogger ('padding_oracle' )
52
+ formatter = logging .Formatter ('[%(asctime)s][%(levelname)s] %(message)s' )
53
+ handler = logging .StreamHandler ()
54
+ handler .setFormatter (formatter )
55
+ logger .addHandler (handler )
56
+ return logger
57
+
58
+
40
59
def padding_oracle (cipher : bytes ,
41
60
block_size : int ,
42
61
oracle : Callable [[bytes ], bool ],
43
62
num_threads : int = 1 ,
44
63
log_level : int = logging .INFO ,
45
64
null : bytes = b' ' ) -> bytes :
65
+ '''
66
+ Run padding oracle attack to decrypt cipher given a function to check wether the cipher
67
+ can be decrypted successfully.
68
+
69
+ Args:
70
+ cipher (bytes) the cipher you want to decrypt
71
+ block_size (int) block size (the cipher length should be multiple of this)
72
+ oracle (function) a function: oracle(cipher: bytes) -> bool
73
+ num_threads (int) how many oracle functions will be run in parallel (default: 1)
74
+ log_level (int) log level (default: logging.INFO)
75
+ null (bytes) the null byte if the (default: b' ')
76
+
77
+ Returns:
78
+ plaintext (bytes) the decrypted plaintext
79
+ '''
80
+
46
81
# Check the oracle function
47
82
assert callable (oracle ), 'the oracle function should be callable'
48
83
assert oracle .__code__ .co_argcount == 1 , 'expect oracle function with only 1 argument'
49
84
assert len (cipher ) % block_size == 0 , 'cipher length should be multiple of block size'
85
+ assert isinstance (null , bytes ), 'expect null with type bytes'
86
+ assert len (null ) == 1 , 'null byte should have length of 1'
50
87
51
- logger = logging . getLogger ( 'padding_oracle' )
88
+ logger = _get_logger ( )
52
89
logger .setLevel (log_level )
53
- formatter = logging .Formatter ('[%(asctime)s][%(levelname)s] %(message)s' )
54
- handler = logging .StreamHandler ()
55
- handler .setFormatter (formatter )
56
- logger .addHandler (handler )
90
+
91
+ # Wrapper to handle exception from the oracle function
92
+ def _oracle_wrapper (i : int , j : int , cipher : bytes ):
93
+ try :
94
+ return oracle (cipher )
95
+ except Exception as e :
96
+ logger .error ('unhandled error at block[{}][{}]: ' , i , j , e )
97
+ logger .debug ('error details at block[{}][{}]: ' , i , j , traceback .format_exc ())
98
+ return False
57
99
58
100
# The plaintext bytes list to save the decrypted data
59
101
plaintext = [null ] * (len (cipher ) - block_size )
60
102
103
+ # Update the decrypted plaintext list
61
104
def _update_plaintext (i : int , c : bytes ):
62
105
plaintext [i ] = c
63
106
logger .info ('plaintext: {}' .format (b'' .join (plaintext )))
64
107
65
108
oracle_executor = ThreadPoolExecutor (max_workers = num_threads )
66
109
110
+ # Block decrypting task to be run in parallel
67
111
def _block_decrypt_task (i , prev : bytes , block : bytes ):
68
112
logger .debug ('task={} prev={} block={}' .format (i , prev , block ))
69
113
guess_list = list (prev )
@@ -80,7 +124,7 @@ def _block_decrypt_task(i, prev: bytes, block: bytes):
80
124
test_list = guess_list .copy ()
81
125
test_list [- j ] = k
82
126
oracle_futures [k ] = oracle_executor .submit (
83
- oracle , bytes (test_list ) + block )
127
+ _oracle_wrapper , i , j , bytes (test_list ) + block )
84
128
85
129
for k , future in oracle_futures .items ():
86
130
if future .result ():
@@ -89,6 +133,7 @@ def _block_decrypt_task(i, prev: bytes, block: bytes):
89
133
logger .debug (
90
134
'oracles at block[{}][{}] -> {}' .format (i , block_size - j , oracle_hits ))
91
135
136
+ # Number of oracle hits should be 1, or we just ignore this block
92
137
if len (oracle_hits ) != 1 :
93
138
logfmt = 'at block[{}][{}]: expect only one hit, got {}. (skipped)'
94
139
logger .error (logfmt .format (i , block_size - j , len (oracle_hits )))
0 commit comments