Skip to content

Commit 9aec708

Browse files
gh-132983: Minor fixes and clean up for the _zstd module
1 parent ebf6d13 commit 9aec708

File tree

4 files changed

+109
-80
lines changed

4 files changed

+109
-80
lines changed

Lib/test/test_zstd.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,18 +1239,37 @@ def test_train_dict_c(self):
12391239
# argument wrong type
12401240
with self.assertRaises(TypeError):
12411241
_zstd.train_dict({}, (), 100)
1242+
with self.assertRaises(TypeError):
1243+
_zstd.train_dict(bytearray(), (), 100)
12421244
with self.assertRaises(TypeError):
12431245
_zstd.train_dict(b'', 99, 100)
1246+
with self.assertRaises(TypeError):
1247+
_zstd.train_dict(b'', [], 100)
12441248
with self.assertRaises(TypeError):
12451249
_zstd.train_dict(b'', (), 100.1)
1250+
with self.assertRaises(TypeError):
1251+
_zstd.train_dict(b'', (99.1,), 100)
1252+
with self.assertRaises(ValueError):
1253+
_zstd.train_dict(b'abc', (4, -1), 100)
1254+
with self.assertRaises(ValueError):
1255+
_zstd.train_dict(b'abc', (2,), 100)
1256+
with self.assertRaises(ValueError):
1257+
_zstd.train_dict(b'', (99,), 100)
12461258

12471259
# size > size_t
12481260
with self.assertRaises(ValueError):
1249-
_zstd.train_dict(b'', (2**64+1,), 100)
1261+
_zstd.train_dict(b'', (2**1000,), 100)
1262+
with self.assertRaises(ValueError):
1263+
_zstd.train_dict(b'', (-2**1000,), 100)
12501264

12511265
# dict_size <= 0
12521266
with self.assertRaises(ValueError):
12531267
_zstd.train_dict(b'', (), 0)
1268+
with self.assertRaises(ValueError):
1269+
_zstd.train_dict(b'', (), -1)
1270+
1271+
with self.assertRaises(ZstdError):
1272+
_zstd.train_dict(b'', (), 1)
12541273

12551274
def test_finalize_dict_c(self):
12561275
with self.assertRaises(TypeError):
@@ -1259,22 +1278,51 @@ def test_finalize_dict_c(self):
12591278
# argument wrong type
12601279
with self.assertRaises(TypeError):
12611280
_zstd.finalize_dict({}, b'', (), 100, 5)
1281+
with self.assertRaises(TypeError):
1282+
_zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5)
12621283
with self.assertRaises(TypeError):
12631284
_zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5)
1285+
with self.assertRaises(TypeError):
1286+
_zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5)
12641287
with self.assertRaises(TypeError):
12651288
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5)
1289+
with self.assertRaises(TypeError):
1290+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5)
12661291
with self.assertRaises(TypeError):
12671292
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5)
12681293
with self.assertRaises(TypeError):
12691294
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1)
12701295

1296+
with self.assertRaises(ValueError):
1297+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5)
1298+
with self.assertRaises(ValueError):
1299+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5)
1300+
with self.assertRaises(ValueError):
1301+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5)
1302+
12711303
# size > size_t
12721304
with self.assertRaises(ValueError):
1273-
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5)
1305+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5)
1306+
with self.assertRaises(ValueError):
1307+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5)
12741308

12751309
# dict_size <= 0
12761310
with self.assertRaises(ValueError):
12771311
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5)
1312+
with self.assertRaises(ValueError):
1313+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5)
1314+
with self.assertRaises(OverflowError):
1315+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5)
1316+
with self.assertRaises(OverflowError):
1317+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5)
1318+
1319+
with self.assertRaises(OverflowError):
1320+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000)
1321+
with self.assertRaises(OverflowError):
1322+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000)
1323+
1324+
with self.assertRaises(ZstdError):
1325+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5)
12781326

12791327
def test_train_buffer_protocol_samples(self):
12801328
def _nbytes(dat):

