Skip to content

Commit f51c758

Browse files
emmatypingAA-Turner
andcommitted
Adopt many suggestions from AA-Turner for ZstdFile
* rename filename argument to file * improve __init__ mode and argument checking * docstring and error rewording * renamed self._closefp to self._close_fp * removed mode_code from __init__ * removed unneeded self._READER_CLASS Co-authored-by: Adam Turner <[email protected]>
1 parent 96d27b0 commit f51c758

File tree

2 files changed

+44
-54
lines changed

2 files changed

+44
-54
lines changed

Lib/compression/zstd/_zstdfile.py

Lines changed: 42 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class ZstdFile(_streams.BaseStream):
3131

3232
def __init__(
3333
self,
34-
filename,
34+
file,
35+
/,
3536
mode="r",
3637
*,
3738
level=None,
@@ -40,7 +41,7 @@ def __init__(
4041
):
4142
"""Open a zstd compressed file in binary mode.
4243
43-
filename can be either an actual file name (given as a str, bytes, or
44+
file can be either an actual file name (given as a str, bytes, or
4445
PathLike object), in which case the named file is opened, or it can be
4546
an existing file object to read from or write to.
4647
@@ -58,29 +59,23 @@ def __init__(
5859
See the function train_dict for how to train a ZstdDict on sample data.
5960
"""
6061
self._fp = None
61-
self._closefp = False
62+
self._close_fp = False
6263
self._mode = _MODE_CLOSED
6364

65+
if not isinstance(mode, str):
66+
raise ValueError("mode must be a str")
6467
# Read or write mode
65-
if mode in ("r", "rb"):
66-
if not isinstance(options, (type(None), dict)):
67-
raise TypeError(
68-
(
69-
"In read mode (decompression), options argument "
70-
"should be a dict object, that represents "
71-
"decompression options."
72-
)
73-
)
68+
if options is not None and not isinstance(options, dict):
69+
raise TypeError("options must be a dict or None")
70+
mode = mode.removesuffix("b") # handle rb, wb, xb, ab
71+
if mode == "r":
7472
if level is not None:
75-
raise TypeError("level argument should only be passed when "
76-
"writing.")
77-
mode_code = _MODE_READ
78-
elif mode in ("w", "wb", "a", "ab", "x", "xb"):
79-
if not isinstance(level, (type(None), int)):
80-
raise TypeError("level argument should be an int object.")
81-
if not isinstance(options, (type(None), dict)):
82-
raise TypeError("options argument should be an dict object.")
83-
mode_code = _MODE_WRITE
73+
raise TypeError("level is illegal in read mode")
74+
self._mode = _MODE_READ
75+
elif mode in {"w", "a", "x"}:
76+
if level is not None and not isinstance(level, int):
77+
raise TypeError("level must be int or None")
78+
self._mode = _MODE_WRITE
8479
self._compressor = ZstdCompressor(
8580
level=level, options=options, zstd_dict=zstd_dict
8681
)
@@ -89,17 +84,15 @@ def __init__(
8984
raise ValueError(f"Invalid mode: {mode!r}")
9085

9186
# File object
92-
if isinstance(filename, (str, bytes, PathLike)):
93-
if "b" not in mode:
94-
mode += "b"
95-
self._fp = io.open(filename, mode)
96-
self._closefp = True
97-
elif hasattr(filename, "read") or hasattr(filename, "write"):
98-
self._fp = filename
87+
if isinstance(file, (str, bytes, PathLike)):
88+
self._fp = io.open(file, f'{mode}b')
89+
self._close_fp = True
90+
elif ((mode == 'r' and hasattr(file, "read"))
91+
or (mode != 'r' and hasattr(file, "write"))):
92+
self._fp = file
9993
else:
100-
raise TypeError("filename must be a str, bytes, file or PathLike "
101-
"object")
102-
self._mode = mode_code
94+
raise TypeError("file must be a file-like object "
95+
"or a str, bytes, or PathLike object")
10396

10497
if self._mode == _MODE_READ:
10598
raw = _streams.DecompressReader(
@@ -114,15 +107,14 @@ def __init__(
114107
def close(self):
115108
"""Flush and close the file.
116109
117-
May be called more than once without error. Once the file is
118-
closed, any other operation on it will raise a ValueError.
110+
May be called multiple times. Once the file has been closed,
111+
any other operation on it will raise ValueError.
119112
"""
120-
# Nop if already closed
121113
if self._fp is None:
122114
return
123115
try:
124116
if self._mode == _MODE_READ:
125-
if hasattr(self, "_buffer") and self._buffer:
117+
if getattr(self, '_buffer', None):
126118
self._buffer.close()
127119
self._buffer = None
128120
elif self._mode == _MODE_WRITE:
@@ -131,11 +123,11 @@ def close(self):
131123
finally:
132124
self._mode = _MODE_CLOSED
133125
try:
134-
if self._closefp:
126+
if self._close_fp:
135127
self._fp.close()
136128
finally:
137129
self._fp = None
138-
self._closefp = False
130+
self._close_fp = False
139131

140132
def write(self, data):
141133
"""Write a bytes-like object *data* to the file.
@@ -161,9 +153,8 @@ def write(self, data):
161153
def flush(self, mode=FLUSH_BLOCK):
162154
"""Flush remaining data to the underlying stream.
163155
164-
The mode argument can be ZstdFile.FLUSH_BLOCK or ZstdFile.FLUSH_FRAME.
165-
Abuse of this method will reduce compression ratio, use it only when
166-
necessary.
156+
The mode argument can be FLUSH_BLOCK or FLUSH_FRAME. Abuse of this
157+
method will reduce compression ratio, use it only when necessary.
167158
168159
If the program is interrupted afterwards, all data can be recovered.
169160
To ensure saving to disk, also need to use os.fsync(fd).
@@ -173,10 +164,10 @@ def flush(self, mode=FLUSH_BLOCK):
173164
if self._mode == _MODE_READ:
174165
return
175166
self._check_not_closed()
176-
if mode not in (self.FLUSH_BLOCK, self.FLUSH_FRAME):
177-
raise ValueError("mode argument wrong value, it should be "
178-
"ZstdCompressor.FLUSH_FRAME or "
179-
"ZstdCompressor.FLUSH_BLOCK.")
167+
if mode not in {self.FLUSH_BLOCK, self.FLUSH_FRAME}:
168+
raise ValueError("Invalid mode argument, expected either "
169+
"ZstdFile.FLUSH_FRAME or "
170+
"ZstdFile.FLUSH_BLOCK")
180171
if self._compressor.last_mode == mode:
181172
return
182173
# Flush zstd block/frame, and write.
@@ -270,8 +261,7 @@ def peek(self, size=-1):
270261
return self._buffer.peek(size)
271262

272263
def __next__(self):
273-
ret = self._buffer.readline()
274-
if ret:
264+
if ret := self._buffer.readline():
275265
return ret
276266
raise StopIteration
277267

@@ -319,7 +309,8 @@ def writable(self):
319309

320310
# Copied from lzma module
321311
def open(
322-
filename,
312+
file,
313+
/,
323314
mode="rb",
324315
*,
325316
level=None,
@@ -331,9 +322,9 @@ def open(
331322
):
332323
"""Open a zstd compressed file in binary or text mode.
333324
334-
filename can be either an actual file name (given as a str, bytes, or
335-
PathLike object), in which case the named file is opened, or it can be an
336-
existing file object to read from or write to.
325+
file can be either a file name (given as a str, bytes, or PathLike object),
326+
in which case the named file is opened, or it can be an existing file object
327+
to read from or write to.
337328
338329
The mode parameter can be "r", "rb" (default), "w", "wb", "x", "xb", "a",
339330
"ab" for binary mode, or "rt", "wt", "xt", "at" for text mode.
@@ -370,7 +361,7 @@ def open(
370361

371362
zstd_mode = mode.replace("t", "")
372363
binary_file = ZstdFile(
373-
filename, zstd_mode, level=level, options=options, zstd_dict=zstd_dict
364+
file, zstd_mode, level=level, options=options, zstd_dict=zstd_dict
374365
)
375366

376367
if "t" in mode:

Lib/test/test_zstd/test_core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,10 +2121,9 @@ class T:
21212121
def read(self, size):
21222122
return b'a' * size
21232123

2124-
with self.assertRaises(AttributeError): # on close
2124+
with self.assertRaises(TypeError): # on creation
21252125
with ZstdFile(T(), 'w') as f:
2126-
with self.assertRaises(AttributeError): # on write
2127-
f.write(b'1234')
2126+
pass
21282127

21292128
# 3
21302129
with ZstdFile(io.BytesIO(), 'w') as f:

0 commit comments

Comments
 (0)