Skip to content

Commit eaf46a8

Browse files
committed
Revert "Remove _set_parameter_types"
This reverts commit bf4b07d. Checking the type of parameters is important to avoid confusing error messages.
1 parent bf4b07d commit eaf46a8

File tree

6 files changed

+144
-1
lines changed

6 files changed

+144
-1
lines changed

Lib/compression/zstd/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,7 @@ class Strategy(enum.IntEnum):
228228
btopt = _zstd._ZSTD_btopt
229229
btultra = _zstd._ZSTD_btultra
230230
btultra2 = _zstd._ZSTD_btultra2
231+
232+
233+
# Check validity of the CompressionParameter & DecompressionParameter types
234+
_zstd._set_parameter_types(CompressionParameter, DecompressionParameter)

Modules/_zstd/_zstdmodule.c

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,12 +510,49 @@ _zstd__get_frame_info_impl(PyObject *module, Py_buffer *frame_buffer)
510510
return Py_BuildValue("KI", decompressed_size, dict_id);
511511
}
512512

513+
/*[clinic input]
514+
_zstd._set_parameter_types
515+
516+
c_parameter_type: object(subclass_of='&PyType_Type')
517+
CompressionParameter IntEnum type object
518+
d_parameter_type: object(subclass_of='&PyType_Type')
519+
DecompressionParameter IntEnum type object
520+
521+
Internal function, set CompressionParameter/DecompressionParameter types for validity check.
522+
[clinic start generated code]*/
523+
524+
static PyObject *
525+
_zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type,
526+
PyObject *d_parameter_type)
527+
/*[clinic end generated code: output=a13d4890ccbd2873 input=4535545d903853d3]*/
528+
{
529+
_zstd_state* const mod_state = get_zstd_state(module);
530+
531+
if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) {
532+
PyErr_SetString(PyExc_ValueError,
533+
"The two arguments should be CompressionParameter and "
534+
"DecompressionParameter types.");
535+
return NULL;
536+
}
537+
538+
Py_XDECREF(mod_state->CParameter_type);
539+
Py_INCREF(c_parameter_type);
540+
mod_state->CParameter_type = (PyTypeObject*)c_parameter_type;
541+
542+
Py_XDECREF(mod_state->DParameter_type);
543+
Py_INCREF(d_parameter_type);
544+
mod_state->DParameter_type = (PyTypeObject*)d_parameter_type;
545+
546+
Py_RETURN_NONE;
547+
}
548+
513549
static PyMethodDef _zstd_methods[] = {
514550
_ZSTD__TRAIN_DICT_METHODDEF
515551
_ZSTD__FINALIZE_DICT_METHODDEF
516552
_ZSTD__GET_PARAM_BOUNDS_METHODDEF
517553
_ZSTD_GET_FRAME_SIZE_METHODDEF
518554
_ZSTD__GET_FRAME_INFO_METHODDEF
555+
_ZSTD__SET_PARAMETER_TYPES_METHODDEF
519556

520557
{0}
521558
};
@@ -729,6 +766,9 @@ static int _zstd_exec(PyObject *module) {
729766
ADD_STR_TO_STATE_MACRO(write);
730767
ADD_STR_TO_STATE_MACRO(flush);
731768

769+
mod_state->CParameter_type = NULL;
770+
mod_state->DParameter_type = NULL;
771+
732772
/* Add variables to module */
733773
if (add_vars_to_module(module) < 0) {
734774
return -1;
@@ -812,6 +852,9 @@ _zstd_traverse(PyObject *module, visitproc visit, void *arg)
812852
Py_VISIT(mod_state->ZstdDecompressor_type);
813853

814854
Py_VISIT(mod_state->ZstdError);
855+
856+
Py_VISIT(mod_state->CParameter_type);
857+
Py_VISIT(mod_state->DParameter_type);
815858
return 0;
816859
}
817860

@@ -833,6 +876,9 @@ _zstd_clear(PyObject *module)
833876
Py_CLEAR(mod_state->ZstdDecompressor_type);
834877

835878
Py_CLEAR(mod_state->ZstdError);
879+
880+
Py_CLEAR(mod_state->CParameter_type);
881+
Py_CLEAR(mod_state->DParameter_type);
836882
return 0;
837883
}
838884

Modules/_zstd/_zstdmodule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ struct _zstd_state {
5454
PyTypeObject *ZstdCompressor_type;
5555
PyTypeObject *ZstdDecompressor_type;
5656
PyObject *ZstdError;
57+
58+
PyTypeObject *CParameter_type;
59+
PyTypeObject *DParameter_type;
5760
};
5861

5962
typedef struct {

Modules/_zstd/clinic/_zstdmodule.c.h

Lines changed: 75 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Modules/_zstd/compressor.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ _PyZstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options,
6767
Py_ssize_t pos = 0;
6868

6969
while (PyDict_Next(level_or_options, &pos, &key, &value)) {
70+
/* Check key type */
71+
if (Py_TYPE(key) == mod_state->DParameter_type) {
72+
PyErr_SetString(PyExc_TypeError,
73+
"Key of compression option dict should "
74+
"NOT be DecompressionParameter.");
75+
return -1;
76+
}
77+
7078
int key_v = PyLong_AsInt(key);
7179
if (key_v == -1 && PyErr_Occurred()) {
7280
PyErr_SetString(PyExc_ValueError,

Modules/_zstd/decompressor.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ _PyZstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
8080

8181
pos = 0;
8282
while (PyDict_Next(options, &pos, &key, &value)) {
83+
/* Check key type */
84+
if (Py_TYPE(key) == mod_state->CParameter_type) {
85+
PyErr_SetString(PyExc_TypeError,
86+
"Key of decompression options dict should "
87+
"NOT be CompressionParameter.");
88+
return -1;
89+
}
90+
8391
/* Both key & value should be 32-bit signed int */
8492
int key_v = PyLong_AsInt(key);
8593
if (key_v == -1 && PyErr_Occurred()) {

0 commit comments

Comments
 (0)