Skip to content

Commit 4fc65c5

Browse files
committed
refactor: simplify solve
1 parent 0fe025b commit 4fc65c5

File tree

1 file changed

+52
-60
lines changed

1 file changed

+52
-60
lines changed

src/padding_oracle/solve.py

Lines changed: 52 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222

2323
import asyncio
2424
from concurrent.futures import ThreadPoolExecutor
25-
from typing import Awaitable, Callable, NamedTuple, Set, Union, List
25+
from collections import defaultdict
26+
from typing import (
27+
Optional, Union,
28+
Awaitable, Callable,
29+
NamedTuple, List, Dict, Set,
30+
)
2631

2732
from .encoding import to_bytes
2833

@@ -36,8 +41,7 @@
3641

3742
class Pass(NamedTuple):
3843
block_index: int
39-
index: int
40-
byte: int
44+
solved: List[int]
4145

4246

4347
class Fail(NamedTuple):
@@ -46,13 +50,7 @@ class Fail(NamedTuple):
4650
is_critical: bool = False
4751

4852

49-
class Done(NamedTuple):
50-
block_index: int
51-
C0: List[int]
52-
X1: List[int]
53-
54-
55-
ResultType = Union[Pass, Fail, Done]
53+
ResultType = Union[Pass, Fail]
5654

5755
OracleFunc = Callable[[bytes], bool]
5856
ResultCallback = Callable[[ResultType], bool]
@@ -68,7 +66,7 @@ class Context(NamedTuple):
6866

6967
tasks: Set[Awaitable[ResultType]]
7068

71-
latest_plaintext: List[int]
69+
solved_counts: Dict[int, int]
7270
plaintext: List[int]
7371

