Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions Lib/test/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2430,10 +2430,8 @@ def test_buffer_protocol(self):
self.assertEqual(f.write(arr), LENGTH)
self.assertEqual(f.tell(), LENGTH)

@unittest.skip("it fails for now, see gh-133885")
class FreeThreadingMethodTests(unittest.TestCase):

@unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_compress_locking(self):
Expand Down Expand Up @@ -2470,7 +2468,6 @@ def run_method(method, input_data, output_data):
actual = b''.join(output) + rest2
self.assertEqual(expected, actual)

@unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_decompress_locking(self):
Expand Down Expand Up @@ -2506,6 +2503,27 @@ def run_method(method, input_data, output_data):
actual = b''.join(output)
self.assertEqual(expected, actual)

@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_compress_shared_dict(self):
num_threads = 8

def run_method(b):
level = threading.get_ident() % 2
# sync threads to increase chance of contention on
# capsule storing dictionary levels
b.wait()
ZstdCompressor(level=level, zstd_dict=TRAINED_DICT.as_digested_dict)
threads = []

b = threading.Barrier(num_threads)
for i in range(num_threads):
thread = threading.Thread(target=run_method, args=(b,))

threads.append(thread)

with threading_helper.start_threads(threads):
pass


if __name__ == "__main__":
Expand Down
60 changes: 31 additions & 29 deletions Modules/_zstd/compressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class _zstd.ZstdCompressor "ZstdCompressor *" "&zstd_compressor_type_spec"
#include "_zstdmodule.h"
#include "buffer.h"
#include "zstddict.h"
#include "internal/pycore_lock.h" // PyMutex_IsLocked

#include <stddef.h> // offsetof()
#include <zstd.h> // ZSTD_*()
Expand All @@ -38,6 +39,9 @@ typedef struct {

/* Compression level */
int compression_level;

/* Lock to protect the compression context */
PyMutex lock;
} ZstdCompressor;

#define ZstdCompressor_CAST(op) ((ZstdCompressor *)op)
Expand Down Expand Up @@ -153,8 +157,6 @@ _get_CDict(ZstdDict *self, int compressionLevel)
PyObject *capsule;
ZSTD_CDict *cdict;

// TODO(emmatyping): refactor critical section code into a lock_held function
Py_BEGIN_CRITICAL_SECTION(self);

/* int level object */
level = PyLong_FromLong(compressionLevel);
Expand All @@ -163,12 +165,12 @@ _get_CDict(ZstdDict *self, int compressionLevel)
}

/* Get PyCapsule object from self->c_dicts */
capsule = PyDict_GetItemWithError(self->c_dicts, level);
if (capsule == NULL) {
if (PyErr_Occurred()) {
goto error;
}
int result = PyDict_GetItemRef(self->c_dicts, level, &capsule);
if (result < 0) {
goto error;
}

if (capsule == NULL) {
/* Create ZSTD_CDict instance */
char *dict_buffer = PyBytes_AS_STRING(self->dict_content);
Py_ssize_t dict_len = Py_SIZE(self->dict_content);
Expand All @@ -195,24 +197,26 @@ _get_CDict(ZstdDict *self, int compressionLevel)
goto error;
}

/* Add PyCapsule object to self->c_dicts */
if (PyDict_SetItem(self->c_dicts, level, capsule) < 0) {
Py_DECREF(capsule);
/* Add PyCapsule object to self->c_dicts if not already inserted */
PyObject *capsule_value;
int result = PyDict_SetDefaultRef(self->c_dicts, level, capsule,
&capsule_value);
if (result < 0) {
goto error;
}
Py_DECREF(capsule);
Py_XDECREF(capsule_value);
}
else {
/* ZSTD_CDict instance already exists */
cdict = PyCapsule_GetPointer(capsule, NULL);
Py_DECREF(capsule);
}
goto success;

error:
cdict = NULL;
success:
Py_XDECREF(level);
Py_END_CRITICAL_SECTION();
return cdict;
}

