Skip to content

Commit 1c3071d

Browse files
committed
Disallow sharing zstd (de)compressor contexts
According to the zstd author, it is not possible to share Zstandard objects across thread boundaries. To resolve this, we check if the object was created on the current thread and raise a RuntimeError if it is not. The tests are updated to ensure that the error is raised if a (de)compression context is shared across threads.
1 parent de70614 commit 1c3071d

File tree

4 files changed

+51
-51
lines changed

4 files changed

+51
-51
lines changed

Lib/test/test_zstd.py

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,83 +2430,54 @@ def test_buffer_protocol(self):
24302430
self.assertEqual(f.write(arr), LENGTH)
24312431
self.assertEqual(f.tell(), LENGTH)
24322432

2433-
@unittest.skip("it fails for now, see gh-133885")
2433+
24342434
class FreeThreadingMethodTests(unittest.TestCase):
24352435

24362436
@unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
24372437
@threading_helper.reap_threads
24382438
@threading_helper.requires_working_threading()
2439-
def test_compress_locking(self):
2439+
def test_compressor_cannot_share(self):
24402440
input = b'a'* (16*_1K)
24412441
num_threads = 8
24422442

2443-
comp = ZstdCompressor()
2444-
parts = []
2445-
for _ in range(num_threads):
2446-
res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK)
2447-
if res:
2448-
parts.append(res)
2449-
rest1 = comp.flush()
2450-
expected = b''.join(parts) + rest1
24512443

24522444
comp = ZstdCompressor()
2453-
output = []
2454-
def run_method(method, input_data, output_data):
2455-
res = method(input_data, ZstdCompressor.FLUSH_BLOCK)
2456-
if res:
2457-
output_data.append(res)
2445+
def run_method(method, input_data):
2446+
with self.assertRaises(RuntimeError):
2447+
method(input_data, ZstdCompressor.FLUSH_BLOCK)
24582448
threads = []
24592449

24602450
for i in range(num_threads):
2461-
thread = threading.Thread(target=run_method, args=(comp.compress, input, output))
2451+
thread = threading.Thread(target=run_method, args=(comp.compress, input))
24622452

24632453
threads.append(thread)
24642454

24652455
with threading_helper.start_threads(threads):
24662456
pass
24672457

2468-
rest2 = comp.flush()
2469-
self.assertEqual(rest1, rest2)
2470-
actual = b''.join(output) + rest2
2471-
self.assertEqual(expected, actual)
2472-
24732458
@unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
24742459
@threading_helper.reap_threads
24752460
@threading_helper.requires_working_threading()
2476-
def test_decompress_locking(self):
2461+
def test_decompressor_cannot_share(self):
24772462
input = compress(b'a'* (16*_1K))
24782463
num_threads = 8
24792464
# to ensure we decompress over multiple calls, set maxsize
24802465
window_size = _1K * 16//num_threads
24812466

2482-
decomp = ZstdDecompressor()
2483-
parts = []
2484-
for _ in range(num_threads):
2485-
res = decomp.decompress(input, window_size)
2486-
if res:
2487-
parts.append(res)
2488-
expected = b''.join(parts)
2489-
24902467
comp = ZstdDecompressor()
2491-
output = []
2492-
def run_method(method, input_data, output_data):
2493-
res = method(input_data, window_size)
2494-
if res:
2495-
output_data.append(res)
2468+
def run_method(method, input_data):
2469+
with self.assertRaises(RuntimeError):
2470+
method(input_data, window_size)
24962471
threads = []
24972472

24982473
for i in range(num_threads):
2499-
thread = threading.Thread(target=run_method, args=(comp.decompress, input, output))
2474+
thread = threading.Thread(target=run_method, args=(comp.decompress, input))
25002475

25012476
threads.append(thread)
25022477

25032478
with threading_helper.start_threads(threads):
25042479
pass
25052480

