1
- import logging
2
- import base64
3
- import urllib .parse
1
+ '''
2
+ Copyright (c) 2020 Yuankui Lee
3
+
4
+ Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ of this software and associated documentation files (the "Software"), to deal
6
+ in the Software without restriction, including without limitation the rights
7
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8
+ copies of the Software, and to permit persons to whom the Software is
9
+ furnished to do so, subject to the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be included in all
12
+ copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20
+ SOFTWARE.
21
+ '''
4
22
23
+ import logging
5
24
from typing import Union , Callable
6
25
from concurrent .futures import ThreadPoolExecutor
7
26
27
+ from .encoding import *
8
28
9
29
__all__ = [
10
- 'base64_encode' , 'base64_decode' ,
11
- 'urlencode' , 'urldecode' ,
12
30
'padding_oracle' ,
13
31
'remove_padding'
14
32
]
15
33
16
34
17
- def _to_bytes (data : Union [str , bytes ]):
18
- if isinstance (data , str ):
19
- data = data .encode ()
20
- assert isinstance (data , bytes )
21
- return data
22
-
23
- def _to_str (data ):
24
- if isinstance (data , bytes ):
25
- data = data .decode ()
26
- elif isinstance (data , str ):
27
- pass
28
- else :
29
- data = str (data )
30
- return data
31
-
32
-
33
- def base64_decode (data : Union [str , bytes ]) -> bytes :
34
- data = _to_bytes (data )
35
- return base64 .b64decode (data )
36
-
37
- def base64_encode (data : Union [str , bytes ]) -> str :
38
- data = _to_bytes (data )
39
- return base64 .b64encode (data ).decode ()
40
-
41
- def urlencode (data : Union [str , bytes ]) -> str :
42
- data = _to_bytes (data )
43
- return urllib .parse .quote (data )
44
-
45
- def urldecode (data : str ) -> bytes :
46
- data = _to_str (data )
47
- return urllib .parse .unquote_plus (data )
48
-
49
35
def remove_padding (data : Union [str , bytes ]):
50
36
data = _to_bytes (data )
51
37
return data [:- data [- 1 ]]
52
38
53
39
54
-
55
-
56
40
def _dummy_oracle (cipher : bytes ) -> bool :
57
41
raise NotImplementedError ('You must implement the oracle function' )
58
42
59
43
60
44
def padding_oracle (cipher : bytes ,
61
45
block_size : int ,
62
- oracle : Callable [[bytes ], bool ]= _dummy_oracle ,
63
- num_threads : int = 1 ,
64
- log_level : int = logging .INFO ,
65
- null : bytes = b' ' ) -> bytes :
46
+ oracle : Callable [[bytes ], bool ] = _dummy_oracle ,
47
+ num_threads : int = 1 ,
48
+ log_level : int = logging .INFO ,
49
+ null : bytes = b' ' ) -> bytes :
66
50
# Check the oracle function
67
51
assert callable (oracle ), 'the oracle function should be callable'
68
52
assert oracle .__code__ .co_argcount == 1 , 'expect oracle function with only 1 argument'
69
53
assert len (cipher ) % block_size == 0 , 'cipher length should be multiple of block size'
70
-
54
+
71
55
logger = logging .getLogger ('padding_oracle' )
72
56
logger .setLevel (log_level )
73
57
formatter = logging .Formatter ('[%(asctime)s][%(levelname)s] %(message)s' )
@@ -84,52 +68,54 @@ def _update_plaintext(i: int, c: bytes):
84
68
logger .info ('plaintext: {}' .format (b'' .join (plaintext )))
85
69
86
70
oracle_executor = ThreadPoolExecutor (max_workers = num_threads )
87
-
71
+
88
72
def _block_decrypt_task (i , prev : bytes , block : bytes ):
89
73
logger .debug ('task={} prev={} block={}' .format (i , prev , block ))
90
74
guess_list = list (prev )
91
-
75
+
92
76
for j in range (1 , block_size + 1 ):
93
77
oracle_hits = []
94
78
oracle_futures = {}
95
-
79
+
96
80
for k in range (256 ):
97
81
if i == len (blocks ) - 1 and j == 1 and k == prev [- j ]:
98
82
# skip the last padding byte if it is identical to the original cipher
99
83
continue
100
-
84
+
101
85
test_list = guess_list .copy ()
102
86
test_list [- j ] = k
103
- oracle_futures [k ] = oracle_executor .submit (oracle , bytes (test_list ) + block )
104
-
87
+ oracle_futures [k ] = oracle_executor .submit (
88
+ oracle , bytes (test_list ) + block )
89
+
105
90
for k , future in oracle_futures .items ():
106
91
if future .result ():
107
92
oracle_hits .append (k )
108
-
109
- logger .debug ('oracles at block[{}][{}] -> {}' .format (i , block_size - j , oracle_hits ))
110
-
93
+
94
+ logger .debug (
95
+ 'oracles at block[{}][{}] -> {}' .format (i , block_size - j , oracle_hits ))
96
+
111
97
if len (oracle_hits ) != 1 :
112
98
logfmt = 'at block[{}][{}]: expect only one positive result, got {}. (skipped)'
113
99
logger .error (logfmt .format (i , block_size - j , len (oracle_hits )))
114
100
return
115
-
101
+
116
102
guess_list [- j ] = oracle_hits [0 ]
117
-
103
+
118
104
p = guess_list [- j ] ^ j ^ prev [- j ]
119
105
_update_plaintext (i * block_size - j , bytes ([p ]))
120
-
106
+
121
107
for n in range (j ):
122
108
guess_list [- n - 1 ] ^= j
123
109
guess_list [- n - 1 ] ^= j + 1
124
-
110
+
125
111
blocks = []
126
-
112
+
127
113
for i in range (0 , len (cipher ), block_size ):
128
114
j = i + block_size
129
115
blocks .append (cipher [i :j ])
130
-
116
+
131
117
logger .debug ('blocks: {}' .format (blocks ))
132
-
118
+
133
119
with ThreadPoolExecutor () as executor :
134
120
futures = []
135
121
for i in reversed (range (1 , len (blocks ))):
@@ -138,7 +124,7 @@ def _block_decrypt_task(i, prev: bytes, block: bytes):
138
124
futures .append (executor .submit (_block_decrypt_task , i , prev , block ))
139
125
for future in futures :
140
126
future .result ()
141
-
127
+
142
128
oracle_executor .shutdown ()
143
-
129
+
144
130
return b'' .join (plaintext )
0 commit comments