diff --git a/Lib/zipfile.py b/Lib/zipfile.py index d99c0d76977750..ebeb6100c1887d 100644 --- a/Lib/zipfile.py +++ b/Lib/zipfile.py @@ -787,19 +787,12 @@ class ZipExtFile(io.BufferedIOBase): def __init__(self, fileobj, mode, zipinfo, pwd=None, close_fileobj=False): + self._zinfo = zipinfo self._fileobj = fileobj self._pwd = pwd self._close_fileobj = close_fileobj self._compress_type = zipinfo.compress_type - self._compress_left = zipinfo.compress_size - self._left = zipinfo.file_size - - self._decompressor = _get_decompressor(self._compress_type) - - self._eof = False - self._readbuffer = b'' - self._offset = 0 self.newlines = None @@ -808,33 +801,37 @@ def __init__(self, fileobj, mode, zipinfo, pwd=None, if hasattr(zipinfo, 'CRC'): self._expected_crc = zipinfo.CRC - self._running_crc = crc32(b'') else: self._expected_crc = None self._seekable = False try: if fileobj.seekable(): - self._orig_compress_start = fileobj.tell() - self._orig_compress_size = zipinfo.compress_size self._orig_file_size = zipinfo.file_size - self._orig_start_crc = self._running_crc self._seekable = True except AttributeError: pass self._decrypter = None - if pwd: - if zipinfo.flag_bits & 0x8: - # compare against the file type from extended local headers - check_byte = (zipinfo._raw_time >> 8) & 0xff - else: - # compare against the CRC otherwise - check_byte = (zipinfo.CRC >> 24) & 0xff - h = self._init_decrypter() - if h != check_byte: - raise RuntimeError("Bad password for file %r" % zipinfo.orig_filename) + # Compress start is the byte after the 'local file header' ie. the + # start of 'encryption header' section, if present, or + # 'file data' otherwise. + self._compress_start = fileobj.tell() + self.read_init() + + def read_init(self): + """Set or reset this ZipExtFile to read from the start of the file""" + self._running_crc = crc32(b'') + self._compress_left = self._zinfo.compress_size + self._left = self._zinfo.file_size + self._eof = False + self._readbuffer = b'' + self._offset = 0 + + if self._pwd: + self._init_decrypter() + self._decompressor = _get_decompressor(self._compress_type) def _init_decrypter(self): self._decrypter = _ZipDecrypter(self._pwd) @@ -845,7 +842,17 @@ def _init_decrypter(self): # and is used to check the correctness of the password. header = self._fileobj.read(12) self._compress_left -= 12 - return self._decrypter(header)[11] + computed_check_byte = self._decrypter(header)[11] + + if self._zinfo.flag_bits & 0x8: + # compare against the file type from extended local headers + check_byte = (self._zinfo._raw_time >> 8) & 0xff + else: + # compare against the CRC otherwise + check_byte = (self._zinfo.CRC >> 24) & 0xff + + if computed_check_byte != check_byte: + raise RuntimeError("Bad password for file %r" % self._zinfo.orig_filename) def __repr__(self): result = ['<%s.%s' % (self.__class__.__module__, @@ -1072,17 +1079,8 @@ def seek(self, offset, whence=0): read_offset = 0 elif read_offset < 0: # Position is before the current position. Reset the ZipExtFile - self._fileobj.seek(self._orig_compress_start) - self._running_crc = self._orig_start_crc - self._compress_left = self._orig_compress_size - self._left = self._orig_file_size - self._readbuffer = b'' - self._offset = 0 - self._decompressor = _get_decompressor(self._compress_type) - self._eof = False - read_offset = new_pos - if self._decrypter is not None: - self._init_decrypter() + self._fileobj.seek(self._compress_start) + self.read_init() while read_offset > 0: read_len = min(self.MAX_SEEK_READ, read_offset) diff --git a/Misc/NEWS.d/next/Library/2021-05-14-04-37-51.bpo-44128.cxicgo.rst b/Misc/NEWS.d/next/Library/2021-05-14-04-37-51.bpo-44128.cxicgo.rst new file mode 100644 index 00000000000000..b6da75e9c0a198 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-05-14-04-37-51.bpo-44128.cxicgo.rst @@ -0,0 +1 @@ +Minor refactor of zipfile.ZipExtFile \ No newline at end of file