|
1 | 1 | from io import DEFAULT_BUFFER_SIZE, SEEK_SET |
2 | | -from lzma import FORMAT_XZ, LZMADecompressor, LZMAError |
| 2 | +from lzma import FORMAT_XZ, LZMACompressor, LZMADecompressor, LZMAError |
3 | 3 |
|
4 | | -from xz.common import XZError, create_xz_header, create_xz_index_footer |
| 4 | +from xz.common import ( |
| 5 | + XZError, |
| 6 | + create_xz_header, |
| 7 | + create_xz_index_footer, |
| 8 | + parse_xz_footer, |
| 9 | + parse_xz_index, |
| 10 | +) |
5 | 11 | from xz.io import IOAbstract, IOCombiner, IOStatic |
6 | 12 |
|
7 | 13 |
|
8 | | -class XZBlock(IOAbstract): |
9 | | - compressed_read_size = DEFAULT_BUFFER_SIZE |
| 14 | +class BlockRead: |
| 15 | + read_size = DEFAULT_BUFFER_SIZE |
10 | 16 |
|
11 | 17 | def __init__(self, fileobj, check, unpadded_size, uncompressed_size): |
12 | | - super().__init__(uncompressed_size) |
13 | | - self.compressed_fileobj = IOCombiner( |
| 18 | + self.length = uncompressed_size |
| 19 | + self.fileobj = IOCombiner( |
14 | 20 | IOStatic(create_xz_header(check)), |
15 | 21 | fileobj, |
16 | 22 | IOStatic( |
17 | 23 | create_xz_index_footer(check, [(unpadded_size, uncompressed_size)]) |
18 | 24 | ), |
19 | 25 | ) |
20 | | - self._decompressor_reset() |
| 26 | + self.reset() |
21 | 27 |
|
22 | | - def _decompressor_reset(self): |
23 | | - self.compressed_fileobj.seek(0, SEEK_SET) |
| 28 | + def reset(self): |
| 29 | + self.fileobj.seek(0, SEEK_SET) |
| 30 | + self.pos = 0 |
24 | 31 | self.decompressor = LZMADecompressor(format=FORMAT_XZ) |
25 | 32 |
|
26 | | - def _decompressor_read(self, size): |
| 33 | + def decompress(self, pos, size): |
| 34 | + if pos < self.pos: |
| 35 | + self.reset() |
| 36 | + |
| 37 | + skip_before = pos - self.pos |
| 38 | + |
27 | 39 | # pylint: disable=using-constant-test |
28 | 40 | if self.decompressor.eof: |
29 | 41 | raise XZError("block: decompressor eof") |
| 42 | + |
30 | 43 | if self.decompressor.needs_input: |
31 | | - data_input = self.compressed_fileobj.read(self.compressed_read_size) |
| 44 | + data_input = self.fileobj.read(self.read_size) |
32 | 45 | if not data_input: |
33 | 46 | raise XZError("block: data eof") |
34 | 47 | else: |
35 | 48 | data_input = b"" |
36 | | - return self.decompressor.decompress(data_input, size) |
37 | | - |
38 | | - def seek(self, *args): |
39 | | - old_pos = self._pos |
40 | | - super().seek(*args) |
41 | | - pos_diff = self._pos - old_pos |
42 | | - if pos_diff < 0: |
43 | | - self._decompressor_reset() |
44 | | - old_pos = 0 |
45 | | - pos_diff = self._pos |
46 | | - if pos_diff > 0: |
47 | | - self._pos = old_pos |
48 | | - self.read(pos_diff) |
49 | 49 |
|
50 | | - def _read(self, size): |
51 | | - try: |
52 | | - data_output = self._decompressor_read(size) |
| 50 | + data_output = self.decompressor.decompress(data_input, skip_before + size) |
| 51 | + self.pos += len(data_output) |
| 52 | + |
| 53 | + if self.pos == self.length: |
| 54 | + # we reached the end of the block |
| 55 | + # according to the XZ specification, we must check the |
| 56 | + # remaining bytes of the block; this is mainly performed by the |
| 57 | + # decompressor itself when we consume it |
| 58 | + while not self.decompressor.eof: |
| 59 | + if self.decompress(self.pos, 1): |
| 60 | + raise LZMAError("Corrupt input data") |
| 61 | + |
| 62 | + return data_output[skip_before:] |
| 63 | + |
| 64 | + |
| 65 | +class BlockWrite: |
| 66 | + def __init__(self, fileobj, check, preset, filters): |
| 67 | + self.fileobj = fileobj |
| 68 | + self.check = check |
| 69 | + self.compressor = LZMACompressor(FORMAT_XZ, check, preset, filters) |
| 70 | + self.pos = 0 |
| 71 | + if self.compressor.compress(b"") != create_xz_header(check): |
| 72 | + raise XZError("block: compressor header") |
| 73 | + |
| 74 | + def _write(self, data): |
| 75 | + if data: |
| 76 | + self.fileobj.seek(self.pos) |
| 77 | + self.fileobj.write(data) |
| 78 | + self.pos += len(data) |
| 79 | + |
| 80 | + def compress(self, data): |
| 81 | + self._write(self.compressor.compress(data)) |
| 82 | + |
| 83 | + def finish(self): |
| 84 | + data = self.compressor.flush() |
| 85 | + |
| 86 | + # footer |
| 87 | + check, backward_size = parse_xz_footer(data[-12:]) |
| 88 | + if check != self.check: |
| 89 | + raise XZError("block: compressor footer check") |
53 | 90 |
|
54 | | - if self._pos + len(data_output) == self._length: |
55 | | - # we reached the end of the block |
56 | | - # according to the XZ specification, we must check the |
57 | | - # remaining bytes of the block; this is mainly performed by the |
58 | | - # decompressor itself when we consume it |
59 | | - while not self.decompressor.eof: |
60 | | - if self._decompressor_read(1): |
61 | | - raise LZMAError("Corrupt input data") |
| 91 | + # index |
| 92 | + records = parse_xz_index(data[-12 - backward_size : -12]) |
| 93 | + if len(records) != 1: |
| 94 | + raise XZError("block: compressor index records length") |
62 | 95 |
|
63 | | - return data_output |
| 96 | + # remaining block data |
| 97 | + self._write(data[: -12 - backward_size]) |
64 | 98 |
|
| 99 | + return records[0] # (unpadded_size, uncompressed_size) |
| 100 | + |
| 101 | + |
| 102 | +class XZBlock(IOAbstract): |
| 103 | + def __init__( |
| 104 | + self, |
| 105 | + fileobj, |
| 106 | + check, |
| 107 | + unpadded_size, |
| 108 | + uncompressed_size, |
| 109 | + preset=None, |
| 110 | + filters=None, |
| 111 | + ): |
| 112 | + super().__init__(uncompressed_size) |
| 113 | + self.fileobj = fileobj |
| 114 | + self.check = check |
| 115 | + self.preset = preset |
| 116 | + self.filters = filters |
| 117 | + self.unpadded_size = unpadded_size |
| 118 | + self.operation = None |
| 119 | + |
| 120 | + @property |
| 121 | + def uncompressed_size(self): |
| 122 | + return self._length |
| 123 | + |
| 124 | + def _read(self, size): |
| 125 | + # enforce read mode |
| 126 | + if not isinstance(self.operation, BlockRead): |
| 127 | + self._write_end() |
| 128 | + self.operation = BlockRead( |
| 129 | + self.fileobj, |
| 130 | + self.check, |
| 131 | + self.unpadded_size, |
| 132 | + self.uncompressed_size, |
| 133 | + ) |
| 134 | + |
| 135 | + # read data |
| 136 | + try: |
| 137 | + return self.operation.decompress(self._pos, size) |
65 | 138 | except LZMAError as ex: |
66 | 139 | raise XZError(f"block: error while decompressing: {ex}") from ex |
| 140 | + |
| 141 | + def writable(self): |
| 142 | + return isinstance(self.operation, BlockWrite) or not self._length |
| 143 | + |
| 144 | + def _write(self, data): |
| 145 | + # enforce write mode |
| 146 | + if not isinstance(self.operation, BlockWrite): |
| 147 | + self.operation = BlockWrite( |
| 148 | + self.fileobj, |
| 149 | + self.check, |
| 150 | + self.preset, |
| 151 | + self.filters, |
| 152 | + ) |
| 153 | + |
| 154 | + # write data |
| 155 | + self.operation.compress(data) |
| 156 | + return len(data) |
| 157 | + |
| 158 | + def _write_after(self): |
| 159 | + if isinstance(self.operation, BlockWrite): |
| 160 | + self.unpadded_size, uncompressed_size = self.operation.finish() |
| 161 | + if uncompressed_size != self.uncompressed_size: |
| 162 | + raise XZError("block: compressor uncompressed size") |
| 163 | + self.operation = None |
| 164 | + |
| 165 | + def _truncate(self, size): |
| 166 | + # thanks to the writable method, we are sure that length is zero |
| 167 | + # so we don't need to handle the case of truncating in middle of the block |
| 168 | + self.seek(size) |
| 169 | + self.write(b"") |
0 commit comments