Skip to content

Commit ee45268

Browse files
committed
gh-61206: support zstd in zipimport
1 parent e05182f commit ee45268

File tree

2 files changed

+108
-25
lines changed

2 files changed

+108
-25
lines changed

Lib/test/test_zipimport.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from test.support import import_helper
1616
from test.support import os_helper
1717

18-
from zipfile import ZipFile, ZipInfo, ZIP_STORED, ZIP_DEFLATED
18+
from zipfile import ZipFile, ZipInfo, ZIP_STORED, ZIP_DEFLATED, ZIP_ZSTANDARD
1919

2020
import zipimport
2121
import linecache
@@ -194,19 +194,38 @@ def testAFakeZlib(self):
194194
# occur in that case (builtin modules are always found first),
195195
# so we'll simply skip it then. Bug #765456.
196196
#
197-
if "zlib" in sys.builtin_module_names:
198-
self.skipTest('zlib is a builtin module')
199-
if "zlib" in sys.modules:
200-
del sys.modules["zlib"]
201-
files = {"zlib.py": test_src}
197+
if self.compression == ZIP_DEFLATED:
198+
mod_name = "zlib"
199+
if zipimport._zlib_decompress: # validate attr name
200+
# reset the cached import to avoid test order dependencies
201+
zipimport._zlib_decompress = None # reset cache
202+
elif self.compression == ZIP_ZSTANDARD:
203+
mod_name = "_zstd"
204+
if zipimport._zstd_decompressor_class: # validate attr name
205+
# reset the cached import to avoid test order dependencies
206+
zipimport._zstd_decompressor_class = None
207+
else:
208+
mod_name = "zlib" # the ZIP_STORED case below
209+
210+
if mod_name in sys.builtin_module_names:
211+
self.skipTest(f"{mod_name} is a builtin module")
212+
if mod_name in sys.modules:
213+
del sys.modules[mod_name]
214+
files = {f"{mod_name}.py": test_src}
202215
try:
203-
self.doTest(".py", files, "zlib")
216+
self.doTest(".py", files, mod_name)
204217
except ImportError:
205-
if self.compression != ZIP_DEFLATED:
206-
self.fail("expected test to not raise ImportError")
207-
else:
208218
if self.compression != ZIP_STORED:
209-
self.fail("expected test to raise ImportError")
219+
# Expected - fake compression module can't decompress
220+
pass
221+
else:
222+
self.fail("expected test to not raise ImportError for uncompressed")
223+
else:
224+
if self.compression == ZIP_STORED:
225+
# Expected - no compression needed, so fake module works
226+
pass
227+
else:
228+
self.fail("expected test to raise ImportError for compressed zip with fake compression module")
210229

211230
def testPy(self):
212231
files = {TESTMOD + ".py": test_src}
@@ -1008,10 +1027,15 @@ def assertDataEntry(name):
10081027

10091028

10101029
@support.requires_zlib()
1011-
class CompressedZipImportTestCase(UncompressedZipImportTestCase):
1030+
class DeflateCompressedZipImportTestCase(UncompressedZipImportTestCase):
10121031
compression = ZIP_DEFLATED
10131032

10141033

1034+
@support.requires_zstd()
1035+
class ZStdCompressedZipImportTestCase(UncompressedZipImportTestCase):
1036+
compression = ZIP_ZSTANDARD
1037+
1038+
10151039
class BadFileZipImportTestCase(unittest.TestCase):
10161040
def assertZipFailure(self, filename):
10171041
self.assertRaises(zipimport.ZipImportError,

Lib/zipimport.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -603,11 +603,16 @@ def _read_directory(archive):
603603
)
604604

605605
_importing_zlib = False
606+
_zlib_decompress = None
606607

