Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Lib/compression/zstd/_zstdfile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import io
from os import PathLike
from _zstd import (ZstdCompressor, ZstdDecompressor, ZstdError,
ZSTD_DStreamOutSize)
from _zstd import ZstdCompressor, ZstdDecompressor, ZSTD_DStreamOutSize
from compression._common import _streams

__all__ = ('ZstdFile', 'open')
Expand Down
66 changes: 53 additions & 13 deletions Lib/test/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,17 @@ def test_simple_compress_bad_args(self):
self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234")
self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4})

# valid compression level range is [-(1<<17), 22]
with self.assertRaises(ValueError):
ZstdCompressor(23)
with self.assertRaises(ValueError):
ZstdCompressor(-(1<<17)-1)
with self.assertRaises(ValueError):
ZstdCompressor(2**31)
with self.assertRaises(ValueError):
ZstdCompressor(options={2**31: 100})
ZstdCompressor(level=-(2**1000))
with self.assertRaises(ValueError):
ZstdCompressor(level=(2**1000))

with self.assertRaises(ZstdError):
ZstdCompressor(options={CompressionParameter.window_log: 100})
Expand Down Expand Up @@ -259,13 +266,22 @@ def test_compress_parameters(self):
d1[CompressionParameter.ldm_bucket_size_log] = 2**31
self.assertRaises(ValueError, ZstdCompressor, options=d1)

# clamp compressionLevel
# out of bounds compression level
level_min, level_max = CompressionParameter.compression_level.bounds()
compress(b'', level_max+1)
compress(b'', level_min-1)

compress(b'', options={CompressionParameter.compression_level:level_max+1})
compress(b'', options={CompressionParameter.compression_level:level_min-1})
with self.assertRaises(ValueError):
compress(b'', level_max+1)
with self.assertRaises(ValueError):
compress(b'', level_min-1)
with self.assertRaises(ValueError):
compress(b'', 2**1000)
with self.assertRaises(ValueError):
compress(b'', -(2**1000))
with self.assertRaises(ValueError):
compress(b'', options={
CompressionParameter.compression_level: level_max+1})
with self.assertRaises(ValueError):
compress(b'', options={
CompressionParameter.compression_level: level_min-1})

# zstd lib doesn't support MT compression
if not SUPPORT_MULTITHREADING:
Expand Down Expand Up @@ -385,12 +401,22 @@ def test_simple_decompress_bad_args(self):
self.assertRaises(TypeError, ZstdDecompressor, options=b'abc')

with self.assertRaises(ValueError):
ZstdDecompressor(options={2**31 : 100})
ZstdDecompressor(options={2**31: 100})
with self.assertRaises(ValueError):
ZstdDecompressor(options={2**1000: 100})
with self.assertRaises(ValueError):
ZstdDecompressor(options={-(2**31)-1: 100})
with self.assertRaises(ValueError):
ZstdDecompressor(options={-(2**1000): 100})
with self.assertRaises(ValueError):
ZstdDecompressor(options={0: 2**32})
with self.assertRaises(ValueError):
ZstdDecompressor(options={0: -(2**1000)})

with self.assertRaises(ZstdError):
ZstdDecompressor(options={DecompressionParameter.window_log_max:100})
ZstdDecompressor(options={DecompressionParameter.window_log_max: 100})
with self.assertRaises(ZstdError):
ZstdDecompressor(options={3333 : 100})
ZstdDecompressor(options={3333: 100})

empty = compress(b'')
lzd = ZstdDecompressor()
Expand All @@ -416,6 +442,20 @@ def test_decompress_parameters(self):
r'\((?:32|64)-bit build\)')):
decompress(b'', options=options)

# out of bounds deecompression parameter
options[DecompressionParameter.window_log_max] = 2**31
with self.assertRaises(ValueError):
decompress(b'', options=options)
options[DecompressionParameter.window_log_max] = -(2**32)-1
with self.assertRaises(ValueError):
decompress(b'', options=options)
options[DecompressionParameter.window_log_max] = 2**1000
with self.assertRaises(ValueError):
decompress(b'', options=options)
options[DecompressionParameter.window_log_max] = -(2**1000)
with self.assertRaises(ValueError):
decompress(b'', options=options)

def test_unknown_decompression_parameter(self):
KEY = 100001234
options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1],
Expand Down Expand Up @@ -1425,11 +1465,11 @@ def test_init_bad_mode(self):
ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw")

with self.assertRaisesRegex(TypeError,
r"NOT be a CompressionParameter"):
r"not be a CompressionParameter"):
ZstdFile(io.BytesIO(), 'rb',
options={CompressionParameter.compression_level:5})
with self.assertRaisesRegex(TypeError,
r"NOT be a DecompressionParameter"):
r"not be a DecompressionParameter"):
ZstdFile(io.BytesIO(), 'wb',
options={DecompressionParameter.window_log_max:21})

Expand Down Expand Up @@ -1468,7 +1508,7 @@ def test_init_close_fp(self):
tmp_f.write(DAT_130K_C)
filename = tmp_f.name

with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
ZstdFile(filename, options={'a':'b'})

