Skip to content

Commit f40f008

Browse files
committed
Respond to review comments
1 parent 6f58a95 commit f40f008

File tree

5 files changed

+125
-70
lines changed

5 files changed

+125
-70
lines changed

Lib/compression/zstd/_zstdfile.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import io
22
from os import PathLike
3-
from _zstd import (ZstdCompressor, ZstdDecompressor, ZstdError,
4-
ZSTD_DStreamOutSize)
3+
from _zstd import ZstdCompressor, ZstdDecompressor, ZSTD_DStreamOutSize
54
from compression._common import _streams
65

76
__all__ = ('ZstdFile', 'open')

Lib/test/test_zstd.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,16 @@ def test_simple_compress_bad_args(self):
197197
self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4})
198198

199199
# valid compression level range is [-(1<<17), 22]
200+
with self.assertRaises(ValueError):
201+
ZstdCompressor(23)
202+
with self.assertRaises(ValueError):
203+
ZstdCompressor(-(1<<17)-1)
200204
with self.assertRaises(ValueError):
201205
ZstdCompressor(2**31)
202206
with self.assertRaises(ValueError):
203-
ZstdCompressor(level=-(2**31))
207+
ZstdCompressor(level=-(2**1000))
204208
with self.assertRaises(ValueError):
205-
ZstdCompressor(options={2**31: 100})
209+
ZstdCompressor(level=(2**1000))
206210

207211
with self.assertRaises(ZstdError):
208212
ZstdCompressor(options={CompressionParameter.window_log: 100})
@@ -262,15 +266,22 @@ def test_compress_parameters(self):
262266
d1[CompressionParameter.ldm_bucket_size_log] = 2**31
263267
self.assertRaises(ValueError, ZstdCompressor, options=d1)
264268

265-
# clamp compressionLevel
269+
# out of bounds compression level
266270
level_min, level_max = CompressionParameter.compression_level.bounds()
267271
with self.assertRaises(ValueError):
268272
compress(b'', level_max+1)
269273
with self.assertRaises(ValueError):
270274
compress(b'', level_min-1)
271-
272-
compress(b'', options={CompressionParameter.compression_level:level_max+1})
273-
compress(b'', options={CompressionParameter.compression_level:level_min-1})
275+
with self.assertRaises(ValueError):
276+
compress(b'', 2**1000)
277+
with self.assertRaises(ValueError):
278+
compress(b'', -(2**1000))
279+
with self.assertRaises(ValueError):
280+
compress(b'', options={
281+
CompressionParameter.compression_level: level_max+1})
282+
with self.assertRaises(ValueError):
283+
compress(b'', options={
284+
CompressionParameter.compression_level: level_min-1})
274285

275286
# zstd lib doesn't support MT compression
276287
if not SUPPORT_MULTITHREADING:
@@ -390,12 +401,22 @@ def test_simple_decompress_bad_args(self):
390401
self.assertRaises(TypeError, ZstdDecompressor, options=b'abc')
391402

392403
with self.assertRaises(ValueError):
393-
ZstdDecompressor(options={2**31 : 100})
404+
ZstdDecompressor(options={2**31: 100})
405+
with self.assertRaises(ValueError):
406+
ZstdDecompressor(options={2**1000: 100})
407+
with self.assertRaises(ValueError):
408+
ZstdDecompressor(options={-(2**31)-1: 100})
409+
with self.assertRaises(ValueError):
410+
ZstdDecompressor(options={-(2**1000): 100})
411+
with self.assertRaises(ValueError):
412+
ZstdDecompressor(options={0: 2**32})
413+
with self.assertRaises(ValueError):
414+
ZstdDecompressor(options={0: -(2**1000)})
394415

395416
with self.assertRaises(ZstdError):
396-
ZstdDecompressor(options={DecompressionParameter.window_log_max:100})
417+
ZstdDecompressor(options={DecompressionParameter.window_log_max: 100})
397418
with self.assertRaises(ZstdError):
398-
ZstdDecompressor(options={3333 : 100})
419+
ZstdDecompressor(options={3333: 100})
399420

400421
empty = compress(b'')
401422
lzd = ZstdDecompressor()
@@ -421,6 +442,20 @@ def test_decompress_parameters(self):
421442
r'\((?:32|64)-bit build\)')):
422443
decompress(b'', options=options)
423444

445+
# out of bounds deecompression parameter
446+
options[DecompressionParameter.window_log_max] = 2**31
447+
with self.assertRaises(ValueError):
448+
decompress(b'', options=options)
449+
options[DecompressionParameter.window_log_max] = -(2**32)-1
450+
with self.assertRaises(ValueError):
451+
decompress(b'', options=options)
452+
options[DecompressionParameter.window_log_max] = 2**1000
453+
with self.assertRaises(ValueError):
454+
decompress(b'', options=options)
455+
options[DecompressionParameter.window_log_max] = -(2**1000)
456+
with self.assertRaises(ValueError):
457+
decompress(b'', options=options)
458+
424459
def test_unknown_decompression_parameter(self):
425460
KEY = 100001234
426461
options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1],
@@ -1430,11 +1465,11 @@ def test_init_bad_mode(self):
14301465
ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw")
14311466