607608
# Return the zlib.decompress function object, or NULL if zlib couldn't
608609
# be imported. The function is cached when found, so subsequent calls
609610
# don't import zlib again.
610-
def _get_decompress_func():
611+
def _get_zlib_decompress_func():
612+
global _zlib_decompress
613+
if _zlib_decompress:
614+
return _zlib_decompress
615+
611616
global _importing_zlib
612617
if _importing_zlib:
613618
# Someone has a zlib.py[co] in their Zip file
@@ -617,15 +622,62 @@ def _get_decompress_func():
617622

618623
_importing_zlib = True
619624
try:
620-
from zlib import decompress
625+
from zlib import decompress as _zlib_decompress
621626
except Exception:
622627
_bootstrap._verbose_message('zipimport: zlib UNAVAILABLE')
623628
raise ZipImportError("can't decompress data; zlib not available")
624629
finally:
625630
_importing_zlib = False
626631

627632
_bootstrap._verbose_message('zipimport: zlib available')
628-
return decompress
633+
return _zlib_decompress
634+
635+
636+
_importing_zstd = False
637+
_zstd_decompressor_class = None
638+
639+
# Return the _zstd.ZstdDecompressor function object, or NULL if _zstd couldn't
640+
# be imported. The result is cached when found.
641+
def _get_zstd_decompressor_class():
642+
global _zstd_decompressor_class
643+
if _zstd_decompressor_class:
644+
return _zstd_decompressor_class
645+
646+
global _importing_zstd
647+
if _importing_zstd:
648+
# Someone has a _zstd.py[co] in their Zip file
649+
# let's avoid a stack overflow.
650+
_bootstrap._verbose_message("zipimport: zstd UNAVAILABLE")
651+
raise ZipImportError("can't decompress data; zstd not available")
652+
653+
_importing_zstd = True
654+
try:
655+
from _zstd import ZstdDecompressor as _zstd_decompressor_class
656+
except Exception:
657+
_bootstrap._verbose_message("zipimport: zstd UNAVAILABLE")
658+
raise ZipImportError("can't decompress data; zstd not available")
659+
finally:
660+
_importing_zstd = False
661+
662+
_bootstrap._verbose_message("zipimport: zstd available")
663+
return _zstd_decompressor_class
664+
665+
666+
def _zstd_decompress(data):
667+
# A simple version of compression.zstd.decompress() as we cannot import
668+
# that here as the stdlib itself could be being zipimported.
669+
results = []
670+
while True:
671+
decomp = _get_zstd_decompressor_class()()
672+
results.append(decomp.decompress(data))
673+
if not decomp.eof:
674+
raise ZipImportError("zipimport: zstd compressed data ended before "
675+
"the end-of-stream marker")
676+
data = decomp.unused_data
677+
if not data:
678+
break
679+
return b"".join(results)
680+
629681

630682
# Given a path to a Zip file and a toc_entry, return the (uncompressed) data.
631683
def _get_data(archive, toc_entry):
@@ -659,16 +711,23 @@ def _get_data(archive, toc_entry):
659711
if len(raw_data) != data_size:
660712
raise OSError("zipimport: can't read data")
661713

662-
if compress == 0:
663-
# data is not compressed
664-
return raw_data
665-
666-
# Decompress with zlib
667-
try:
668-
decompress = _get_decompress_func()
669-
except Exception:
670-
raise ZipImportError("can't decompress data; zlib not available")
671-
return decompress(raw_data, -15)
714+
match compress:
715+
case 0: # stored
716+
return raw_data
717+
case 8: # deflate aka zlib
718+
try:
719+
decompress = _get_zlib_decompress_func()
720+
except Exception:
721+
raise ZipImportError("can't decompress data; zlib not available")
722+
return decompress(raw_data, -15)
723+
case 93: # zstd
724+
try:
725+
return _zstd_decompress(raw_data)
726+
except Exception:
727+
raise ZipImportError("could not decompress zstd data")
728+
# bz2 and lzma could be added, but are largely obsolete.
729+
case _:
730+
raise ZipImportError(f"zipimport: unsupported compression {compress}")
672731

673732

674733
# Lenient date/time comparison function. The precision of the mtime

0 commit comments

Comments
 (0)