7472
result_callback: ResultCallback
@@ -122,10 +120,10 @@ async def solve_async(ciphertext: bytes,
122120
ctx.tasks.remove(task)
123121

124122
if isinstance(result, Pass):
125-
update_latest_plaintext(
126-
ctx, result.block_index, result.index, result.byte)
127-
if isinstance(result, Done):
128-
update_plaintext(ctx, result.block_index, result.C0, result.X1)
123+
if len(result.solved) >= ctx.solved_counts[result.block_index]:
124+
update_plaintext(ctx, result.block_index, result.solved)
125+
ctx.solved_counts[result.block_index] = len(result.solved)
126+
ctx.plaintext_callback(ctx.plaintext)
129127

130128
if len(ctx.tasks) == 0:
131129
break
@@ -151,59 +149,71 @@ def create_solve_context(ciphertext, block_size, oracle, parallel,
151149
for i in range(0, len(ciphertext), block_size):
152150
cipher_blocks.append(ciphertext[i:i+block_size])
153151

152+
solved_counts = defaultdict(lambda: 0)
153+
154154
plaintext = [None] * (len(cipher_blocks) - 1) * block_size
155-
latest_plaintext = plaintext.copy()
156155

157156
executor = ThreadPoolExecutor(parallel)
158157
loop = asyncio.get_event_loop()
159158
ctx = Context(block_size, oracle, executor, loop, tasks,
160-
latest_plaintext, plaintext,
159+
solved_counts, plaintext,
161160
result_callback, plaintext_callback)
162161

163162
for i in range(1, len(cipher_blocks)):
164-
run_block_task(ctx, i, cipher_blocks[i-1], cipher_blocks[i], [])
163+
add_solve_block_task(ctx, i, cipher_blocks[i-1], cipher_blocks[i], [])
165164

166165
return ctx
167166

168167

169-
def run_block_task(ctx: Context, block_index, C0, C1, X1):
170-
future = solve_block(ctx, block_index, C0, C1, X1)
168+
def add_solve_block_task(ctx: Context, block_index: int, C0: List[int],
169+
C1: List[int], X1_suffix: List[int]):
170+
future = solve_block(ctx, block_index, C0, C1, X1_suffix)
171171
task = ctx.loop.create_task(future)
172172
ctx.tasks.add(task)
173173

174174

175175
async def solve_block(ctx: Context, block_index: int, C0: List[int],
176-
C1: List[int], X1: List[int] = []) -> ResultType:
176+
C1: List[int], X1_suffix: List[int] = []) -> ResultType:
177+
178+
assert len(C0) == ctx.block_size
179+
assert len(C1) == ctx.block_size
180+
assert len(X1_suffix) in range(ctx.block_size + 1)
181+
177182
# X1 = decrypt(C1)
178183
# P1 = xor(C0, X1)
184+
C0_suffix = C0[len(C0)-len(X1_suffix):]
185+
P1_suffix = [c ^ x for c, x in zip(C0_suffix, X1_suffix)]
179186

180-
if len(X1) == ctx.block_size:
181-
return Done(block_index, C0, X1)
187+
if len(P1_suffix) < ctx.block_size:
188+
result = await exploit_oracle(ctx, block_index, C0, C1, X1_suffix)
189+
if isinstance(result, Fail):
190+
return result
182191

183-
assert len(C0) == ctx.block_size
184-
assert len(C1) == ctx.block_size
185-
assert len(X1) in range(ctx.block_size)
192+
return Pass(block_index, P1_suffix)
186193

187-
index = ctx.block_size - len(X1) - 1
188-
padding = len(X1) + 1
194+
195+
async def exploit_oracle(ctx: Context, block_index: int,
196+
C0: List[int], C1: List[int],
197+
X1_suffix: List[int]) -> Optional[Fail]:
198+
index = ctx.block_size - len(X1_suffix) - 1
199+
padding = len(X1_suffix) + 1
189200

190201
C0_test = C0.copy()
191-
for i in range(len(X1)):
192-
C0_test[-i-1] = X1[-i-1] ^ padding
202+
for i in range(len(X1_suffix)):
203+
C0_test[-i-1] = X1_suffix[-i-1] ^ padding
193204
hits = list(await get_oracle_hits(ctx, C0_test, C1, index))
194205

195-
invalid = len(X1) == 0 and len(hits) not in (1, 2)
196-
invalid |= len(X1) > 0 and len(hits) != 1
206+
# Check if the number of hits is invalid
207+
invalid = len(X1_suffix) == 0 and len(hits) not in (1, 2)
208+
invalid |= len(X1_suffix) > 0 and len(hits) != 1
197209
if invalid:
198-
message = 'unexpected number of hits: block={} index={} n={}' \
199-
.format(block_index, index, len(hits))
210+
message = f'invalid number of hits: {len(hits)}'
211+
message = f'{message} (block: {block_index}, byte: {index})'
200212
return Fail(block_index, message)
201213

202214
for byte in hits:
203-
X1_test = [byte ^ padding, *X1]
204-
run_block_task(ctx, block_index, C0, C1, X1_test)
205-
206-
return Pass(block_index, index, byte ^ padding ^ C0[index])
215+
X1_test = [byte ^ padding, *X1_suffix]
216+
add_solve_block_task(ctx, block_index, C0, C1, X1_test)
207217

208218

209219
async def get_oracle_hits(ctx: Context, C0: List[int], C1: List[int],
@@ -228,28 +238,10 @@ async def get_oracle_hits(ctx: Context, C0: List[int], C1: List[int],
228238
return hits
229239

230240

231-
def update_latest_plaintext(ctx: Context, block_index: int, index: int,
232-
byte: int):
233-
234-
i = (block_index - 1) * ctx.block_size + index
235-
ctx.latest_plaintext[i] = byte
236-
ctx.plaintext_callback(ctx.latest_plaintext)
237-
238-
239-
def update_plaintext(ctx: Context, block_index: int, C0: List[int],
240-
X1: List[int]):
241-
242-
assert len(C0) == len(X1) == ctx.block_size
243-
block = compute_plaintext(C0, X1)
244-
245-
i = (block_index - 1) * ctx.block_size
246-
ctx.latest_plaintext[i:i+ctx.block_size] = block
247-
ctx.plaintext[i:i+ctx.block_size] = block
248-
ctx.plaintext_callback(ctx.plaintext)
249-
250-
251-
def compute_plaintext(C0: List[int], X1: List[int]):
252-
return [c ^ x for c, x in zip(C0, X1)]
241+
def update_plaintext(ctx: Context, block_index: int, solved_suffix: List[int]):
242+
j = block_index * ctx.block_size
243+
i = j - len(solved_suffix)
244+
ctx.plaintext[i:j] = solved_suffix
253245

254246

255247
def convert_to_bytes(byte_list: List[int], replacement=b' '):

0 commit comments

Comments
 (0)