# for PyPy
Expand Down
169 changes: 97 additions & 72 deletions Modules/_zstd/compressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,98 +49,108 @@ typedef struct {
#include "clinic/compressor.c.h"

static int
_zstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options,
const char *arg_name, const char* arg_type)
_zstd_set_c_level(ZstdCompressor *self, const int level)
{
size_t zstd_ret;
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
/* Set integer compression level */
int min_level = ZSTD_minCLevel();
int max_level = ZSTD_maxCLevel();
if (level < min_level || level > max_level) {
PyErr_Format(PyExc_ValueError,
"%zd not in valid range %d <= compression level <= %d.",
level, min_level, max_level);
return -1;
}

/* Integer compression level */
if (PyLong_Check(level_or_options)) {
int level = PyLong_AsInt(level_or_options);
if (level == -1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Compression level should be an int value between "
"%d and %d.", ZSTD_minCLevel(), ZSTD_maxCLevel());
return -1;
}

/* Save for generating ZSTD_CDICT */
self->compression_level = level;
/* Save for generating ZSTD_CDICT */
self->compression_level = level;

/* Set compressionLevel to compression context */
zstd_ret = ZSTD_CCtx_setParameter(self->cctx,
ZSTD_c_compressionLevel,
level);
/* Set compressionLevel to compression context */
size_t zstd_ret = ZSTD_CCtx_setParameter(
self->cctx, ZSTD_c_compressionLevel, level);

/* Check error */
if (ZSTD_isError(zstd_ret)) {
set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret);
/* Check error */
if (ZSTD_isError(zstd_ret)) {
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
return -1;
}
return 0;
set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret);
return -1;
}
return 0;
}

/* Options dict */
if (PyDict_Check(level_or_options)) {
PyObject *key, *value;
Py_ssize_t pos = 0;
static int
_zstd_set_c_parameters(ZstdCompressor *self, PyObject *options)
{
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
return -1;
}

while (PyDict_Next(level_or_options, &pos, &key, &value)) {
/* Check key type */
if (Py_TYPE(key) == mod_state->DParameter_type) {
PyErr_SetString(PyExc_TypeError,
"Key of compression options dict should "
"NOT be a DecompressionParameter attribute.");
return -1;
}
if (!PyDict_Check(options)) {
PyErr_Format(PyExc_TypeError,
"invalid type for options, expected dict");
return -1;
}

int key_v = PyLong_AsInt(key);
if (key_v == -1 && PyErr_Occurred()) {
Py_ssize_t pos = 0;
PyObject *key, *value;
while (PyDict_Next(options, &pos, &key, &value)) {
/* Check key type */
if (Py_TYPE(key) == mod_state->DParameter_type) {
PyErr_SetString(PyExc_TypeError,
"compression options dictionary key must not be a "
"DecompressionParameter attribute");
return -1;
}

Py_INCREF(key);
int key_v = PyLong_AsInt(key);
if (key_v == -1 && PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
PyErr_SetString(PyExc_ValueError,
"Key of options dict should be either a "
"CompressionParameter attribute or an int.");
return -1;
"dictionary key must be less than 2**31");
}
return -1;
}

int value_v = PyLong_AsInt(value);
if (value_v == -1 && PyErr_Occurred()) {
Py_INCREF(value);
int value_v = PyLong_AsInt(value);
if (value_v == -1 && PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
PyErr_SetString(PyExc_ValueError,
"Value of options dict should be an int.");
return -1;
"dictionary value must be less than 2**31");
}
return -1;
}

if (key_v == ZSTD_c_compressionLevel) {
/* Save for generating ZSTD_CDICT */
self->compression_level = value_v;
if (key_v == ZSTD_c_compressionLevel) {
if (_zstd_set_c_level(self, value_v) < 0) {
return -1;
}
else if (key_v == ZSTD_c_nbWorkers) {
/* From the zstd library docs:
1. When nbWorkers >= 1, triggers asynchronous mode when
used with ZSTD_compressStream2().
2, Default value is `0`, aka "single-threaded mode" : no
worker is spawned, compression is performed inside
caller's thread, all invocations are blocking. */
if (value_v != 0) {
self->use_multithread = 1;
}
continue;
}
if (key_v == ZSTD_c_nbWorkers) {
/* From the zstd library docs:
1. When nbWorkers >= 1, triggers asynchronous mode when
used with ZSTD_compressStream2().
2, Default value is `0`, aka "single-threaded mode" : no
worker is spawned, compression is performed inside
caller's thread, all invocations are blocking. */
if (value_v != 0) {
self->use_multithread = 1;
}
}

/* Set parameter to compression context */
zstd_ret = ZSTD_CCtx_setParameter(self->cctx, key_v, value_v);
if (ZSTD_isError(zstd_ret)) {
set_parameter_error(mod_state, 1, key_v, value_v);
return -1;
}
/* Set parameter to compression context */
size_t zstd_ret = ZSTD_CCtx_setParameter(self->cctx, key_v, value_v);
if (ZSTD_isError(zstd_ret)) {
set_parameter_error(mod_state, 1, key_v, value_v);
return -1;
}
return 0;
}
PyErr_Format(PyExc_TypeError,
"Invalid type for %s. Expected %s", arg_name, arg_type);
return -1;
return 0;
}

static void
Expand Down Expand Up @@ -366,15 +376,30 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level,
goto error;
}

/* Set compressLevel/options to compression context */
/* Set compression level */
if (level != Py_None) {
if (_zstd_set_c_parameters(self, level, "level", "int") < 0) {
if (!PyLong_Check(level)) {
PyErr_SetString(PyExc_TypeError,
"invalid type for level, expected int");
goto error;
}
int level_v = PyLong_AsInt(level);
if (level_v == -1 && PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
PyErr_Format(PyExc_ValueError,
"%zd not in valid range %d <= compression level <= %d.",
level, ZSTD_minCLevel(), ZSTD_maxCLevel());
}
goto error;
}
if (_zstd_set_c_level(self, level_v) < 0) {
goto error;
}
}

/* Set options dictionary */
if (options != Py_None) {
if (_zstd_set_c_parameters(self, options, "options", "dict") < 0) {
if (_zstd_set_c_parameters(self, options) < 0) {
goto error;
}
}
Expand Down
Loading
Loading