Skip to content

Commit 6c450b5

Browse files
committed
style: fix lint errors
1 parent 2807939 commit 6c450b5

File tree

9 files changed

+139
-76
lines changed

9 files changed

+139
-76
lines changed

src/padding_oracle/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,16 @@
2727
to_bytes, to_str,
2828
)
2929
from .legacy import padding_oracle
30+
31+
__all__ = [
32+
'solve',
33+
'convert_to_bytes',
34+
'remove_padding',
35+
'padding_oracle',
36+
'urlencode',
37+
'urldecode',
38+
'base64_encode',
39+
'base64_decode',
40+
'to_bytes',
41+
'to_str',
42+
]

src/padding_oracle/legacy.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,61 +22,75 @@
2222

2323
import logging
2424
import traceback
25-
from typing import Optional, Union
25+
from typing import List, Union
2626

2727
from .encoding import to_bytes
28-
from .solve import Fail, OracleFunc, ResultType, solve, remove_padding, convert_to_bytes
28+
from .solve import (
29+
solve, Fail, OracleFunc, ResultType,
30+
convert_to_bytes, remove_padding)
2931

3032
__all__ = [
3133
'padding_oracle',
3234
]
3335

36+
3437
def padding_oracle(ciphertext: Union[bytes, str],
3538
block_size: int,
3639
oracle: OracleFunc,
3740
num_threads: int = 1,
3841
log_level: int = logging.INFO,
3942
null_byte: bytes = b' ',
4043
return_raw: bool = False,
41-
) -> bytes:
44+
) -> Union[bytes, List[int]]:
4245
'''
43-
Run padding oracle attack to decrypt ciphertext given a function to check wether the
44-
ciphertext can be decrypted successfully.
46+
Run padding oracle attack to decrypt ciphertext given a function to check
47+
wether the ciphertext can be decrypted successfully.
4548
4649
Args:
4750
cipher (bytes|str) the cipher you want to decrypt
48-
block_size (int) block size (the cipher length should be multiple of this)
51+
block_size (int) block size (the cipher length should be
52+
multiple of this)
4953
oracle (function) a function: oracle(cipher: bytes) -> bool
50-
num_threads (int) how many oracle functions will be run in parallel (default: 1)
54+
num_threads (int) how many oracle functions will be run in
55+
parallel (default: 1)
5156
log_level (int) log level (default: logging.INFO)
52-
null_byte (bytes|str) the default byte when plaintext are not set (default: None)
53-
return_raw (bool) do not convert plaintext into bytes and unpad (default: False)
57+
null_byte (bytes|str) the default byte when plaintext are not
58+
set (default: None)
59+
return_raw (bool) do not convert plaintext into bytes and
60+
unpad (default: False)
5461
5562
Returns:
5663
plaintext (bytes|List[int]) the decrypted plaintext
5764
'''
5865

5966
# Check args
60-
assert callable(oracle), 'the oracle function should be callable'
61-
assert isinstance(ciphertext, (bytes, str)), 'cipher should have type bytes'
62-
assert isinstance(block_size, int), 'block_size should have type int'
63-
assert len(ciphertext) % block_size == 0, 'cipher length should be multiple of block size'
64-
assert 1 <= num_threads <= 1000, 'num_threads should be in [1, 1000]'
65-
assert isinstance(null_byte, (bytes, str)), 'expect null with type bytes or str'
66-
assert len(null_byte) == 1, 'null byte should have length of 1'
67+
if not callable(oracle):
68+
raise TypeError('the oracle function should be callable')
69+
if not isinstance(ciphertext, (bytes, str)):
70+
raise TypeError('cipher should have type bytes')
71+
if not isinstance(block_size, int):
72+
raise TypeError('block_size should have type int')
73+
if not len(ciphertext) % block_size == 0:
74+
raise ValueError('cipher length should be multiple of block size')
75+
if not 1 <= num_threads <= 1000:
76+
raise ValueError('num_threads should be in [1, 1000]')
77+
if not isinstance(null_byte, (bytes, str)):
78+
raise TypeError('expect null with type bytes or str')
79+
if not len(null_byte) == 1:
80+
raise ValueError('null byte should have length of 1')
6781

6882
logger = get_logger()
6983
logger.setLevel(log_level)
70-
84+
7185
ciphertext = to_bytes(ciphertext)
7286
null_byte = to_bytes(null_byte)
73-
87+
7488
# Wrapper to handle exceptions from the oracle function
7589
def wrapped_oracle(ciphertext: bytes):
7690
try:
7791
return oracle(ciphertext)
7892
except Exception as e:
79-
logger.error('error calling oracle with {!r}'.format(ciphertext))
93+
logger.error(f'error in oracle with {ciphertext!r}, {e}')
8094
logger.debug('error details: {}'.format(traceback.format_exc()))
8195
return False
8296