2506-
actual = b''.join(output)
2507-
self.assertEqual(expected, actual)
2508-
2509-
25102481

25112482
if __name__ == "__main__":
25122483
unittest.main()

Modules/_zstd/_zstdmodule.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,21 @@ extern void
5252
set_parameter_error(const _zstd_state* const state, int is_compress,
5353
int key_v, int value_v);
5454

55+
static inline int
56+
check_object_shared(PyObject *ob, char *type)
57+
{
58+
#if defined(Py_GIL_DISABLED)
59+
if (!_Py_IsOwnedByCurrentThread(ob))
60+
{
61+
PyErr_Format(PyExc_RuntimeError,
62+
"%s cannot be shared across multiple threads.",
63+
type);
64+
return 1;
65+
}
66+
return 0;
67+
#else
68+
return 0;
69+
#endif
70+
}
71+
5572
#endif // !ZSTD_MODULE_H

Modules/_zstd/compressor.c

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,12 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
575575
{
576576
PyObject *ret;
577577

578+
/* Check we are on the same thread as the compressor was created */
579+
if (check_object_shared((PyObject *)self, "ZstdCompressor") > 0)
580+
{
581+
return NULL;
582+
}
583+
578584
/* Check mode value */
579585
if (mode != ZSTD_e_continue &&
580586
mode != ZSTD_e_flush &&
@@ -587,9 +593,6 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
587593
return NULL;
588594
}
589595

590-
/* Thread-safe code */
591-
Py_BEGIN_CRITICAL_SECTION(self);
592-
593596
/* Compress */
594597
if (self->use_multithread && mode == ZSTD_e_continue) {
595598
ret = compress_mt_continue_impl(self, data);
@@ -607,7 +610,6 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
607610
/* Resetting cctx's session never fail */
608611
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
609612
}
610-
Py_END_CRITICAL_SECTION();
611613

612614
return ret;
613615
}
@@ -632,6 +634,12 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
632634
{
633635
PyObject *ret;
634636

637+
/* Check we are on the same thread as the compressor was created */
638+
if (check_object_shared((PyObject *)self, "ZstdCompressor") > 0)
639+
{
640+
return NULL;
641+
}
642+
635643
/* Check mode value */
636644
if (mode != ZSTD_e_end && mode != ZSTD_e_flush) {
637645
PyErr_SetString(PyExc_ValueError,
@@ -641,8 +649,6 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
641649
return NULL;
642650
}
643651

644-
/* Thread-safe code */
645-
Py_BEGIN_CRITICAL_SECTION(self);
646652
ret = compress_impl(self, NULL, mode);
647653

648654
if (ret) {
@@ -654,7 +660,6 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
654660
/* Resetting cctx's session never fail */
655661
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
656662
}
657-
Py_END_CRITICAL_SECTION();
658663

659664
return ret;
660665
}

Modules/_zstd/decompressor.c

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,12 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
639639
{
640640
PyObject *ret;
641641

642+
/* Check we are on the same thread as the decompressor was created */
643+
if (check_object_shared((PyObject *)self, "ZstdDecompressor") > 0)
644+
{
645+
return NULL;
646+
}
647+
642648
if (!self->eof) {
643649
return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES);
644650
}
@@ -692,11 +698,12 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self,
692698
/*[clinic end generated code: output=a4302b3c940dbec6 input=6463dfdf98091caa]*/
693699
{
694700
PyObject *ret;
695-
/* Thread-safe code */
696-
Py_BEGIN_CRITICAL_SECTION(self);
697-
701+
/* Check we are on the same thread as the decompressor was created */
702+
if (check_object_shared((PyObject *)self, "ZstdDecompressor") > 0)
703+
{
704+
return NULL;
705+
}
698706
ret = stream_decompress(self, data, max_length);
699-
Py_END_CRITICAL_SECTION();
700707
return ret;
701708
}
702709

0 commit comments

Comments
 (0)