Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Lib/test/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2430,7 +2430,6 @@ def test_buffer_protocol(self):
self.assertEqual(f.write(arr), LENGTH)
self.assertEqual(f.tell(), LENGTH)

@unittest.skip("it fails for now, see gh-133885")
class FreeThreadingMethodTests(unittest.TestCase):

@unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
Expand Down
41 changes: 21 additions & 20 deletions Modules/_zstd/compressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class _zstd.ZstdCompressor "ZstdCompressor *" "&zstd_compressor_type_spec"
#include "_zstdmodule.h"
#include "buffer.h"
#include "zstddict.h"
#include "internal/pycore_lock.h" // PyMutex_IsLocked

#include <stddef.h> // offsetof()
#include <zstd.h> // ZSTD_*()
Expand All @@ -38,6 +39,9 @@ typedef struct {

/* Compression level */
int compression_level;

/* Lock to protect the compression context */
PyMutex lock;
} ZstdCompressor;

#define ZstdCompressor_CAST(op) ((ZstdCompressor *)op)
Expand Down Expand Up @@ -153,8 +157,6 @@ _get_CDict(ZstdDict *self, int compressionLevel)
PyObject *capsule;
ZSTD_CDict *cdict;

// TODO(emmatyping): refactor critical section code into a lock_held function
Py_BEGIN_CRITICAL_SECTION(self);

/* int level object */
level = PyLong_FromLong(compressionLevel);
Expand Down Expand Up @@ -212,7 +214,6 @@ _get_CDict(ZstdDict *self, int compressionLevel)
cdict = NULL;
success:
Py_XDECREF(level);
Py_END_CRITICAL_SECTION();
return cdict;
}

Expand Down Expand Up @@ -276,28 +277,22 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
}
/* Reference a prepared dictionary.
It overrides some compression context's parameters. */
Py_BEGIN_CRITICAL_SECTION(self);
zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict);
Py_END_CRITICAL_SECTION();
}
else if (type == DICT_TYPE_UNDIGESTED) {
/* Load a dictionary.
It doesn't override compression context's parameters. */
Py_BEGIN_CRITICAL_SECTION2(self, zd);
zstd_ret = ZSTD_CCtx_loadDictionary(
self->cctx,
PyBytes_AS_STRING(zd->dict_content),
Py_SIZE(zd->dict_content));
Py_END_CRITICAL_SECTION2();
}
else if (type == DICT_TYPE_PREFIX) {
/* Load a prefix */
Py_BEGIN_CRITICAL_SECTION2(self, zd);
zstd_ret = ZSTD_CCtx_refPrefix(
self->cctx,
PyBytes_AS_STRING(zd->dict_content),
Py_SIZE(zd->dict_content));
Py_END_CRITICAL_SECTION2();
}
else {
Py_UNREACHABLE();
Expand Down Expand Up @@ -339,6 +334,7 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level,

self->use_multithread = 0;
self->dict = NULL;
self->lock = (PyMutex){0};

/* Compression context */
self->cctx = ZSTD_createCCtx();
Expand Down Expand Up @@ -403,6 +399,10 @@ ZstdCompressor_dealloc(PyObject *ob)
ZSTD_freeCCtx(self->cctx);
}

if (PyMutex_IsLocked(&self->lock)) {
PyMutex_Unlock(&self->lock);
}

/* Py_XDECREF the dict after free the compression context */
Py_CLEAR(self->dict);

Expand All @@ -412,8 +412,8 @@ ZstdCompressor_dealloc(PyObject *ob)
}

static PyObject *
compress_impl(ZstdCompressor *self, Py_buffer *data,
ZSTD_EndDirective end_directive)
compress_lock_held(ZstdCompressor *self, Py_buffer *data,
ZSTD_EndDirective end_directive)
{
ZSTD_inBuffer in;
ZSTD_outBuffer out;
Expand Down Expand Up @@ -495,7 +495,7 @@ mt_continue_should_break(ZSTD_inBuffer *in, ZSTD_outBuffer *out)
#endif

static PyObject *
compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data)
compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data)
{
ZSTD_inBuffer in;
ZSTD_outBuffer out;
Expand Down Expand Up @@ -529,7 +529,7 @@ compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data)
goto error;
}

/* Like compress_impl(), output as much as possible. */
/* Like compress_lock_held(), output as much as possible. */
if (out.pos == out.size) {
if (_OutputBuffer_Grow(&buffer, &out) < 0) {
goto error;
Expand Down Expand Up @@ -588,14 +588,14 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
}

/* Thread-safe code */
Py_BEGIN_CRITICAL_SECTION(self);
PyMutex_Lock(&self->lock);

/* Compress */
if (self->use_multithread && mode == ZSTD_e_continue) {
ret = compress_mt_continue_impl(self, data);
ret = compress_mt_continue_lock_held(self, data);
}
else {
ret = compress_impl(self, data, mode);
ret = compress_lock_held(self, data, mode);
}

if (ret) {
Expand All @@ -607,7 +607,7 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
/* Resetting cctx's session never fail */
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
}
Py_END_CRITICAL_SECTION();
PyMutex_Unlock(&self->lock);

return ret;
}
Expand Down Expand Up @@ -642,8 +642,9 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
}

/* Thread-safe code */
Py_BEGIN_CRITICAL_SECTION(self);
ret = compress_impl(self, NULL, mode);
PyMutex_Lock(&self->lock);

ret = compress_lock_held(self, NULL, mode);

if (ret) {
self->last_mode = mode;
Expand All @@ -654,7 +655,7 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
/* Resetting cctx's session never fail */
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
}
Py_END_CRITICAL_SECTION();
PyMutex_Unlock(&self->lock);

