22
22
23
23
import asyncio
24
24
from concurrent .futures import ThreadPoolExecutor
25
- from typing import Any , Callable , NamedTuple , Set , Union , List
25
+ from typing import Callable , NamedTuple , Set , Union , List
26
26
27
27
from .encoding import to_bytes
28
28
33
33
'remove_padding' ,
34
34
]
35
35
36
+
36
37
class Pass (NamedTuple ):
37
38
block_index : int
38
39
index : int
39
40
byte : int
40
41
42
+
41
43
class Fail (NamedTuple ):
42
44
block_index : int
43
45
message : str
44
46
is_critical : bool = False
45
47
48
+
46
49
class Done (NamedTuple ):
47
50
block_index : int
48
51
C0 : List [int ]
@@ -54,26 +57,28 @@ class Done(NamedTuple):
54
57
OracleFunc = Callable [[bytes ], bool ]
55
58
ResultCallback = Callable [[ResultType ], bool ]
56
59
PlainTextCallback = Callable [[List [int ]], bool ]
57
-
60
+
58
61
59
62
class Context (NamedTuple ):
60
63
block_size : int
61
64
oracle : OracleFunc
62
-
63
- executor : ThreadPoolExecutor
65
+
66
+ executor : ThreadPoolExecutor
64
67
loop : asyncio .AbstractEventLoop
65
-
68
+
66
69
tasks : Set [asyncio .Task [ResultType ]]
67
70
68
71
latest_plaintext : List [int ]
69
72
plaintext : List [int ]
70
-
73
+
71
74
result_callback : ResultCallback
72
75
plaintext_callback : PlainTextCallback
73
76
77
+
74
78
def dummy_callback (* a , ** ka ):
75
79
pass
76
80
81
+
77
82
def solve (ciphertext : bytes ,
78
83
block_size : int ,
79
84
oracle : OracleFunc ,
@@ -87,6 +92,7 @@ def solve(ciphertext: bytes,
87
92
result_callback , plaintext_callback )
88
93
return loop .run_until_complete (future )
89
94
95
+
90
96
async def solve_async (ciphertext : bytes ,
91
97
block_size : int ,
92
98
oracle : OracleFunc ,
@@ -96,43 +102,47 @@ async def solve_async(ciphertext: bytes,
96
102
) -> List [int ]:
97
103
98
104
ciphertext = list (ciphertext )
99
- assert len (ciphertext ) % block_size == 0 , \
100
- 'ciphertext length must be a multiple of block_size'
101
- assert len (ciphertext ) // block_size > 1 , \
102
- 'cannot solve with only one block'
105
+
106
+ if not len (ciphertext ) % block_size == 0 :
107
+ raise ValueError ('ciphertext length must be a multiple of block_size' )
108
+ if not len (ciphertext ) // block_size > 1 :
109
+ raise ValueError ('cannot solve with only one block' )
103
110
104
111
ctx = create_solve_context (ciphertext , block_size , oracle , parallel ,
105
112
result_callback , plaintext_callback )
106
113
107
114
while True :
108
- done_tasks , _ = await asyncio .wait (ctx .tasks , return_when = asyncio .FIRST_COMPLETED )
109
-
115
+ done_tasks , _ = await asyncio .wait (ctx .tasks ,
116
+ return_when = asyncio .FIRST_COMPLETED )
117
+
110
118
for task in done_tasks :
111
119
result = await task
112
-
120
+
113
121
ctx .result_callback (result )
114
122
ctx .tasks .remove (task )
115
-
123
+
116
124
if isinstance (result , Pass ):
117
- update_latest_plaintext (ctx , result .block_index , result .index , result .byte )
125
+ update_latest_plaintext (
126
+ ctx , result .block_index , result .index , result .byte )
118
127
if isinstance (result , Done ):
119
128
update_plaintext (ctx , result .block_index , result .C0 , result .X1 )
120
-
129
+
121
130
if len (ctx .tasks ) == 0 :
122
131
break
123
-
132
+
124
133
# Check if any block failed
125
134
error_block_indices = set ()
126
-
135
+
127
136
for i , byte in enumerate (ctx .plaintext ):
128
137
if byte is None :
129
138
error_block_indices .add (i // block_size + 1 )
130
-
139
+
131
140
for idx in error_block_indices :
132
141
result_callback (Fail (idx , f'cannot decrypt cipher block { idx } ' , True ))
133
-
142
+
134
143
return ctx .plaintext
135
144
145
+
136
146
def create_solve_context (ciphertext , block_size , oracle , parallel ,
137
147
result_callback , plaintext_callback ) -> Context :
138
148
tasks = set ()
@@ -143,30 +153,27 @@ def create_solve_context(ciphertext, block_size, oracle, parallel,
143
153
144
154
plaintext = [None ] * (len (cipher_blocks ) - 1 ) * block_size
145
155
latest_plaintext = plaintext .copy ()
146
-
156
+
147
157
executor = ThreadPoolExecutor (parallel )
148
158
loop = asyncio .get_running_loop ()
149
159
ctx = Context (block_size , oracle , executor , loop , tasks ,
150
160
latest_plaintext , plaintext ,
151
161
result_callback , plaintext_callback )
152
-
162
+
153
163
for i in range (1 , len (cipher_blocks )):
154
164
run_block_task (ctx , i , cipher_blocks [i - 1 ], cipher_blocks [i ], [])
155
165
156
166
return ctx
157
167
168
+
158
169
def run_block_task (ctx : Context , block_index , C0 , C1 , X1 ):
159
170
future = solve_block (ctx , block_index , C0 , C1 , X1 )
160
171
task = ctx .loop .create_task (future )
161
172
ctx .tasks .add (task )
162
173
163
- async def solve_block (
164
- ctx : Context ,
165
- block_index : int ,
166
- C0 : List [int ],
167
- C1 : List [int ],
168
- X1 : List [int ] = [],
169
- ) -> ResultType :
174
+
175
+ async def solve_block (ctx : Context , block_index : int , C0 : List [int ],
176
+ C1 : List [int ], X1 : List [int ] = []) -> ResultType :
170
177
# X1 = decrypt(C1)
171
178
# P1 = xor(C0, X1)
172
179
@@ -195,46 +202,56 @@ async def solve_block(
195
202
for byte in hits :
196
203
X1_test = [byte ^ padding , * X1 ]
197
204
run_block_task (ctx , block_index , C0 , C1 , X1_test )
198
-
205
+
199
206
return Pass (block_index , index , byte ^ padding ^ C0 [index ])
200
207
201
- async def get_oracle_hits (ctx : Context , C0 : List [int ], C1 : List [int ], index : int ):
202
-
208
+
209
+ async def get_oracle_hits (ctx : Context , C0 : List [int ], C1 : List [int ],
210
+ index : int ):
211
+
203
212
C0 = C0 .copy ()
204
213
futures = {}
205
-
214
+
206
215
for byte in range (256 ):
207
216
C0 [index ] = byte
208
217
ciphertext = bytes (C0 + C1 )
209
218
futures [byte ] = ctx .loop .run_in_executor (
210
219
ctx .executor , ctx .oracle , ciphertext )
211
-
220
+
212
221
hits = []
213
-
222
+
214
223
for byte , future in futures .items ():
215
224
is_valid = await future
216
225
if is_valid :
217
226
hits .append (byte )
218
-
227
+
219
228
return hits
220
229
221
- def update_latest_plaintext (ctx : Context , block_index : int , index : int , byte : int ):
230
+
231
+ def update_latest_plaintext (ctx : Context , block_index : int , index : int ,
232
+ byte : int ):
233
+
222
234
i = (block_index - 1 ) * ctx .block_size + index
223
235
ctx .latest_plaintext [i ] = byte
224
236
ctx .plaintext_callback (ctx .latest_plaintext )
225
237
226
- def update_plaintext (ctx : Context , block_index : int , C0 : List [int ], X1 : List [int ]):
238
+
239
+ def update_plaintext (ctx : Context , block_index : int , C0 : List [int ],
240
+ X1 : List [int ]):
241
+
227
242
assert len (C0 ) == len (X1 ) == ctx .block_size
228
243
block = compute_plaintext (C0 , X1 )
229
-
244
+
230
245
i = (block_index - 1 ) * ctx .block_size
231
246
ctx .latest_plaintext [i :i + ctx .block_size ] = block
232
247
ctx .plaintext [i :i + ctx .block_size ] = block
233
248
ctx .plaintext_callback (ctx .plaintext )
234
249
250
+
235
251
def compute_plaintext (C0 : List [int ], X1 : List [int ]):
236
252
return [c ^ x for c , x in zip (C0 , X1 )]
237
253
254
+
238
255
def convert_to_bytes (byte_list : List [int ], replacement = b' ' ):
239
256
'''
240
257
Convert a list of int into bytes, replace invalid byte with replacement.
@@ -249,6 +266,7 @@ def convert_to_bytes(byte_list: List[int], replacement=b' '):
249
266
byte_list [i ] = byte
250
267
return bytes (byte_list )
251
268
269
+
252
270
def remove_padding (data : Union [str , bytes , List [int ]]) -> bytes :
253
271
'''
254
272
Remove PKCS#7 padding bytes.
0 commit comments