Modules/_zstd/_zstdmodule.c

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ set_zstd_error(const _zstd_state* const state,
2828
char *msg;
2929
assert(ZSTD_isError(zstd_ret));
3030

31+
if (state == NULL) {
32+
return;
33+
}
3134
switch (type) {
3235
case ERR_DECOMPRESS:
3336
msg = "Unable to decompress Zstandard data: %s";
@@ -174,7 +177,7 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
174177
Py_ssize_t sizes_sum;
175178
Py_ssize_t i;
176179

177-
chunks_number = Py_SIZE(samples_sizes);
180+
chunks_number = PyTuple_GET_SIZE(samples_sizes);
178181
if ((size_t) chunks_number > UINT32_MAX) {
179182
PyErr_Format(PyExc_ValueError,
180183
"The number of samples should be <= %u.", UINT32_MAX);
@@ -188,20 +191,24 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
188191
return -1;
189192
}
190193

191-
sizes_sum = 0;
194+
sizes_sum = PyBytes_GET_SIZE(samples_bytes);
192195
for (i = 0; i < chunks_number; i++) {
193-
PyObject *size = PyTuple_GetItem(samples_sizes, i);
194-
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
195-
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
196-
PyErr_Format(PyExc_ValueError,
197-
"Items in samples_sizes should be an int "
198-
"object, with a value between 0 and %u.", SIZE_MAX);
196+
size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i));
197+
(*chunk_sizes)[i] = size;
198+
if (size == (size_t)-1 && PyErr_Occurred()) {
199+
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
200+
goto sum_error;
201+
}
199202
return -1;
200203
}
201-
sizes_sum += (*chunk_sizes)[i];
204+
if ((size_t)sizes_sum < size) {
205+
goto sum_error;
206+
}
207+
sizes_sum -= size;
202208
}
203209

204-
if (sizes_sum != Py_SIZE(samples_bytes)) {
210+
if (sizes_sum != 0) {
211+
sum_error:
205212
PyErr_SetString(PyExc_ValueError,
206213
"The samples size tuple doesn't match the "
207214
"concatenation's size.");
@@ -257,7 +264,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
257264

258265
/* Train the dictionary */
259266
char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes);
260-
char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
267+
const char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
261268
Py_BEGIN_ALLOW_THREADS
262269
zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size,
263270
samples_buffer,
@@ -507,17 +514,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type,
507514
{
508515
_zstd_state* mod_state = get_zstd_state(module);
509516

510-
if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) {
511-
PyErr_SetString(PyExc_ValueError,
512-
"The two arguments should be CompressionParameter and "
513-
"DecompressionParameter types.");
514-
return NULL;
515-
}
516-
517-
Py_XSETREF(
518-
mod_state->CParameter_type, (PyTypeObject*)Py_NewRef(c_parameter_type));
519-
Py_XSETREF(
520-
mod_state->DParameter_type, (PyTypeObject*)Py_NewRef(d_parameter_type));
517+
Py_INCREF(c_parameter_type);
518+
Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type);
519+
Py_INCREF(d_parameter_type);
520+
Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type);
521521

522522
Py_RETURN_NONE;
523523
}
@@ -580,7 +580,6 @@ do { \
580580
return -1;
581581
}
582582
if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) {
583-
Py_DECREF(mod_state->ZstdError);
584583
return -1;
585584
}
586585

Modules/_zstd/compressor.c

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@ _zstd_set_c_level(ZstdCompressor *self, int level)
7171
/* Check error */
7272
if (ZSTD_isError(zstd_ret)) {
7373
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
74-
if (mod_state == NULL) {
75-
return -1;
76-
}
7774
set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret);
7875
return -1;
7976
}
@@ -203,16 +200,16 @@ _get_CDict(ZstdDict *self, int compressionLevel)
203200
goto error;
204201
}
205202

206-
/* Add PyCapsule object to self->c_dicts */
207-
ret = PyDict_SetItem(self->c_dicts, level, capsule);
203+
/* Add PyCapsule object to self->c_dicts if it is not already present. */
204+
PyObject *result;
205+
ret = PyDict_SetDefaultRef(self->c_dicts, level, capsule, &result);
208206
if (ret < 0) {
209207
goto error;
210208
}
209+
Py_DECREF(capsule);
210+
capsule = result;
211211
}
212-
else {
213-
/* ZSTD_CDict instance already exists */
214-
cdict = PyCapsule_GetPointer(capsule, NULL);
215-
}
212+
cdict = PyCapsule_GetPointer(capsule, NULL);
216213
goto success;
217214

218215
error:
@@ -272,11 +269,7 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
272269
int type, ret;
273270

