22
22
23
23
import asyncio
24
24
from concurrent .futures import ThreadPoolExecutor
25
- from typing import Awaitable , Callable , NamedTuple , Set , Union , List
25
+ from collections import defaultdict
26
+ from typing import (
27
+ Optional , Union ,
28
+ Awaitable , Callable ,
29
+ NamedTuple , List , Dict , Set ,
30
+ )
26
31
27
32
from .encoding import to_bytes
28
33
36
41
37
42
class Pass (NamedTuple ):
38
43
block_index : int
39
- index : int
40
- byte : int
44
+ solved : List [int ]
41
45
42
46
43
47
class Fail (NamedTuple ):
@@ -46,13 +50,7 @@ class Fail(NamedTuple):
46
50
is_critical : bool = False
47
51
48
52
49
- class Done (NamedTuple ):
50
- block_index : int
51
- C0 : List [int ]
52
- X1 : List [int ]
53
-
54
-
55
- ResultType = Union [Pass , Fail , Done ]
53
+ ResultType = Union [Pass , Fail ]
56
54
57
55
OracleFunc = Callable [[bytes ], bool ]
58
56
ResultCallback = Callable [[ResultType ], bool ]
@@ -68,7 +66,7 @@ class Context(NamedTuple):
68
66
69
67
tasks : Set [Awaitable [ResultType ]]
70
68
71
- latest_plaintext : List [ int ]
69
+ solved_counts : Dict [ int , int ]
72
70
plaintext : List [int ]
73
71
74
72
result_callback : ResultCallback
@@ -122,10 +120,10 @@ async def solve_async(ciphertext: bytes,
122
120
ctx .tasks .remove (task )
123
121
124
122
if isinstance (result , Pass ):
125
- update_latest_plaintext (
126
- ctx , result .block_index , result .index , result . byte )
127
- if isinstance ( result , Done ):
128
- update_plaintext ( ctx , result . block_index , result . C0 , result . X1 )
123
+ if len ( result . solved ) >= ctx . solved_counts [ result . block_index ]:
124
+ update_plaintext ( ctx , result .block_index , result .solved )
125
+ ctx . solved_counts [ result . block_index ] = len ( result . solved )
126
+ ctx . plaintext_callback ( ctx . plaintext )
129
127
130
128
if len (ctx .tasks ) == 0 :
131
129
break
@@ -151,59 +149,71 @@ def create_solve_context(ciphertext, block_size, oracle, parallel,
151
149
for i in range (0 , len (ciphertext ), block_size ):
152
150
cipher_blocks .append (ciphertext [i :i + block_size ])
153
151
152
+ solved_counts = defaultdict (lambda : 0 )
153
+
154
154
plaintext = [None ] * (len (cipher_blocks ) - 1 ) * block_size
155
- latest_plaintext = plaintext .copy ()
156
155
157
156
executor = ThreadPoolExecutor (parallel )
158
157
loop = asyncio .get_event_loop ()
159
158
ctx = Context (block_size , oracle , executor , loop , tasks ,
160
- latest_plaintext , plaintext ,
159
+ solved_counts , plaintext ,
161
160
result_callback , plaintext_callback )
162
161
163
162
for i in range (1 , len (cipher_blocks )):
164
- run_block_task (ctx , i , cipher_blocks [i - 1 ], cipher_blocks [i ], [])
163
+ add_solve_block_task (ctx , i , cipher_blocks [i - 1 ], cipher_blocks [i ], [])
165
164
166
165
return ctx
167
166
168
167
169
- def run_block_task (ctx : Context , block_index , C0 , C1 , X1 ):
170
- future = solve_block (ctx , block_index , C0 , C1 , X1 )
168
+ def add_solve_block_task (ctx : Context , block_index : int , C0 : List [int ],
169
+ C1 : List [int ], X1_suffix : List [int ]):
170
+ future = solve_block (ctx , block_index , C0 , C1 , X1_suffix )
171
171
task = ctx .loop .create_task (future )
172
172
ctx .tasks .add (task )
173
173
174
174
175
175
async def solve_block (ctx : Context , block_index : int , C0 : List [int ],
176
- C1 : List [int ], X1 : List [int ] = []) -> ResultType :
176
+ C1 : List [int ], X1_suffix : List [int ] = []) -> ResultType :
177
+
178
+ assert len (C0 ) == ctx .block_size
179
+ assert len (C1 ) == ctx .block_size
180
+ assert len (X1_suffix ) in range (ctx .block_size + 1 )
181
+
177
182
# X1 = decrypt(C1)
178
183
# P1 = xor(C0, X1)
184
+ C0_suffix = C0 [len (C0 )- len (X1_suffix ):]
185
+ P1_suffix = [c ^ x for c , x in zip (C0_suffix , X1_suffix )]
179
186
180
- if len (X1 ) == ctx .block_size :
181
- return Done (block_index , C0 , X1 )
187
+ if len (P1_suffix ) < ctx .block_size :
188
+ result = await exploit_oracle (ctx , block_index , C0 , C1 , X1_suffix )
189
+ if isinstance (result , Fail ):
190
+ return result
182
191
183
- assert len (C0 ) == ctx .block_size
184
- assert len (C1 ) == ctx .block_size
185
- assert len (X1 ) in range (ctx .block_size )
192
+ return Pass (block_index , P1_suffix )
186
193
187
- index = ctx .block_size - len (X1 ) - 1
188
- padding = len (X1 ) + 1
194
+
195
+ async def exploit_oracle (ctx : Context , block_index : int ,
196
+ C0 : List [int ], C1 : List [int ],
197
+ X1_suffix : List [int ]) -> Optional [Fail ]:
198
+ index = ctx .block_size - len (X1_suffix ) - 1
199
+ padding = len (X1_suffix ) + 1
189
200
190
201
C0_test = C0 .copy ()
191
- for i in range (len (X1 )):
192
- C0_test [- i - 1 ] = X1 [- i - 1 ] ^ padding
202
+ for i in range (len (X1_suffix )):
203
+ C0_test [- i - 1 ] = X1_suffix [- i - 1 ] ^ padding
193
204
hits = list (await get_oracle_hits (ctx , C0_test , C1 , index ))
194
205
195
- invalid = len (X1 ) == 0 and len (hits ) not in (1 , 2 )
196
- invalid |= len (X1 ) > 0 and len (hits ) != 1
206
+ # Check if the number of hits is invalid
207
+ invalid = len (X1_suffix ) == 0 and len (hits ) not in (1 , 2 )
208
+ invalid |= len (X1_suffix ) > 0 and len (hits ) != 1
197
209
if invalid :
198
- message = 'unexpected number of hits: block={} index={} n={}' \
199
- . format ( block_index , index , len ( hits ))
210
+ message = f'invalid number of hits: { len ( hits ) } '
211
+ message = f' { message } (block: { block_index } , byte: { index } )'
200
212
return Fail (block_index , message )
201
213
202
214
for byte in hits :
203
- X1_test = [byte ^ padding , * X1 ]
204
- run_block_task (ctx , block_index , C0 , C1 , X1_test )
205
-
206
- return Pass (block_index , index , byte ^ padding ^ C0 [index ])
215
+ X1_test = [byte ^ padding , * X1_suffix ]
216
+ add_solve_block_task (ctx , block_index , C0 , C1 , X1_test )
207
217
208
218
209
219
async def get_oracle_hits (ctx : Context , C0 : List [int ], C1 : List [int ],
@@ -228,28 +238,10 @@ async def get_oracle_hits(ctx: Context, C0: List[int], C1: List[int],
228
238
return hits
229
239
230
240
231
- def update_latest_plaintext (ctx : Context , block_index : int , index : int ,
232
- byte : int ):
233
-
234
- i = (block_index - 1 ) * ctx .block_size + index
235
- ctx .latest_plaintext [i ] = byte
236
- ctx .plaintext_callback (ctx .latest_plaintext )
237
-
238
-
239
- def update_plaintext (ctx : Context , block_index : int , C0 : List [int ],
240
- X1 : List [int ]):
241
-
242
- assert len (C0 ) == len (X1 ) == ctx .block_size
243
- block = compute_plaintext (C0 , X1 )
244
-
245
- i = (block_index - 1 ) * ctx .block_size
246
- ctx .latest_plaintext [i :i + ctx .block_size ] = block
247
- ctx .plaintext [i :i + ctx .block_size ] = block
248
- ctx .plaintext_callback (ctx .plaintext )
249
-
250
-
251
- def compute_plaintext (C0 : List [int ], X1 : List [int ]):
252
- return [c ^ x for c , x in zip (C0 , X1 )]
241
+ def update_plaintext (ctx : Context , block_index : int , solved_suffix : List [int ]):
242
+ j = block_index * ctx .block_size
243
+ i = j - len (solved_suffix )
244
+ ctx .plaintext [i :j ] = solved_suffix
253
245
254
246
255
247
def convert_to_bytes (byte_list : List [int ], replacement = b' ' ):
0 commit comments