14321467
with self.assertRaisesRegex(TypeError,
1433-
r"NOT be a CompressionParameter"):
1468+
r"not be a CompressionParameter"):
14341469
ZstdFile(io.BytesIO(), 'rb',
14351470
options={CompressionParameter.compression_level:5})
14361471
with self.assertRaisesRegex(TypeError,
1437-
r"NOT be a DecompressionParameter"):
1472+
r"not be a DecompressionParameter"):
14381473
ZstdFile(io.BytesIO(), 'wb',
14391474
options={DecompressionParameter.window_log_max:21})
14401475

@@ -1473,7 +1508,7 @@ def test_init_close_fp(self):
14731508
tmp_f.write(DAT_130K_C)
14741509
filename = tmp_f.name
14751510

1476-
with self.assertRaises(ValueError):
1511+
with self.assertRaises(TypeError):
14771512
ZstdFile(filename, options={'a':'b'})
14781513

14791514
# for PyPy

Modules/_zstd/clinic/compressor.c.h

Lines changed: 4 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Modules/_zstd/compressor.c

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,32 +49,32 @@ typedef struct {
4949
#include "clinic/compressor.c.h"
5050

5151
static int
52-
_zstd_set_c_level(ZstdCompressor *self, const Py_ssize_t level)
52+
_zstd_set_c_level(ZstdCompressor *self, const int level)
5353
{
5454
/* Set integer compression level */
55-
const int min_level = ZSTD_minCLevel();
56-
const int max_level = ZSTD_maxCLevel();
55+
int min_level = ZSTD_minCLevel();
56+
int max_level = ZSTD_maxCLevel();
5757
if (level < min_level || level > max_level) {
5858
PyErr_Format(PyExc_ValueError,
59-
"compression level %zd not in valid range %d <= level <= %d.",
59+
"%zd not in valid range %d <= compression level <= %d.",
6060
level, min_level, max_level);
6161
return -1;
6262
}
6363

6464
/* Save for generating ZSTD_CDICT */
65-
self->compression_level = (int)level;
65+
self->compression_level = level;
6666

6767
/* Set compressionLevel to compression context */
68-
const size_t zstd_ret = ZSTD_CCtx_setParameter(
69-
self->cctx, ZSTD_c_compressionLevel, (int)level);
68+
size_t zstd_ret = ZSTD_CCtx_setParameter(
69+
self->cctx, ZSTD_c_compressionLevel, level);
7070

7171
/* Check error */
7272
if (ZSTD_isError(zstd_ret)) {
73-
const _zstd_state* const st = PyType_GetModuleState(Py_TYPE(self));
74-
if (st == NULL) {
73+
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
74+
if (mod_state == NULL) {
7575
return -1;
7676
}
77-
set_zstd_error(st, ERR_SET_C_LEVEL, zstd_ret);
77+
set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret);
7878
return -1;
7979
}
8080
return 0;
@@ -83,14 +83,14 @@ _zstd_set_c_level(ZstdCompressor *self, const Py_ssize_t level)
8383
static int
8484
_zstd_set_c_parameters(ZstdCompressor *self, PyObject *options)
8585
{
86-
/* Set options dict */
87-
_zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
86+
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
8887
if (mod_state == NULL) {
8988
return -1;
9089
}
9190

9291
if (!PyDict_Check(options)) {
93-
PyErr_Format(PyExc_TypeError, "invalid type for options, expected dict");
92+
PyErr_Format(PyExc_TypeError,
93+
"invalid type for options, expected dict");
9494
return -1;
9595
}
9696

@@ -100,32 +100,38 @@ _zstd_set_c_parameters(ZstdCompressor *self, PyObject *options)
100100
/* Check key type */
101101
if (Py_TYPE(key) == mod_state->DParameter_type) {
102102
PyErr_SetString(PyExc_TypeError,
103-
"key should NOT be DecompressionParameter.");
103+
"compression options dictionary key must not be a "
104+
"DecompressionParameter attribute");
104105
return -1;
105106
}
106107

107-
const int key_v = PyLong_AsInt(key);
108+
Py_INCREF(key);
109+
int key_v = PyLong_AsInt(key);
108110
if (key_v == -1 && PyErr_Occurred()) {
109-
PyErr_SetString(PyExc_ValueError,
110-
"key should be either a "
111-
"CompressionParameter attribute or an int.");
111+
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
112+
PyErr_SetString(PyExc_ValueError,
113+
"dictionary key must be less than 2**31");
114+
}
112115
return -1;
113116
}
114117

115-
// TODO(emmatyping): check bounds when there is a value error here for better
116-
// error message?
118+
Py_INCREF(value);
117119
int value_v = PyLong_AsInt(value);
118120
if (value_v == -1 && PyErr_Occurred()) {
119-
PyErr_SetString(PyExc_ValueError,
120-
"options dict value should be an int.");
121+
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
122+
PyErr_SetString(PyExc_ValueError,
123+
"dictionary value must be less than 2**31");
124+
}
121125
return -1;
122126
}
123127