return ret;
}
Expand Down
45 changes: 23 additions & 22 deletions Modules/_zstd/decompressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" "&zstd_decompressor_type_spec"
#include "_zstdmodule.h"
#include "buffer.h"
#include "zstddict.h"
#include "internal/pycore_lock.h" // PyMutex_IsLocked

#include <stdbool.h> // bool
#include <stddef.h> // offsetof()
Expand Down Expand Up @@ -45,6 +46,9 @@ typedef struct {
/* For ZstdDecompressor, 0 or 1.
1 means the end of the first frame has been reached. */
bool eof;

/* Lock to protect the decompression context */
PyMutex lock;
} ZstdDecompressor;

#define ZstdDecompressor_CAST(op) ((ZstdDecompressor *)op)
Expand All @@ -61,7 +65,6 @@ _get_DDict(ZstdDict *self)
return self->d_dict;
}

Py_BEGIN_CRITICAL_SECTION(self);
if (self->d_dict == NULL) {
/* Create ZSTD_DDict instance from dictionary content */
char *dict_buffer = PyBytes_AS_STRING(self->dict_content);
Expand All @@ -83,7 +86,6 @@ _get_DDict(ZstdDict *self)

/* Don't lose any exception */
ret = self->d_dict;
Py_END_CRITICAL_SECTION();

return ret;
}
Expand Down Expand Up @@ -134,9 +136,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
}

/* Set parameter to compression context */
Py_BEGIN_CRITICAL_SECTION(self);
zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v);
Py_END_CRITICAL_SECTION();

/* Check error */
if (ZSTD_isError(zstd_ret)) {
Expand Down Expand Up @@ -206,27 +206,21 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
return -1;
}
/* Reference a prepared dictionary */
Py_BEGIN_CRITICAL_SECTION(self);
zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict);
Py_END_CRITICAL_SECTION();
}
else if (type == DICT_TYPE_UNDIGESTED) {
/* Load a dictionary */
Py_BEGIN_CRITICAL_SECTION2(self, zd);
zstd_ret = ZSTD_DCtx_loadDictionary(
self->dctx,
PyBytes_AS_STRING(zd->dict_content),
Py_SIZE(zd->dict_content));
Py_END_CRITICAL_SECTION2();
}
else if (type == DICT_TYPE_PREFIX) {
/* Load a prefix */
Py_BEGIN_CRITICAL_SECTION2(self, zd);
zstd_ret = ZSTD_DCtx_refPrefix(
self->dctx,
PyBytes_AS_STRING(zd->dict_content),
Py_SIZE(zd->dict_content));
Py_END_CRITICAL_SECTION2();
}
else {
/* Impossible code path */
Expand Down Expand Up @@ -268,8 +262,8 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
Note, decompressing "an empty input" in any case will make it > 0.
*/
static PyObject *
decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
Py_ssize_t max_length)
decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in,
Py_ssize_t max_length)
{
size_t zstd_ret;
ZSTD_outBuffer out;
Expand Down Expand Up @@ -339,10 +333,8 @@ decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
}

static void
decompressor_reset_session(ZstdDecompressor *self)
decompressor_reset_session_lock_held(ZstdDecompressor *self)
{
// TODO(emmatyping): use _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED here
// and ensure lock is always held

/* Reset variables */
self->in_begin = 0;
Expand All @@ -359,7 +351,8 @@ decompressor_reset_session(ZstdDecompressor *self)
}

static PyObject *
stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length)
stream_decompress_lock_held(ZstdDecompressor *self, Py_buffer *data,
Py_ssize_t max_length)
{
ZSTD_inBuffer in;
PyObject *ret = NULL;
Expand Down Expand Up @@ -456,7 +449,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length
assert(in.pos == 0);

/* Decompress */
ret = decompress_impl(self, &in, max_length);
ret = decompress_lock_held(self, &in, max_length);
if (ret == NULL) {
goto error;
}
Expand Down Expand Up @@ -517,7 +510,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length

error:
/* Reset decompressor's states/session */
decompressor_reset_session(self);
decompressor_reset_session_lock_held(self);

Py_CLEAR(ret);
return NULL;
Expand Down Expand Up @@ -555,6 +548,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict,
self->unused_data = NULL;
self->eof = 0;
self->dict = NULL;
self->lock = (PyMutex){0};

/* needs_input flag */
self->needs_input = 1;
Expand Down Expand Up @@ -608,6 +602,10 @@ ZstdDecompressor_dealloc(PyObject *ob)
ZSTD_freeDCtx(self->dctx);
}

if (PyMutex_IsLocked(&self->lock)) {
PyMutex_Unlock(&self->lock);
}

/* Py_CLEAR the dict after free decompression context */
Py_CLEAR(self->dict);

Expand Down Expand Up @@ -639,7 +637,10 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
{
PyObject *ret;

PyMutex_Lock(&self->lock);

if (!self->eof) {
PyMutex_Unlock(&self->lock);
return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES);
}
else {
Expand All @@ -656,6 +657,7 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
}
}

PyMutex_Unlock(&self->lock);
return ret;
}

Expand Down Expand Up @@ -693,10 +695,9 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self,
{
PyObject *ret;
/* Thread-safe code */
Py_BEGIN_CRITICAL_SECTION(self);

ret = stream_decompress(self, data, max_length);
Py_END_CRITICAL_SECTION();
PyMutex_Lock(&self->lock);
ret = stream_decompress_lock_held(self, data, max_length);
PyMutex_Unlock(&self->lock);
return ret;
}

Expand Down
Loading