Skip to content

Commit 18f54e7

Browse files
authored
Add constant enumerating supported formats (#1)
1 parent 69e70dd commit 18f54e7

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

tests/test_xtarfile.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from tempfile import mkstemp
77
from unittest import TestCase
88

9+
from xtarfile.xtarfile import SUPPORTED_FORMATS
910
from xtarfile.xtarfile import get_compression
1011
from xtarfile.xtarfile import xtarfile_open
11-
from xtarfile.xtarfile import HANDLERS
1212

1313

1414
class FileExtensionIdContext:
@@ -54,11 +54,9 @@ def test_falls_back_to_extension(self):
5454

5555
class OpenTests(TestCase):
5656
def test_roundtrip(self):
57-
plugins = [key for (key, value) in HANDLERS.items() if value]
58-
compressors = ['gz', 'bz2', 'xz'] + plugins
5957
contexts = (ExplicitOpenIdContext, FileExtensionIdContext)
6058

61-
for compressor, ctx in product(compressors, contexts):
59+
for compressor, ctx in product(SUPPORTED_FORMATS, contexts):
6260
context = ctx(self, compressor)
6361
with self.subTest(compressor=compressor, context=str(context)):
6462
self._test_roundtrip(context)

xtarfile/xtarfile.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
from itertools import chain
12
from tarfile import open as tarfile_open
23

34
from xtarfile.zstd import ZstandardTarfile
45

56

6-
HANDLERS = {
7+
_HANDLERS = {
78
'zstd': ZstandardTarfile
89
}
910

11+
_NATIVE_FORMATS = ('gz', 'bz2', 'xz')
12+
13+
SUPPORTED_FORMATS = frozenset(chain(_HANDLERS.keys(), _NATIVE_FORMATS))
14+
1015

1116
def get_compression(path: str, mode: str) -> str:
1217
for delim in (':', '|'):
@@ -24,10 +29,10 @@ def get_compression(path: str, mode: str) -> str:
2429
def xtarfile_open(path: str, mode: str, **kwargs):
2530
compression = get_compression(path, mode)
2631

27-
if not compression or compression in ('gz', 'bz2', 'xz'):
32+
if not compression or compression in _NATIVE_FORMATS:
2833
return tarfile_open(path, mode, **kwargs)
2934

30-
handler_class = HANDLERS.get(compression)
35+
handler_class = _HANDLERS.get(compression)
3136
if handler_class is not None:
3237
handler = handler_class(**kwargs)
3338
if mode.startswith('r'):

0 commit comments

Comments
 (0)