|
| 1 | +''' |
| 2 | +Copyright (c) 2020 Yuankui Lee |
| 3 | +
|
| 4 | +Permission is hereby granted, free of charge, to any person obtaining a copy |
| 5 | +of this software and associated documentation files (the "Software"), to deal |
| 6 | +in the Software without restriction, including without limitation the rights |
| 7 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 8 | +copies of the Software, and to permit persons to whom the Software is |
| 9 | +furnished to do so, subject to the following conditions: |
| 10 | +
|
| 11 | +The above copyright notice and this permission notice shall be included in all |
| 12 | +copies or substantial portions of the Software. |
| 13 | +
|
| 14 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 15 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 16 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 17 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 18 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 19 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 20 | +SOFTWARE. |
| 21 | +''' |
| 22 | + |
| 23 | +import logging |
| 24 | +import traceback |
| 25 | +from typing import Callable, Union, List, Optional, Generator, Tuple |
| 26 | +from types import ModuleType |
| 27 | +from queue import Empty as QueueEmpty |
| 28 | +import multiprocessing.dummy |
| 29 | + |
| 30 | +from .encoding import to_bytes |
| 31 | + |
| 32 | + |
| 33 | +__all__ = [ |
| 34 | + 'Solver', 'solve', |
| 35 | + 'remove_padding', |
| 36 | + 'plaintext_list_to_bytes' |
| 37 | +] |
| 38 | + |
| 39 | + |
| 40 | +def remove_padding(data: Union[str, bytes]): |
| 41 | + ''' |
| 42 | + Remove PKCS#7 padding bytes. |
| 43 | +
|
| 44 | + Args: |
| 45 | + data (str | bytes) |
| 46 | +
|
| 47 | + Returns: |
| 48 | + data with padding removed (bytes) |
| 49 | + ''' |
| 50 | + data = to_bytes(data) |
| 51 | + return data[:-data[-1]] |
| 52 | + |
| 53 | + |
| 54 | +def plaintext_list_to_bytes(plaintext_list, unknown=b' '): |
| 55 | + plaintext_bytes = bytes(unknown if b is None else b |
| 56 | + for b in plaintext_list) |
| 57 | + return plaintext_bytes |
| 58 | + |
| 59 | + |
| 60 | +def solve(**kwargs): |
| 61 | + cipher = kwargs.pop('cipher') |
| 62 | + unknown = kwargs.pop('unknown', b' ') |
| 63 | + solver = Solver(**kwargs) |
| 64 | + plaintext = plaintext_list_to_bytes(solver.solve(cipher)) |
| 65 | + plaintext = remove_padding(plaintext) |
| 66 | + return plaintext |
| 67 | + |
| 68 | + |
| 69 | +class Solver: |
| 70 | + block_size: int = 16 # positive integer |
| 71 | + possible_bytes: bytes = bytes(range(256)) # bytes |
| 72 | + num_threads: int = 1 # positive integer |
| 73 | + validator: Optional[Callable[[bytes], bool]] = None # function(bytes) -> bool |
| 74 | + logger: logging.Logger = logging.getLogger(__name__) # Logger |
| 75 | + mp: ModuleType = multiprocessing.dummy # thread-based, or `multiprocessing` for process-based |
| 76 | + |
| 77 | + def __init__(self, |
| 78 | + block_size: int = None, |
| 79 | + possible_bytes: bytes = None, |
| 80 | + num_threads: int = None, |
| 81 | + validator: Callable = None, |
| 82 | + logger: logging.Logger = None, |
| 83 | + mp: ModuleType = None): |
| 84 | + if block_size is not None: |
| 85 | + self.block_size = block_size |
| 86 | + if possible_bytes is not None: |
| 87 | + self.possible_bytes = possible_bytes |
| 88 | + if num_threads is not None: |
| 89 | + self.num_threads = num_threads |
| 90 | + if validator is not None: |
| 91 | + self.validator = validator |
| 92 | + if logger is not None: |
| 93 | + self.logger = logger |
| 94 | + if mp is not None: |
| 95 | + self.mp = mp |
| 96 | + |
| 97 | + def check_params(self): |
| 98 | + assert isinstance(self.block_size, int) and self.block_size > 0, ( |
| 99 | + 'block_size should be a positive integer') |
| 100 | + assert isinstance(self.possible_bytes, |
| 101 | + bytes), 'possible_bytes should be bytes' |
| 102 | + assert isinstance(self.num_threads, int) and self.num_threads > 0, ( |
| 103 | + 'num_threads should be a positive integer') |
| 104 | + assert self.validator is not None and callable(self.validator), ( |
| 105 | + 'please implement the validator function') |
| 106 | + |
| 107 | + def oracle(self, validator): |
| 108 | + self.validator = validator |
| 109 | + |
| 110 | + def solve(self, cipher: bytes, unknown: bytes = b' ') -> List[Optional[int]]: |
| 111 | + plaintext_list = [None] * (len(cipher) - self.block_size) |
| 112 | + unknown = ord(unknown) |
| 113 | + |
| 114 | + for block_index, byte_index, byte in self.iter_solve(cipher): |
| 115 | + index = (block_index - 1) * self.block_size + byte_index |
| 116 | + plaintext_list[index] = byte |
| 117 | + |
| 118 | + self.logger.debug('decrypted list: {!r}'.format(plaintext_list)) |
| 119 | + |
| 120 | + plaintext = bytes( |
| 121 | + unknown if b is None else b for b in plaintext_list) |
| 122 | + self.logger.info('decrypted: {!r}'.format(plaintext)) |
| 123 | + |
| 124 | + return plaintext_list |
| 125 | + |
| 126 | + def iter_solve(self, cipher: bytes): |
| 127 | + # check cipher and divide cipher bytes into blocks |
| 128 | + assert len(cipher) % self.block_size == 0, ( |
| 129 | + 'invalid cipher length: {}'.format(len(cipher))) |
| 130 | + cipher_blocks = [] |
| 131 | + for i in range(0, len(cipher), self.block_size): |
| 132 | + cipher_blocks.append(cipher[i:i + self.block_size]) |
| 133 | + |
| 134 | + self.logger.debug('cipher blocks: {}'.format(cipher_blocks)) |
| 135 | + |
| 136 | + # check other params |
| 137 | + self.check_params() |
| 138 | + |
| 139 | + possible_bytes = set(self.possible_bytes) | set( |
| 140 | + range(1, self.block_size + 1)) |
| 141 | + |
| 142 | + self.logger.debug('creating pool and queue') |
| 143 | + |
| 144 | + pool = self.mp.Pool(self.num_threads) |
| 145 | + queue = self.mp.Queue() |
| 146 | + |
| 147 | + def _decrypt(block_index, block, prefix_bytes, queue): |
| 148 | + prefix_list = list(prefix_bytes) |
| 149 | + |
| 150 | + for n in range(1, self.block_size + 1): |
| 151 | + byte_index = self.block_size - n # byte index in the block |
| 152 | + validate_results = {} # async result handler for validator |
| 153 | + valid_bytes = [] # valid try, expect only one item if vulnerable |
| 154 | + |
| 155 | + for p in possible_bytes: |
| 156 | + b = p ^ n ^ prefix_bytes[-n] |
| 157 | + |
| 158 | + if block_index == len(cipher_blocks) - 1 and n == 1 and b == prefix_bytes[-n]: |
| 159 | + # skip the last padding byte if it is identical to the original cipher |
| 160 | + continue |
| 161 | + |
| 162 | + # modify prefix block and construct the cipher |
| 163 | + test_prefix_list = prefix_list.copy() |
| 164 | + test_prefix_list[-n] = b |
| 165 | + test_cipher = bytes(test_prefix_list) + block |
| 166 | + |
| 167 | + # add and run validation for constructed cipher |
| 168 | + validate_results[b] = pool.apply_async( |
| 169 | + self.validator, (test_cipher, )) |
| 170 | + |
| 171 | + has_exception_in_thread = False |
| 172 | + |
| 173 | + # collect valid bytes from validator results |
| 174 | + for b, result in validate_results.items(): |
| 175 | + is_valid = False |
| 176 | + try: |
| 177 | + is_valid = result.get() |
| 178 | + except: |
| 179 | + # catch exceptions generated in the thread |
| 180 | + self.logger.error('at block {} pos {}, unhandled error in validator:\n{}'.format( |
| 181 | + block_index, byte_index, traceback.format_exc())) |
| 182 | + has_exception_in_thread = True |
| 183 | + if is_valid: |
| 184 | + valid_bytes.append(b) |
| 185 | + |
| 186 | + self.logger.debug('at block {} pos {}, valid bytes are {}'.format( |
| 187 | + block_index, byte_index, valid_bytes)) |
| 188 | + |
| 189 | + if len(valid_bytes) != 1: |
| 190 | + # something goes wrong here, please check the validator |
| 191 | + self.logger.error('at block {} pos {}, expect only one valid byte, got {}'.format( |
| 192 | + block_index, byte_index, len(valid_bytes))) |
| 193 | + return |
| 194 | + elif has_exception_in_thread: |
| 195 | + self.logger.warning( |
| 196 | + 'at block {} pos {}, an exception was ignored') |
| 197 | + |
| 198 | + prefix_list[-n] = valid_bytes[0] |
| 199 | + for i in range(n): |
| 200 | + prefix_list[-i-1] ^= n ^ (n + 1) |
| 201 | + |
| 202 | + decrypted = valid_bytes[0] ^ n ^ prefix_bytes[-n] |
| 203 | + |
| 204 | + self.logger.debug('at block {} pos {}, decrypted a byte {!r}'.format( |
| 205 | + block_index, byte_index, bytes([decrypted]))) |
| 206 | + |
| 207 | + queue.put((block_index, byte_index, decrypted)) |
| 208 | + |
| 209 | + block_procs = [] |
| 210 | + |
| 211 | + for i in reversed(range(1, len(cipher_blocks))): |
| 212 | + prefix_bytes = b''.join(cipher_blocks[:i]) |
| 213 | + block = b''.join(cipher_blocks[i:i+1]) |
| 214 | + |
| 215 | + self.logger.debug( |
| 216 | + 'starting decryption process for block {}'.format(i)) |
| 217 | + p = self.mp.Process(target=_decrypt, args=( |
| 218 | + i, block, prefix_bytes, queue)) |
| 219 | + p.start() |
| 220 | + block_procs.append(p) |
| 221 | + |
| 222 | + while any(p.is_alive() for p in block_procs): |
| 223 | + try: |
| 224 | + yield queue.get(timeout=1) |
| 225 | + except QueueEmpty: |
| 226 | + continue |
| 227 | + |
| 228 | + self.logger.debug('shutting down pool and processes') |
| 229 | + |
| 230 | + for p in block_procs: |
| 231 | + p.join() |
| 232 | + |
| 233 | + pool.terminate() |
| 234 | + pool.join() |
| 235 | + |
| 236 | + self.logger.debug('end solving') |
| 237 | + |
| 238 | + def __repr__(self): |
| 239 | + return '<{}.{} (block_size={}, validator={}, num_threads={}, mp={}, logger={})>'.format( |
| 240 | + self.__module__, self.__class__.__qualname__, |
| 241 | + self.block_size, self.validator.__name__, |
| 242 | + self.num_threads, self.mp.__name__, self.logger) |
| 243 | + |
| 244 | + __str__ = __repr__ |
0 commit comments