@@ -93,13 +107,14 @@ def plaintext_callback(plaintext: bytes):
93107

94108
plaintext = solve(ciphertext, block_size, wrapped_oracle, num_threads,
95109
result_callback, plaintext_callback)
96-
110+
97111
if not return_raw:
98112
plaintext = convert_to_bytes(plaintext, null_byte)
99113
plaintext = remove_padding(plaintext)
100-
114+
101115
return plaintext
102116

117+
103118
def get_logger():
104119
logger = logging.getLogger('padding_oracle')
105120
formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s')

src/padding_oracle/solve.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import asyncio
2424
from concurrent.futures import ThreadPoolExecutor
25-
from typing import Any, Callable, NamedTuple, Set, Union, List
25+
from typing import Callable, NamedTuple, Set, Union, List
2626

2727
from .encoding import to_bytes
2828

@@ -33,16 +33,19 @@
3333
'remove_padding',
3434
]
3535

36+
3637
class Pass(NamedTuple):
3738
block_index: int
3839
index: int
3940
byte: int
4041

42+
4143
class Fail(NamedTuple):
4244
block_index: int
4345
message: str
4446
is_critical: bool = False
4547

48+
4649
class Done(NamedTuple):
4750
block_index: int
4851
C0: List[int]
@@ -54,26 +57,28 @@ class Done(NamedTuple):
5457
OracleFunc = Callable[[bytes], bool]
5558
ResultCallback = Callable[[ResultType], bool]
5659
PlainTextCallback = Callable[[List[int]], bool]
57-
60+
5861

5962
class Context(NamedTuple):
6063
block_size: int
6164
oracle: OracleFunc
62-
63-
executor: ThreadPoolExecutor
65+
66+
executor: ThreadPoolExecutor
6467
loop: asyncio.AbstractEventLoop
65-
68+
6669
tasks: Set[asyncio.Task[ResultType]]
6770

6871
latest_plaintext: List[int]
6972
plaintext: List[int]
70-
73+
7174
result_callback: ResultCallback
7275
plaintext_callback: PlainTextCallback
7376

77+
7478
def dummy_callback(*a, **ka):
7579
pass
7680

