@@ -17,6 +17,7 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" "&zstd_decompressor_type_spec"
1717#include "_zstdmodule.h"
1818#include "buffer.h"
1919#include "zstddict.h"
20+ #include "internal/pycore_lock.h" // PyMutex_IsLocked
2021
2122#include <stdbool.h> // bool
2223#include <stddef.h> // offsetof()
@@ -45,6 +46,9 @@ typedef struct {
4546 /* For ZstdDecompressor, 0 or 1.
4647 1 means the end of the first frame has been reached. */
4748 bool eof ;
49+
50+ /* Lock to protect the decompression context */
51+ PyMutex lock ;
4852} ZstdDecompressor ;
4953
5054#define ZstdDecompressor_CAST (op ) ((ZstdDecompressor *)op)
@@ -61,7 +65,6 @@ _get_DDict(ZstdDict *self)
6165 return self -> d_dict ;
6266 }
6367
64- Py_BEGIN_CRITICAL_SECTION (self );
6568 if (self -> d_dict == NULL ) {
6669 /* Create ZSTD_DDict instance from dictionary content */
6770 char * dict_buffer = PyBytes_AS_STRING (self -> dict_content );
@@ -83,7 +86,6 @@ _get_DDict(ZstdDict *self)
8386
8487 /* Don't lose any exception */
8588 ret = self -> d_dict ;
86- Py_END_CRITICAL_SECTION ();
8789
8890 return ret ;
8991}
@@ -134,9 +136,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
134136 }
135137
136138 /* Set parameter to compression context */
137- Py_BEGIN_CRITICAL_SECTION (self );
138139 zstd_ret = ZSTD_DCtx_setParameter (self -> dctx , key_v , value_v );
139- Py_END_CRITICAL_SECTION ();
140140
141141 /* Check error */
142142 if (ZSTD_isError (zstd_ret )) {
@@ -206,27 +206,21 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
206206 return -1 ;
207207 }
208208 /* Reference a prepared dictionary */
209- Py_BEGIN_CRITICAL_SECTION (self );
210209 zstd_ret = ZSTD_DCtx_refDDict (self -> dctx , d_dict );
211- Py_END_CRITICAL_SECTION ();
212210 }
213211 else if (type == DICT_TYPE_UNDIGESTED ) {
214212 /* Load a dictionary */
215- Py_BEGIN_CRITICAL_SECTION2 (self , zd );
216213 zstd_ret = ZSTD_DCtx_loadDictionary (
217214 self -> dctx ,
218215 PyBytes_AS_STRING (zd -> dict_content ),
219216 Py_SIZE (zd -> dict_content ));
220- Py_END_CRITICAL_SECTION2 ();
221217 }
222218 else if (type == DICT_TYPE_PREFIX ) {
223219 /* Load a prefix */
224- Py_BEGIN_CRITICAL_SECTION2 (self , zd );
225220 zstd_ret = ZSTD_DCtx_refPrefix (
226221 self -> dctx ,
227222 PyBytes_AS_STRING (zd -> dict_content ),
228223 Py_SIZE (zd -> dict_content ));
229- Py_END_CRITICAL_SECTION2 ();
230224 }
231225 else {
232226 /* Impossible code path */
@@ -268,8 +262,8 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
268262 Note, decompressing "an empty input" in any case will make it > 0.
269263*/
270264static PyObject *
271- decompress_impl (ZstdDecompressor * self , ZSTD_inBuffer * in ,
272- Py_ssize_t max_length )
265+ decompress_lock_held (ZstdDecompressor * self , ZSTD_inBuffer * in ,
266+ Py_ssize_t max_length )
273267{
274268 size_t zstd_ret ;
275269 ZSTD_outBuffer out ;
@@ -339,10 +333,8 @@ decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
339333}
340334
341335static void
342- decompressor_reset_session (ZstdDecompressor * self )
336+ decompressor_reset_session_lock_held (ZstdDecompressor * self )
343337{
344- // TODO(emmatyping): use _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED here
345- // and ensure lock is always held
346338
347339 /* Reset variables */
348340 self -> in_begin = 0 ;
@@ -359,7 +351,8 @@ decompressor_reset_session(ZstdDecompressor *self)
359351}
360352
361353static PyObject *
362- stream_decompress (ZstdDecompressor * self , Py_buffer * data , Py_ssize_t max_length )
354+ stream_decompress_lock_held (ZstdDecompressor * self , Py_buffer * data ,
355+ Py_ssize_t max_length )
363356{
364357 ZSTD_inBuffer in ;
365358 PyObject * ret = NULL ;
@@ -456,7 +449,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length
456449 assert (in .pos == 0 );
457450
458451 /* Decompress */
459- ret = decompress_impl (self , & in , max_length );
452+ ret = decompress_lock_held (self , & in , max_length );
460453 if (ret == NULL ) {
461454 goto error ;
462455 }
@@ -517,7 +510,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length
517510
518511error :
519512 /* Reset decompressor's states/session */
520- decompressor_reset_session (self );
513+ decompressor_reset_session_lock_held (self );
521514
522515 Py_CLEAR (ret );
523516 return NULL ;
@@ -555,6 +548,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict,
555548 self -> unused_data = NULL ;
556549 self -> eof = 0 ;
557550 self -> dict = NULL ;
551+ self -> lock = (PyMutex ){0 };
558552
559553 /* needs_input flag */
560554 self -> needs_input = 1 ;
@@ -608,6 +602,10 @@ ZstdDecompressor_dealloc(PyObject *ob)
608602 ZSTD_freeDCtx (self -> dctx );
609603 }
610604
605+ if (PyMutex_IsLocked (& self -> lock )) {
606+ PyMutex_Unlock (& self -> lock );
607+ }
608+
611609 /* Py_CLEAR the dict after free decompression context */
612610 Py_CLEAR (self -> dict );
613611
@@ -639,7 +637,10 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
639637{
640638 PyObject * ret ;
641639
640+ PyMutex_Lock (& self -> lock );
641+
642642 if (!self -> eof ) {
643+ PyMutex_Unlock (& self -> lock );
643644 return Py_GetConstant (Py_CONSTANT_EMPTY_BYTES );
644645 }
645646 else {
@@ -656,6 +657,7 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
656657 }
657658 }
658659
660+ PyMutex_Unlock (& self -> lock );
659661 return ret ;
660662}
661663
@@ -693,10 +695,9 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self,
693695{
694696 PyObject * ret ;
695697 /* Thread-safe code */
696- Py_BEGIN_CRITICAL_SECTION (self );
697-
698- ret = stream_decompress (self , data , max_length );
699- Py_END_CRITICAL_SECTION ();
698+ PyMutex_Lock (& self -> lock );
699+ ret = stream_decompress_lock_held (self , data , max_length );
700+ PyMutex_Unlock (& self -> lock );
700701 return ret ;
701702}
702703
0 commit comments