Expand Down Expand Up @@ -276,28 +280,22 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
}
/* Reference a prepared dictionary.
It overrides some compression context's parameters. */
Py_BEGIN_CRITICAL_SECTION(self);
zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict);
Py_END_CRITICAL_SECTION();
}
else if (type == DICT_TYPE_UNDIGESTED) {
/* Load a dictionary.
It doesn't override compression context's parameters. */
Py_BEGIN_CRITICAL_SECTION2(self, zd);
zstd_ret = ZSTD_CCtx_loadDictionary(
self->cctx,
PyBytes_AS_STRING(zd->dict_content),
Py_SIZE(zd->dict_content));
Py_END_CRITICAL_SECTION2();
}
else if (type == DICT_TYPE_PREFIX) {
/* Load a prefix */
Py_BEGIN_CRITICAL_SECTION2(self, zd);
zstd_ret = ZSTD_CCtx_refPrefix(
self->cctx,
PyBytes_AS_STRING(zd->dict_content),
Py_SIZE(zd->dict_content));
Py_END_CRITICAL_SECTION2();
}
else {
Py_UNREACHABLE();
Expand Down Expand Up @@ -339,6 +337,7 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level,

self->use_multithread = 0;
self->dict = NULL;
self->lock = (PyMutex){0};

/* Compression context */
self->cctx = ZSTD_createCCtx();
Expand Down Expand Up @@ -403,6 +402,8 @@ ZstdCompressor_dealloc(PyObject *ob)
ZSTD_freeCCtx(self->cctx);
}

assert(!PyMutex_IsLocked(&self->lock));

/* Py_XDECREF the dict after free the compression context */
Py_CLEAR(self->dict);

Expand All @@ -412,8 +413,8 @@ ZstdCompressor_dealloc(PyObject *ob)
}

static PyObject *
compress_impl(ZstdCompressor *self, Py_buffer *data,
ZSTD_EndDirective end_directive)
compress_lock_held(ZstdCompressor *self, Py_buffer *data,
ZSTD_EndDirective end_directive)
{
ZSTD_inBuffer in;
ZSTD_outBuffer out;
Expand Down Expand Up @@ -495,7 +496,7 @@ mt_continue_should_break(ZSTD_inBuffer *in, ZSTD_outBuffer *out)
#endif

static PyObject *
compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data)
compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data)
{
ZSTD_inBuffer in;
ZSTD_outBuffer out;
Expand Down Expand Up @@ -529,7 +530,7 @@ compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data)
goto error;
}

/* Like compress_impl(), output as much as possible. */
/* Like compress_lock_held(), output as much as possible. */
if (out.pos == out.size) {
if (_OutputBuffer_Grow(&buffer, &out) < 0) {
goto error;
Expand Down Expand Up @@ -588,14 +589,14 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
}

/* Thread-safe code */
Py_BEGIN_CRITICAL_SECTION(self);
PyMutex_Lock(&self->lock);

/* Compress */
if (self->use_multithread && mode == ZSTD_e_continue) {
ret = compress_mt_continue_impl(self, data);
ret = compress_mt_continue_lock_held(self, data);
}
else {
ret = compress_impl(self, data, mode);
ret = compress_lock_held(self, data, mode);
}

if (ret) {
Expand All @@ -607,7 +608,7 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
/* Resetting cctx's session never fail */
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
}
Py_END_CRITICAL_SECTION();
PyMutex_Unlock(&self->lock);

return ret;
}
Expand Down Expand Up @@ -642,8 +643,9 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
}

/* Thread-safe code */
Py_BEGIN_CRITICAL_SECTION(self);
ret = compress_impl(self, NULL, mode);
PyMutex_Lock(&self->lock);

ret = compress_lock_held(self, NULL, mode);

if (ret) {
self->last_mode = mode;
Expand All @@ -654,7 +656,7 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
/* Resetting cctx's session never fail */
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
}
Py_END_CRITICAL_SECTION();
PyMutex_Unlock(&self->lock);

return ret;
}
Expand Down
Loading
Loading