|
22 | 22 |
|
23 | 23 | import logging
|
24 | 24 | import traceback
|
25 |
| -from typing import Union, Callable |
26 |
| -from concurrent.futures import ThreadPoolExecutor |
| 25 | +from typing import Optional, Union |
27 | 26 |
|
28 | 27 | from .encoding import to_bytes
|
| 28 | +from .solve import Fail, OracleFunc, ResultType, solve, remove_padding, convert_to_bytes |
29 | 29 |
|
30 | 30 | __all__ = [
|
31 | 31 | 'padding_oracle',
|
32 |
| - 'remove_padding' |
33 | 32 | ]
|
34 | 33 |
|
35 |
| - |
36 |
| -def remove_padding(data: Union[str, bytes]): |
37 |
| - ''' |
38 |
| - Remove PKCS#7 padding bytes. |
39 |
| -
|
40 |
| - Args: |
41 |
| - data (str | bytes) |
42 |
| -
|
43 |
| - Returns: |
44 |
| - data with padding removed (bytes) |
45 |
| - ''' |
46 |
| - data = to_bytes(data) |
47 |
| - return data[:-data[-1]] |
48 |
| - |
49 |
| - |
50 |
| -def _get_logger(): |
51 |
| - logger = logging.getLogger('padding_oracle') |
52 |
| - formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s') |
53 |
| - handler = logging.StreamHandler() |
54 |
| - handler.setFormatter(formatter) |
55 |
| - logger.addHandler(handler) |
56 |
| - return logger |
57 |
| - |
58 |
| - |
59 |
| -def padding_oracle(cipher: bytes, |
| 34 | +def padding_oracle(ciphertext: Union[bytes, str], |
60 | 35 | block_size: int,
|
61 |
| - oracle: Callable[[bytes], bool], |
| 36 | + oracle: OracleFunc, |
62 | 37 | num_threads: int = 1,
|
63 | 38 | log_level: int = logging.INFO,
|
64 |
| - chars=None, |
65 |
| - null: bytes = b' ') -> bytes: |
| 39 | + null_byte: bytes = b' ', |
| 40 | + return_raw: bool = False, |
| 41 | + ) -> bytes: |
66 | 42 | '''
|
67 |
| - Run padding oracle attack to decrypt cipher given a function to check wether the cipher |
68 |
| - can be decrypted successfully. |
| 43 | + Run padding oracle attack to decrypt ciphertext given a function to check wether the |
| 44 | + ciphertext can be decrypted successfully. |
69 | 45 |
|
70 | 46 | Args:
|
71 |
| - cipher (bytes|str) the cipher you want to decrypt |
72 |
| - block_size (int) block size (the cipher length should be multiple of this) |
73 |
| - oracle (function) a function: oracle(cipher: bytes) -> bool |
74 |
| - num_threads (int) how many oracle functions will be run in parallel (default: 1) |
75 |
| - log_level (int) log level (default: logging.INFO) |
76 |
| - chars (bytes|str) possible characters in your plaintext, None for all |
77 |
| - null (bytes|str) the default byte when plaintext are not set (default: b' ') |
| 47 | + cipher (bytes|str) the cipher you want to decrypt |
| 48 | + block_size (int) block size (the cipher length should be multiple of this) |
| 49 | + oracle (function) a function: oracle(cipher: bytes) -> bool |
| 50 | + num_threads (int) how many oracle functions will be run in parallel (default: 1) |
| 51 | + log_level (int) log level (default: logging.INFO) |
| 52 | + null_byte (bytes|str) the default byte when plaintext are not set (default: None) |
| 53 | + return_raw (bool) do not convert plaintext into bytes and unpad (default: False) |
78 | 54 |
|
79 | 55 | Returns:
|
80 |
| - plaintext (bytes) the decrypted plaintext |
| 56 | + plaintext (bytes|List[int]) the decrypted plaintext |
81 | 57 | '''
|
82 | 58 |
|
83 | 59 | # Check args
|
84 | 60 | assert callable(oracle), 'the oracle function should be callable'
|
85 |
| - assert oracle.__code__.co_argcount == 1, 'expect oracle function with only 1 argument' |
86 |
| - assert isinstance(cipher, (bytes, str)), 'cipher should have type bytes' |
| 61 | + assert isinstance(ciphertext, (bytes, str)), 'cipher should have type bytes' |
87 | 62 | assert isinstance(block_size, int), 'block_size should have type int'
|
88 |
| - assert len(cipher) % block_size == 0, 'cipher length should be multiple of block size' |
| 63 | + assert len(ciphertext) % block_size == 0, 'cipher length should be multiple of block size' |
89 | 64 | assert 1 <= num_threads <= 1000, 'num_threads should be in [1, 1000]'
|
90 |
| - assert isinstance(null, (bytes, str)), 'expect null with type bytes or str' |
91 |
| - assert len(null) == 1, 'null byte should have length of 1' |
92 |
| - assert isinstance(chars, (bytes, str)) or chars is None, 'chars should be None or type bytes' |
| 65 | + assert isinstance(null_byte, (bytes, str)), 'expect null with type bytes or str' |
| 66 | + assert len(null_byte) == 1, 'null byte should have length of 1' |
93 | 67 |
|
94 |
| - logger = _get_logger() |
| 68 | + logger = get_logger() |
95 | 69 | logger.setLevel(log_level)
|
96 | 70 |
|
97 |
| - cipher = to_bytes(cipher) |
98 |
| - null = to_bytes(null) |
| 71 | + ciphertext = to_bytes(ciphertext) |
| 72 | + null_byte = to_bytes(null_byte) |
99 | 73 |
|
100 |
| - if chars is None: |
101 |
| - chars = set(range(256)) |
102 |
| - else: |
103 |
| - chars = set(to_bytes(chars)) |
104 |
| - chars |= set(range(1, block_size + 1)) # include PCKS#7 padding bytes |
105 |
| - |
106 |
| - # Wrapper to handle exception from the oracle function |
107 |
| - def _oracle_wrapper(i: int, j: int, cipher: bytes): |
| 74 | + # Wrapper to handle exceptions from the oracle function |
| 75 | + def wrapped_oracle(ciphertext: bytes): |
108 | 76 | try:
|
109 |
| - return oracle(cipher) |
| 77 | + return oracle(ciphertext) |
110 | 78 | except Exception as e:
|
111 |
| - logger.error('unhandled error at block[{}][{}]: {}'.format(i, j, e)) |
112 |
| - logger.debug('error details at block[{}][{}]: {}'.format(i, j, traceback.format_exc())) |
| 79 | + logger.error('error calling oracle with {!r}'.format(ciphertext)) |
| 80 | + logger.debug('error details: {}'.format(traceback.format_exc())) |
113 | 81 | return False
|
114 | 82 |
|
115 |
| - # The plaintext bytes list to store the decrypted data |
116 |
| - plaintext = [null] * (len(cipher) - block_size) |
117 |
| - |
118 |
| - # Update the decrypted plaintext list |
119 |
| - def _update_plaintext(i: int, c: bytes): |
120 |
| - plaintext[i] = c |
121 |
| - logger.info('plaintext: {}'.format(b''.join(plaintext))) |
122 |
| - |
123 |
| - oracle_executor = ThreadPoolExecutor(max_workers=num_threads) |
| 83 | + def result_callback(result: ResultType): |
| 84 | + if isinstance(result, Fail): |
| 85 | + if result.is_critical: |
| 86 | + logger.critical(result.message) |
| 87 | + else: |
| 88 | + logger.error(result.message) |
124 | 89 |
|
125 |
| - # Block decrypting task to be run in parallel |
126 |
| - def _block_decrypt_task(i, prev: bytes, block: bytes): |
127 |
| - logger.debug('task={} prev={} block={}'.format(i, prev, block)) |
128 |
| - guess_list = list(prev) |
| 90 | + def plaintext_callback(plaintext: bytes): |
| 91 | + plaintext = convert_to_bytes(plaintext, null_byte) |
| 92 | + logger.info(f'plaintext: {plaintext}') |
129 | 93 |
|
130 |
| - for j in range(1, block_size + 1): |
131 |
| - oracle_hits = [] |
132 |
| - oracle_futures = {} |
133 |
| - |
134 |
| - for c in chars: |
135 |
| - k = c ^ j ^ prev[-j] |
136 |
| - |
137 |
| - if i == len(blocks) - 1 and j == 1 and k == prev[-j]: |
138 |
| - # skip the last padding byte if it is identical to the original cipher |
139 |
| - continue |
140 |
| - |
141 |
| - test_list = guess_list.copy() |
142 |
| - test_list[-j] = k |
143 |
| - oracle_futures[k] = oracle_executor.submit( |
144 |
| - _oracle_wrapper, i, j, bytes(test_list) + block) |
145 |
| - |
146 |
| - for k, future in oracle_futures.items(): |
147 |
| - if future.result(): |
148 |
| - oracle_hits.append(k) |
149 |
| - |
150 |
| - logger.debug( |
151 |
| - 'oracles at block[{}][{}] -> {}'.format(i, block_size - j, oracle_hits)) |
152 |
| - |
153 |
| - # Number of oracle hits should be 1, or we just ignore this block |
154 |
| - if len(oracle_hits) != 1: |
155 |
| - logfmt = 'at block[{}][{}]: expect only one hit, got {}. (skipped)' |
156 |
| - logger.error(logfmt.format(i, block_size-j, len(oracle_hits))) |
157 |
| - return |
158 |
| - |
159 |
| - guess_list[-j] = oracle_hits[0] |
160 |
| - |
161 |
| - p = guess_list[-j] ^ j ^ prev[-j] |
162 |
| - _update_plaintext(i * block_size - j, bytes([p])) |
163 |
| - |
164 |
| - for n in range(j): |
165 |
| - guess_list[-n-1] ^= j ^ (j + 1) |
166 |
| - |
167 |
| - blocks = [] |
168 |
| - |
169 |
| - for i in range(0, len(cipher), block_size): |
170 |
| - blocks.append(cipher[i:i + block_size]) |
171 |
| - |
172 |
| - logger.debug('blocks: {}'.format(blocks)) |
173 |
| - |
174 |
| - with ThreadPoolExecutor() as executor: |
175 |
| - futures = [] |
176 |
| - for i in reversed(range(1, len(blocks))): |
177 |
| - prev = b''.join(blocks[:i]) |
178 |
| - block = b''.join(blocks[i:i+1]) |
179 |
| - futures.append(executor.submit(_block_decrypt_task, i, prev, block)) |
180 |
| - for future in futures: |
181 |
| - future.result() |
182 |
| - |
183 |
| - oracle_executor.shutdown() |
| 94 | + plaintext = solve(ciphertext, block_size, wrapped_oracle, num_threads, |
| 95 | + result_callback, plaintext_callback) |
| 96 | + |
| 97 | + if not return_raw: |
| 98 | + plaintext = convert_to_bytes(plaintext, null_byte) |
| 99 | + plaintext = remove_padding(plaintext) |
| 100 | + |
| 101 | + return plaintext |
184 | 102 |
|
185 |
| - return b''.join(plaintext) |
| 103 | +def get_logger(): |
| 104 | + logger = logging.getLogger('padding_oracle') |
| 105 | + formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s') |
| 106 | + handler = logging.StreamHandler() |
| 107 | + handler.setFormatter(formatter) |
| 108 | + logger.addHandler(handler) |
| 109 | + return logger |
0 commit comments