33
33
'padding_oracle' ,
34
34
]
35
35
36
-
37
- def padding_oracle (ciphertext : Union [bytes , str ],
36
+ def padding_oracle (payload : Union [bytes , str ],
38
37
block_size : int ,
39
38
oracle : OracleFunc ,
40
39
num_threads : int = 1 ,
41
40
log_level : int = logging .INFO ,
42
41
null_byte : bytes = b' ' ,
43
42
return_raw : bool = False ,
43
+ mode : Union [bool , str ] = 'encrypt' ,
44
44
) -> Union [bytes , List [int ]]:
45
45
'''
46
46
Run padding oracle attack to decrypt ciphertext given a function to check
47
47
wether the ciphertext can be decrypted successfully.
48
48
49
49
Args:
50
- ciphertext (bytes|str) the ciphertext you want to decrypt
50
+ payload (bytes|str) the payload you want to encrypt/ decrypt
51
51
block_size (int) block size (the ciphertext length should be
52
52
multiple of this)
53
53
oracle (function) a function: oracle(ciphertext: bytes) -> bool
@@ -58,33 +58,48 @@ def padding_oracle(ciphertext: Union[bytes, str],
58
58
set (default: None)
59
59
return_raw (bool) do not convert plaintext into bytes and
60
60
unpad (default: False)
61
+ mode (bool|str) encrypt the payload (defaut: False/'decrypt')
62
+
61
63
62
64
Returns:
63
- plaintext (bytes|List[int]) the decrypted plaintext
65
+ result (bytes|List[int]) the processed payload
64
66
'''
65
67
66
68
# Check args
67
69
if not callable (oracle ):
68
70
raise TypeError ('the oracle function should be callable' )
69
- if not isinstance (ciphertext , (bytes , str )):
70
- raise TypeError ('ciphertext should have type bytes' )
71
+ if not isinstance (payload , (bytes , str )):
72
+ raise TypeError ('payload should have type bytes' )
71
73
if not isinstance (block_size , int ):
72
74
raise TypeError ('block_size should have type int' )
73
- if not len (ciphertext ) % block_size == 0 :
74
- raise ValueError ('ciphertext length should be multiple of block size' )
75
+ if not len (payload ) % block_size == 0 :
76
+ raise ValueError ('payload length should be multiple of block size' )
75
77
if not 1 <= num_threads <= 1000 :
76
78
raise ValueError ('num_threads should be in [1, 1000]' )
77
79
if not isinstance (null_byte , (bytes , str )):
78
80
raise TypeError ('expect null with type bytes or str' )
79
81
if not len (null_byte ) == 1 :
80
82
raise ValueError ('null byte should have length of 1' )
83
+ if not isinstance (mode , (bool , str )):
84
+ raise TypeError ('expect mode with type bool or str' )
85
+ if isinstance (mode , str ) and mode not in ('encrypt' , 'decrypt' ):
86
+ raise ValueError ('mode must be either encrypt or decrypt' )
81
87
82
88
logger = get_logger ()
83
89
logger .setLevel (log_level )
84
90
85
- ciphertext = to_bytes (ciphertext )
91
+ payload = to_bytes (payload )
86
92
null_byte = to_bytes (null_byte )
87
93
94
+
95
+ # encryption routine
96
+ if mode == 'encrypt' or mode :
97
+ return encrypt (payload , block_size , oracle , num_threads , null_byte , return_raw , logger )
98
+
99
+ # otherwise continue with decryption as normal
100
+ return decrypt (payload , block_size , oracle , num_threads , null_byte , return_raw , logger ):
101
+
102
+ def encrypt (payload , block_size , oracle , num_threads , null_byte , return_raw , logger ):
88
103
# Wrapper to handle exceptions from the oracle function
89
104
def wrapped_oracle (ciphertext : bytes ):
90
105
try :
@@ -105,15 +120,51 @@ def plaintext_callback(plaintext: bytes):
105
120
plaintext = convert_to_bytes (plaintext , null_byte )
106
121
logger .info (f'plaintext: { plaintext } ' )
107
122
108
- plaintext = solve (ciphertext , block_size , wrapped_oracle , num_threads ,
123
+ plaintext = solve (payload , block_size , wrapped_oracle , num_threads ,
109
124
result_callback , plaintext_callback )
110
125
111
126
if not return_raw :
112
127
plaintext = convert_to_bytes (plaintext , null_byte )
113
128
plaintext = remove_padding (plaintext )
114
129
115
- return plaintext
116
130
131
+ def decrypt (payload , block_size , oracle , num_threads , null_byte , return_raw , logger ):
132
+ # Wrapper to handle exceptions from the oracle function
133
+ def wrapped_oracle (ciphertext : bytes ):
134
+ try :
135
+ return oracle (ciphertext )
136
+ except Exception as e :
137
+ logger .error (f'error in oracle with { ciphertext !r} , { e } ' )
138
+ logger .debug ('error details: {}' .format (traceback .format_exc ()))
139
+ return False
140
+
141
+ def result_callback (result : ResultType ):
142
+ if isinstance (result , Fail ):
143
+ if result .is_critical :
144
+ logger .critical (result .message )
145
+ else :
146
+ logger .error (result .message )
147
+
148
+ def plaintext_callback (plaintext : bytes ):
149
+ plaintext = convert_to_bytes (plaintext , null_byte )
150
+ logger .info (f'plaintext: { plaintext } ' )
151
+
152
+ def blocks (data : bytes ):
153
+ return [data [i :(i + block_size )] for i in range (0 , len (data ), block_size )]
154
+
155
+ def bytes_xor (byte_string_1 : bytes , byte_string_2 : bytes ):
156
+ return bytes ([_a ^ _b for _a , _b in zip (byte_string_1 , byte_string_2 )])
157
+
158
+ plaintext_blocks = blocks (payload )
159
+ ciphertext_blocks = [null_byte * block_size for i in range (len (plaintext_blocks )+ 1 )]
160
+
161
+ for index in range (len (plaintext_blocks )- 1 , - 1 , - 1 ):
162
+ plaintext = solve (b'\x00 ' * block_size + ciphertext_blocks [index + 1 ], block_size , wrapped_oracle ,
163
+ num_threads , result_callback , plaintext_callback )
164
+ ciphertext_blocks [i ] = bytes_xor (plaintext_blocks [index ], plaintext )
165
+
166
+ ciphertext = b'' .join (ciphertext_blocks )
167
+ return ciphertext
117
168
118
169
def get_logger ():
119
170
logger = logging .getLogger ('padding_oracle' )
0 commit comments