-
-
Notifications
You must be signed in to change notification settings - Fork 19.1k
ENH: add support for reading .tar archives #44787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
c1823ef
9a85cba
e673061
a0d6386
6a8edef
d4e40c9
5f22df7
c6573ef
f3b6ed5
a4ac382
e66826b
941be37
e3369aa
0468e5f
2531ee0
57eba0a
38f7d54
887fd10
fc2e7f0
669d942
7d7d3c6
8b8b8ac
dd356f6
514014a
38971c7
e35d361
c5088fc
f6c5173
9a4fa07
0c31aa8
086c598
861faf0
9458ecb
d20f315
1066f1b
6b0e1e6
0d9ed18
37370c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2341,6 +2341,7 @@ def to_json( | |
default_handler: Callable[[Any], JSONSerializable] | None = None, | ||
lines: bool_t = False, | ||
compression: CompressionOptions = "infer", | ||
mode: str = "w", | ||
|
||
index: bool_t = True, | ||
indent: int | None = None, | ||
storage_options: StorageOptions = None, | ||
|
@@ -2604,6 +2605,7 @@ def to_json( | |
default_handler=default_handler, | ||
lines=lines, | ||
compression=compression, | ||
mode=mode, | ||
index=index, | ||
indent=indent, | ||
storage_options=storage_options, | ||
|
@@ -2923,6 +2925,7 @@ def to_pickle( | |
self, | ||
path, | ||
compression: CompressionOptions = "infer", | ||
mode: str = "wb", | ||
protocol: int = pickle.HIGHEST_PROTOCOL, | ||
storage_options: StorageOptions = None, | ||
) -> None: | ||
|
@@ -2990,6 +2993,7 @@ def to_pickle( | |
self, | ||
path, | ||
compression=compression, | ||
mode=mode, | ||
protocol=protocol, | ||
storage_options=storage_options, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
from io import ( | ||
BufferedIOBase, | ||
BytesIO, | ||
FileIO, | ||
RawIOBase, | ||
StringIO, | ||
TextIOBase, | ||
|
@@ -17,6 +18,7 @@ | |
import mmap | ||
import os | ||
from pathlib import Path | ||
import tarfile | ||
from typing import ( | ||
IO, | ||
Any, | ||
|
@@ -260,7 +262,7 @@ def _get_filepath_or_buffer( | |
---------- | ||
filepath_or_buffer : a url, filepath (str, py.path.local or pathlib.Path), | ||
or buffer | ||
compression : {{'gzip', 'bz2', 'zip', 'xz', None}}, optional | ||
compression : {{'gzip', 'bz2', 'zip', 'xz', 'tar', None}}, optional | ||
encoding : the encoding to use to decode bytes, default is 'utf-8' | ||
mode : str, optional | ||
|
||
|
@@ -443,7 +445,17 @@ def file_path_to_url(path: str) -> str: | |
return urljoin("file:", pathname2url(path)) | ||
|
||
|
||
_compression_to_extension = {"gzip": ".gz", "bz2": ".bz2", "zip": ".zip", "xz": ".xz"} | ||
_extension_to_compression = { | ||
".tar": "tar", | ||
".tar.gz": "tar", | ||
".tar.bz2": "tar", | ||
".tar.xz": "tar", | ||
".gz": "gzip", | ||
".bz2": "bz2", | ||
".zip": "zip", | ||
".xz": "xz", | ||
} | ||
_supported_compressions = set(_extension_to_compression.values()) | ||
|
||
|
||
def get_compression_method( | ||
|
@@ -494,9 +506,9 @@ def infer_compression( | |
---------- | ||
filepath_or_buffer : str or file handle | ||
File path or object. | ||
compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', None} | ||
compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', 'tar', None} | ||
If 'infer' and `filepath_or_buffer` is path-like, then detect | ||
compression from the following extensions: '.gz', '.bz2', '.zip', | ||
compression from the following extensions: '.gz', '.bz2', '.zip', '.tar', | ||
or '.xz' (otherwise no compression). | ||
|
||
Returns | ||
|
@@ -519,20 +531,18 @@ def infer_compression( | |
return None | ||
|
||
# Infer compression from the filename/URL extension | ||
for compression, extension in _compression_to_extension.items(): | ||
for extension, compression in _extension_to_compression.items(): | ||
if filepath_or_buffer.lower().endswith(extension): | ||
return compression | ||
return None | ||
|
||
# Compression has been specified. Check that it's valid | ||
if compression in _compression_to_extension: | ||
if compression in _supported_compressions: | ||
return compression | ||
|
||
# https://github.com/python/mypy/issues/5492 | ||
# Unsupported operand types for + ("List[Optional[str]]" and "List[str]") | ||
valid = ["infer", None] + sorted( | ||
_compression_to_extension | ||
) # type: ignore[operator] | ||
valid = ["infer", None] + sorted(_supported_compressions) # type: ignore[operator] | ||
msg = ( | ||
f"Unrecognized compression type: {compression}\n" | ||
f"Valid compression types are {valid}" | ||
|
@@ -677,7 +687,7 @@ def get_handle( | |
ioargs.encoding, | ||
ioargs.mode, | ||
errors, | ||
ioargs.compression["method"] not in _compression_to_extension, | ||
ioargs.compression["method"] not in _supported_compressions, | ||
) | ||
|
||
is_path = isinstance(handle, str) | ||
|
@@ -745,6 +755,25 @@ def get_handle( | |
f"Only one file per ZIP: {zip_names}" | ||
) | ||
|
||
# TAR Encoding | ||
elif compression == "tar": | ||
if is_path: | ||
handle = _BytesTarFile.open(name=handle, mode=ioargs.mode) | ||
else: | ||
handle = _BytesTarFile.open(fileobj=handle, mode=ioargs.mode) | ||
if handle.mode == "r": | ||
handles.append(handle) | ||
files = handle.getnames() | ||
if len(files) == 1: | ||
handle = handle.extractfile(files[0]) | ||
elif len(files) == 0: | ||
raise ValueError(f"Zero files found in TAR archive {path_or_buf}") | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise ValueError( | ||
"Multiple files found in TAR archive. " | ||
f"Only one file per TAR archive: {files}" | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
# XZ Compression | ||
elif compression == "xz": | ||
handle = get_lzma_file()(handle, ioargs.mode) | ||
|
@@ -823,6 +852,96 @@ def get_handle( | |
) | ||
|
||
|
||
class _BytesTarFile(tarfile.TarFile, BytesIO): | ||
|
||
|
||
# GH 17778 | ||
def __init__( | ||
self, | ||
name: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes], | ||
mode: str, | ||
fileobj: FileIO, | ||
archive_name: str | None = None, | ||
**kwargs, | ||
): | ||
self.archive_name = archive_name | ||
self.multiple_write_buffer: StringIO | BytesIO | None = None | ||
self._closing = False | ||
|
||
super().__init__(name=name, mode=mode, fileobj=fileobj, **kwargs) | ||
|
||
@classmethod | ||
def open(cls, name=None, mode="r", **kwargs): | ||
mode = mode.replace("b", "") | ||
return super().open(name=name, mode=cls.extend_mode(name, mode), **kwargs) | ||
|
||
@classmethod | ||
def extend_mode( | ||
cls, name: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes], mode: str | ||
) -> str: | ||
if mode != "w": | ||
return mode | ||
if isinstance(name, (os.PathLike, str)): | ||
filename = Path(name) | ||
if filename.suffix == ".gz": | ||
return mode + ":gz" | ||
elif filename.suffix == ".xz": | ||
return mode + ":xz" | ||
elif filename.suffix == ".bz2": | ||
return mode + ":bz2" | ||
return mode | ||
|
||
def infer_filename(self): | ||
""" | ||
If an explicit archive_name is not given, we still want the file inside the zip | ||
file not to be named something.tar, because that causes confusion (GH39465). | ||
""" | ||
if isinstance(self.name, (os.PathLike, str)): | ||
filename = Path(self.name) | ||
if filename.suffix == ".tar": | ||
return filename.with_suffix("").name | ||
if filename.suffix in [".tar.gz", ".tar.bz2", ".tar.xz"]: | ||
return filename.with_suffix("").with_suffix("").name | ||
return filename.name | ||
return None | ||
|
||
def write(self, data): | ||
# buffer multiple write calls, write on flush | ||
if self.multiple_write_buffer is None: | ||
self.multiple_write_buffer = ( | ||
BytesIO() if isinstance(data, bytes) else StringIO() | ||
) | ||
self.multiple_write_buffer.write(data) | ||
|
||
def flush(self) -> None: | ||
# write to actual handle and close write buffer | ||
if self.multiple_write_buffer is None or self.multiple_write_buffer.closed: | ||
return | ||
|
||
# TarFile needs a non-empty string | ||
archive_name = self.archive_name or self.infer_filename() or "tar" | ||
with self.multiple_write_buffer: | ||
value = self.multiple_write_buffer.getvalue() | ||
tarinfo = tarfile.TarInfo(name=archive_name) | ||
tarinfo.size = len(value) | ||
self.addfile(tarinfo, BytesIO(value)) | ||
|
||
def close(self): | ||
self.flush() | ||
super().close() | ||
|
||
@property | ||
def closed(self): | ||
if self.multiple_write_buffer is None: | ||
return False | ||
return self.multiple_write_buffer.closed and super().closed | ||
|
||
@closed.setter | ||
def closed(self, value): | ||
if not self._closing and value: | ||
self._closing = True | ||
self.close() | ||
|
||
|
||
# error: Definition of "__exit__" in base class "ZipFile" is incompatible with | ||
# definition in base class "BytesIO" [misc] | ||
# error: Definition of "__enter__" in base class "ZipFile" is incompatible with | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this entire function could be replaced with a call to
get_handle
(not needed in this PR)