diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py index 53ca592ea38828..a510b7a3d5d552 100644 --- a/Lib/test/test_zstd.py +++ b/Lib/test/test_zstd.py @@ -2430,83 +2430,54 @@ def test_buffer_protocol(self): self.assertEqual(f.write(arr), LENGTH) self.assertEqual(f.tell(), LENGTH) -@unittest.skip("it fails for now, see gh-133885") + class FreeThreadingMethodTests(unittest.TestCase): @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') @threading_helper.reap_threads @threading_helper.requires_working_threading() - def test_compress_locking(self): + def test_compressor_cannot_share(self): input = b'a'* (16*_1K) num_threads = 8 - comp = ZstdCompressor() - parts = [] - for _ in range(num_threads): - res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK) - if res: - parts.append(res) - rest1 = comp.flush() - expected = b''.join(parts) + rest1 comp = ZstdCompressor() - output = [] - def run_method(method, input_data, output_data): - res = method(input_data, ZstdCompressor.FLUSH_BLOCK) - if res: - output_data.append(res) + def run_method(method, input_data): + with self.assertRaises(RuntimeError): + method(input_data, ZstdCompressor.FLUSH_BLOCK) threads = [] for i in range(num_threads): - thread = threading.Thread(target=run_method, args=(comp.compress, input, output)) + thread = threading.Thread(target=run_method, args=(comp.compress, input)) threads.append(thread) with threading_helper.start_threads(threads): pass - rest2 = comp.flush() - self.assertEqual(rest1, rest2) - actual = b''.join(output) + rest2 - self.assertEqual(expected, actual) - @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') @threading_helper.reap_threads @threading_helper.requires_working_threading() - def test_decompress_locking(self): + def test_decompressor_cannot_share(self): input = compress(b'a'* (16*_1K)) num_threads = 8 # to ensure we decompress over multiple calls, set maxsize window_size = _1K * 16//num_threads - decomp = ZstdDecompressor() - parts = [] - for _ in range(num_threads): - res = decomp.decompress(input, window_size) - if res: - parts.append(res) - expected = b''.join(parts) - comp = ZstdDecompressor() - output = [] - def run_method(method, input_data, output_data): - res = method(input_data, window_size) - if res: - output_data.append(res) + def run_method(method, input_data): + with self.assertRaises(RuntimeError): + method(input_data, window_size) threads = [] for i in range(num_threads): - thread = threading.Thread(target=run_method, args=(comp.decompress, input, output)) + thread = threading.Thread(target=run_method, args=(comp.decompress, input)) threads.append(thread) with threading_helper.start_threads(threads): pass - actual = b''.join(output) - self.assertEqual(expected, actual) - - if __name__ == "__main__": unittest.main() diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h index b36486442c6567..31fef0ec3d6022 100644 --- a/Modules/_zstd/_zstdmodule.h +++ b/Modules/_zstd/_zstdmodule.h @@ -52,4 +52,21 @@ extern void set_parameter_error(const _zstd_state* const state, int is_compress, int key_v, int value_v); +static inline int +check_object_shared(PyObject *ob, char *type) +{ +#if defined(Py_GIL_DISABLED) + if (!_Py_IsOwnedByCurrentThread(ob)) + { + PyErr_Format(PyExc_RuntimeError, + "%s cannot be shared across multiple threads.", + type); + return 1; + } + return 0; +#else + return 0; +#endif +} + #endif // !ZSTD_MODULE_H diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index 38baee2be1e95b..6a788d6f6d5c62 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -575,6 +575,12 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, { PyObject *ret; + /* Check we are on the same thread as the compressor was created */ + if (check_object_shared((PyObject *)self, "ZstdCompressor") > 0) + { + return NULL; + } + /* Check mode value */ if (mode != ZSTD_e_continue && mode != ZSTD_e_flush && @@ -587,9 +593,6 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, return NULL; } - /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); - /* Compress */ if (self->use_multithread && mode == ZSTD_e_continue) { ret = compress_mt_continue_impl(self, data); @@ -607,7 +610,6 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, /* Resetting cctx's session never fail */ ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only); } - Py_END_CRITICAL_SECTION(); return ret; } @@ -632,6 +634,12 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) { PyObject *ret; + /* Check we are on the same thread as the compressor was created */ + if (check_object_shared((PyObject *)self, "ZstdCompressor") > 0) + { + return NULL; + } + /* Check mode value */ if (mode != ZSTD_e_end && mode != ZSTD_e_flush) { PyErr_SetString(PyExc_ValueError, @@ -641,8 +649,6 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) return NULL; } - /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); ret = compress_impl(self, NULL, mode); if (ret) { @@ -654,7 +660,6 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) /* Resetting cctx's session never fail */ ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only); } - Py_END_CRITICAL_SECTION(); return ret; } diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 58f9c9f804e549..645e3fb954b777 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -639,6 +639,12 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self) { PyObject *ret; + /* Check we are on the same thread as the decompressor was created */ + if (check_object_shared((PyObject *)self, "ZstdDecompressor") > 0) + { + return NULL; + } + if (!self->eof) { return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES); } @@ -692,11 +698,12 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self, /*[clinic end generated code: output=a4302b3c940dbec6 input=6463dfdf98091caa]*/ { PyObject *ret; - /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); - + /* Check we are on the same thread as the decompressor was created */ + if (check_object_shared((PyObject *)self, "ZstdDecompressor") > 0) + { + return NULL; + } ret = stream_decompress(self, data, max_length); - Py_END_CRITICAL_SECTION(); return ret; }