81+
7782
def solve(ciphertext: bytes,
7883
block_size: int,
7984
oracle: OracleFunc,
@@ -87,6 +92,7 @@ def solve(ciphertext: bytes,
8792
result_callback, plaintext_callback)
8893
return loop.run_until_complete(future)
8994

95+
9096
async def solve_async(ciphertext: bytes,
9197
block_size: int,
9298
oracle: OracleFunc,
@@ -96,43 +102,47 @@ async def solve_async(ciphertext: bytes,
96102
) -> List[int]:
97103

98104
ciphertext = list(ciphertext)
99-
assert len(ciphertext) % block_size == 0, \
100-
'ciphertext length must be a multiple of block_size'
101-
assert len(ciphertext) // block_size > 1, \
102-
'cannot solve with only one block'
105+
106+
if not len(ciphertext) % block_size == 0:
107+
raise ValueError('ciphertext length must be a multiple of block_size')
108+
if not len(ciphertext) // block_size > 1:
109+
raise ValueError('cannot solve with only one block')
103110

104111
ctx = create_solve_context(ciphertext, block_size, oracle, parallel,
105112
result_callback, plaintext_callback)
106113

107114
while True:
108-
done_tasks, _ = await asyncio.wait(ctx.tasks, return_when=asyncio.FIRST_COMPLETED)
109-
115+
done_tasks, _ = await asyncio.wait(ctx.tasks,
116+
return_when=asyncio.FIRST_COMPLETED)
117+
110118
for task in done_tasks:
111119
result = await task
112-
120+
113121
ctx.result_callback(result)
114122
ctx.tasks.remove(task)
115-
123+
116124
if isinstance(result, Pass):
117-
update_latest_plaintext(ctx, result.block_index, result.index, result.byte)
125+
update_latest_plaintext(
126+
ctx, result.block_index, result.index, result.byte)
118127
if isinstance(result, Done):
119128
update_plaintext(ctx, result.block_index, result.C0, result.X1)
120-
129+
121130
if len(ctx.tasks) == 0:
122131
break
123-
132+
124133
# Check if any block failed
125134
error_block_indices = set()
126-
135+
127136
for i, byte in enumerate(ctx.plaintext):
128137
if byte is None:
129138
error_block_indices.add(i // block_size + 1)
130-
139+
131140
for idx in error_block_indices:
132141
result_callback(Fail(idx, f'cannot decrypt cipher block {idx}', True))
133-
142+
134143
return ctx.plaintext
135144

145+
136146
def create_solve_context(ciphertext, block_size, oracle, parallel,
137147
result_callback, plaintext_callback) -> Context:
138148
tasks = set()
@@ -143,30 +153,27 @@ def create_solve_context(ciphertext, block_size, oracle, parallel,
143153

144154
plaintext = [None] * (len(cipher_blocks) - 1) * block_size
145155
latest_plaintext = plaintext.copy()
146-
156+
147157
executor = ThreadPoolExecutor(parallel)
148158
loop = asyncio.get_running_loop()
149159
ctx = Context(block_size, oracle, executor, loop, tasks,
150160
latest_plaintext, plaintext,
151161
result_callback, plaintext_callback)
152-
162+
153163
for i in range(1, len(cipher_blocks)):
154164
run_block_task(ctx, i, cipher_blocks[i-1], cipher_blocks[i], [])
155165

156166
return ctx
157167

168+
158169
def run_block_task(ctx: Context, block_index, C0, C1, X1):
159170
future = solve_block(ctx, block_index, C0, C1, X1)
160171
task = ctx.loop.create_task(future)
161172
ctx.tasks.add(task)
162173

163-
async def solve_block(
164-
ctx: Context,
165-
block_index: int,
166-
C0: List[int],
167-
C1: List[int],
168-
X1: List[int] = [],
169-
) -> ResultType:
174+
175+
async def solve_block(ctx: Context, block_index: int, C0: List[int],
176+
C1: List[int], X1: List[int] = []) -> ResultType:
170177
# X1 = decrypt(C1)
171178
# P1 = xor(C0, X1)
172179

@@ -195,46 +202,56 @@ async def solve_block(
195202
for byte in hits:
196203
X1_test = [byte ^ padding, *X1]
197204
run_block_task(ctx, block_index, C0, C1, X1_test)
198-
205+
199206
return Pass(block_index, index, byte ^ padding ^ C0[index])
200207

201-
async def get_oracle_hits(ctx: Context, C0: List[int], C1: List[int], index: int):
202-
208+
209+
async def get_oracle_hits(ctx: Context, C0: List[int], C1: List[int],
210+
index: int):
211+
203212
C0 = C0.copy()
204213
futures = {}
205-
214+
206215
for byte in range(256):
207216
C0[index] = byte
208217
ciphertext = bytes(C0 + C1)
209218
futures[byte] = ctx.loop.run_in_executor(
210219
ctx.executor, ctx.oracle, ciphertext)
211-
220+
212221
hits = []
213-
222+
214223
for byte, future in futures.items():
215224
is_valid = await future
216225
if is_valid:
217226
hits.append(byte)
218-
227+
219228
return hits
220229

221-
def update_latest_plaintext(ctx: Context, block_index: int, index: int, byte: int):
230+
231+
def update_latest_plaintext(ctx: Context, block_index: int, index: int,
232+
byte: int):
233+
222234
i = (block_index - 1) * ctx.block_size + index
223235
ctx.latest_plaintext[i] = byte
224236
ctx.plaintext_callback(ctx.latest_plaintext)
225237

226-
def update_plaintext(ctx: Context, block_index: int, C0: List[int], X1: List[int]):
238+
239+
def update_plaintext(ctx: Context, block_index: int, C0: List[int],
240+
X1: List[int]):
241+
227242
assert len(C0) == len(X1) == ctx.block_size
228243
block = compute_plaintext(C0, X1)
229-
244+
230245
i = (block_index - 1) * ctx.block_size
231246
ctx.latest_plaintext[i:i+ctx.block_size] = block
232247
ctx.plaintext[i:i+ctx.block_size] = block
233248
ctx.plaintext_callback(ctx.plaintext)
234249

250+
235251
def compute_plaintext(C0: List[int], X1: List[int]):
236252
return [c ^ x for c, x in zip(C0, X1)]
237253

254+
238255
def convert_to_bytes(byte_list: List[int], replacement=b' '):
239256
'''
240257
Convert a list of int into bytes, replace invalid byte with replacement.
@@ -249,6 +266,7 @@ def convert_to_bytes(byte_list: List[int], replacement=b' '):
249266
byte_list[i] = byte
250267
return bytes(byte_list)
251268

269+
252270
def remove_padding(data: Union[str, bytes, List[int]]) -> bytes:
253271
'''
254272
Remove PKCS#7 padding bytes.

tests/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
import sys, os
1+
import sys
2+
import os
3+
24

35
def get_src_dir():
46
current_dir = os.path.dirname(__file__)
57
src_dir = os.path.join(current_dir, '..', 'src')
68
return src_dir
79

10+
811
sys.path.insert(0, get_src_dir())

0 commit comments

Comments
 (0)