274271
/* Check ZstdDict */
275-
ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type);
276-
if (ret < 0) {
277-
return -1;
278-
}
279-
else if (ret > 0) {
272+
if (PyObject_TypeCheck(dict, mod_state->ZstdDict_type)) {
280273
/* When compressing, use undigested dictionary by default. */
281274
zd = (ZstdDict*)dict;
282275
type = DICT_TYPE_UNDIGESTED;
@@ -289,14 +282,14 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
289282
/* Check (ZstdDict, type) */
290283
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) {
291284
/* Check ZstdDict */
292-
ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0),
293-
(PyObject*)mod_state->ZstdDict_type);
294-
if (ret < 0) {
295-
return -1;
296-
}
297-
else if (ret > 0) {
298-
/* type == -1 may indicate an error. */
285+
if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0),
286+
mod_state->ZstdDict_type) &&
287+
PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
288+
{
299289
type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
290+
if (type == -1 && PyErr_Occurred()) {
291+
return -1;
292+
}
300293
if (type == DICT_TYPE_DIGESTED
301294
|| type == DICT_TYPE_UNDIGESTED
302295
|| type == DICT_TYPE_PREFIX)
@@ -481,9 +474,7 @@ compress_lock_held(ZstdCompressor *self, Py_buffer *data,
481474
/* Check error */
482475
if (ZSTD_isError(zstd_ret)) {
483476
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
484-
if (mod_state != NULL) {
485-
set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
486-
}
477+
set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
487478
goto error;
488479
}
489480

@@ -553,9 +544,7 @@ compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data)
553544
/* Check error */
554545
if (ZSTD_isError(zstd_ret)) {
555546
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
556-
if (mod_state != NULL) {
557-
set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
558-
}
547+
set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
559548
goto error;
560549
}
561550

Modules/_zstd/decompressor.c

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,23 @@ _get_DDict(ZstdDict *self)
6161
assert(PyMutex_IsLocked(&self->lock));
6262
ZSTD_DDict *ret;
6363

64-
/* Already created */
65-
if (self->d_dict != NULL) {
66-
return self->d_dict;
67-
}
68-
6964
if (self->d_dict == NULL) {
7065
/* Create ZSTD_DDict instance from dictionary content */
7166
Py_BEGIN_ALLOW_THREADS
7267
ret = ZSTD_createDDict(self->dict_buffer, self->dict_len);
7368
Py_END_ALLOW_THREADS
74-
self->d_dict = ret;
75-
76-
if (self->d_dict == NULL) {
77-
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
78-
if (mod_state != NULL) {
79-
PyErr_SetString(mod_state->ZstdError,
80-
"Failed to create a ZSTD_DDict instance from "
81-
"Zstandard dictionary content.");
69+
if (self->d_dict != NULL) {
70+
ZSTD_freeDDict(ret);
71+
}
72+
else {
73+
self->d_dict = ret;
74+
if (self->d_dict == NULL) {
75+
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
76+
if (mod_state != NULL) {
77+
PyErr_SetString(mod_state->ZstdError,
78+
"Failed to create a ZSTD_DDict instance from "
79+
"Zstandard dictionary content.");
80+
}
8281
}
8382
}
8483
}
@@ -189,11 +188,7 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
189188
int type, ret;
190189

191190
/* Check ZstdDict */
192-
ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type);
193-
if (ret < 0) {
194-
return -1;
195-
}
196-
else if (ret > 0) {
191+
if (PyObject_TypeCheck(dict, mod_state->ZstdDict_type)) {
197192
/* When decompressing, use digested dictionary by default. */
198193
zd = (ZstdDict*)dict;
199194
type = DICT_TYPE_DIGESTED;
@@ -206,14 +201,14 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
206201
/* Check (ZstdDict, type) */
207202
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) {
208203
/* Check ZstdDict */
209-
ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0),
210-
(PyObject*)mod_state->ZstdDict_type);
211-
if (ret < 0) {
212-
return -1;
213-
}
214-
else if (ret > 0) {
215-
/* type == -1 may indicate an error. */
204+
if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0),
205+
mod_state->ZstdDict_type) &&
206+
PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
207+
{
216208
type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
209+
if (type == -1 && PyErr_Occurred()) {
210+
return -1;
211+
}
217212
if (type == DICT_TYPE_DIGESTED
218213
|| type == DICT_TYPE_UNDIGESTED
219214
|| type == DICT_TYPE_PREFIX)
@@ -282,9 +277,7 @@ decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in,
282277
/* Check error */
283278
if (ZSTD_isError(zstd_ret)) {
284279
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
285-
if (mod_state != NULL) {
286-
set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret);
287-
}
280+
set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret);
288281
goto error;
289282
}
290283

0 commit comments

Comments
 (0)