124128
if (key_v == ZSTD_c_compressionLevel) {
125-
/* Save for generating ZSTD_CDICT */
126-
self->compression_level = value_v;
129+
if (_zstd_set_c_level(self, value_v) < 0) {
130+
return -1;
131+
}
132+
continue;
127133
}
128-
else if (key_v == ZSTD_c_nbWorkers) {
134+
if (key_v == ZSTD_c_nbWorkers) {
129135
/* From the zstd library docs:
130136
1. When nbWorkers >= 1, triggers asynchronous mode when
131137
used with ZSTD_compressStream2().
@@ -138,7 +144,7 @@ _zstd_set_c_parameters(ZstdCompressor *self, PyObject *options)
138144
}
139145

140146
/* Set parameter to compression context */
141-
const size_t zstd_ret = ZSTD_CCtx_setParameter(self->cctx, key_v, value_v);
147+
size_t zstd_ret = ZSTD_CCtx_setParameter(self->cctx, key_v, value_v);
142148
if (ZSTD_isError(zstd_ret)) {
143149
set_parameter_error(mod_state, 1, key_v, value_v);
144150
return -1;
@@ -323,7 +329,7 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
323329
/*[clinic input]
324330
@classmethod
325331
_zstd.ZstdCompressor.__new__ as _zstd_ZstdCompressor_new
326-
level: Py_ssize_t(c_default='PY_SSIZE_T_MIN', accept={int, NoneType}) = None
332+
level: object = None
327333
The compression level to use. Defaults to COMPRESSION_LEVEL_DEFAULT.
328334
options: object = None
329335
A dict object that contains advanced compression parameters.
@@ -337,9 +343,9 @@ function instead.
337343
[clinic start generated code]*/
338344

339345
static PyObject *
340-
_zstd_ZstdCompressor_new_impl(PyTypeObject *type, Py_ssize_t level,
346+
_zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level,
341347
PyObject *options, PyObject *zstd_dict)
342-
/*[clinic end generated code: output=a857ec0dc29fc5e2 input=9899740b24d11319]*/
348+
/*[clinic end generated code: output=cdef61eafecac3d7 input=92de0211ae20ffdc]*/
343349
{
344350
ZstdCompressor* self = PyObject_GC_New(ZstdCompressor, type);
345351
if (self == NULL) {
@@ -364,19 +370,34 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, Py_ssize_t level,
364370
/* Last mode */
365371
self->last_mode = ZSTD_e_end;
366372

367-
if (level != PY_SSIZE_T_MIN && options != Py_None) {
373+
if (level != Py_None && options != Py_None) {
368374
PyErr_SetString(PyExc_RuntimeError,
369375
"Only one of level or options should be used.");
370376
goto error;
371377
}
372378

373-
/* Set compressLevel/options to compression context */
374-
if (level != PY_SSIZE_T_MIN) {
375-
if (_zstd_set_c_level(self, level) < 0) {
379+
/* Set compression level */
380+
if (level != Py_None) {
381+
if (!PyLong_Check(level)) {
382+
PyErr_SetString(PyExc_TypeError,
383+
"invalid type for level, expected int");
384+
goto error;
385+
}
386+
int level_v = PyLong_AsInt(level);
387+
if (level_v == -1 && PyErr_Occurred()) {
388+
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
389+
PyErr_Format(PyExc_ValueError,
390+
"%zd not in valid range %d <= compression level <= %d.",
391+
level, ZSTD_minCLevel(), ZSTD_maxCLevel());
392+
}
393+
goto error;
394+
}
395+
if (_zstd_set_c_level(self, level_v) < 0) {
376396
goto error;
377397
}
378398
}
379399

400+
/* Set options dictionary */
380401
if (options != Py_None) {
381402
if (_zstd_set_c_parameters(self, options) < 0) {
382403
goto error;
@@ -693,6 +714,8 @@ PyDoc_STRVAR(ZstdCompressor_last_mode_doc,
693714
static PyMemberDef ZstdCompressor_members[] = {
694715
{"last_mode", Py_T_INT, offsetof(ZstdCompressor, last_mode),
695716
Py_READONLY, ZstdCompressor_last_mode_doc},
717+
{"compression_level", Py_T_INT, offsetof(ZstdCompressor, compression_level),
718+
Py_READONLY, NULL},
696719
{NULL}
697720
};
698721

0 commit comments

Comments
 (0)