From db63cea7fb2353c5771c580a32b4270d9977f86f Mon Sep 17 00:00:00 2001 From: Udit Kumar Agarwal Date: Fri, 8 Aug 2025 10:25:13 -0700 Subject: [PATCH] [SYCL] Don't use `zstd` context across threads. (#19747) `ZSTDCompressor` holds `zstd` context as its only data members. The idea behind `GetSingletonInstance()` method was to re-use these contexts for subsequent compression and decompressions. Re-using context across (de)compression reduces system memory usage. However, `zstd` contexts are not meant to be used concurrently, therefore, this PR makes `ZSTDCompressor` object thread-local, instead of static. Relevant excerpt from zstd doc (https://facebook.github.io/zstd/zstd_manual.html): > When decompressing many times, > it is recommended to allocate a context only once, > and re-use it for each successive compression operation. > This will make workload friendlier for system's memory. > Use one context per thread for parallel execution. --- sycl/source/detail/compression.hpp | 4 ++- .../compression/CompressionTests.cpp | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/sycl/source/detail/compression.hpp b/sycl/source/detail/compression.hpp index 6362e6ec47884..b5185a45cfc17 100644 --- a/sycl/source/detail/compression.hpp +++ b/sycl/source/detail/compression.hpp @@ -33,7 +33,9 @@ class ZSTDCompressor { // Get the singleton instance of the ZSTDCompressor class. static ZSTDCompressor &GetSingletonInstance() { - static ZSTDCompressor instance; + // Use thread_local to ensure that each thread has its own instance. + // This avoids issues with concurrent access to the ZSTD contexts. + thread_local ZSTDCompressor instance; return instance; } diff --git a/sycl/unittests/compression/CompressionTests.cpp b/sycl/unittests/compression/CompressionTests.cpp index 2c30cace1b574..55e99dd3685fc 100644 --- a/sycl/unittests/compression/CompressionTests.cpp +++ b/sycl/unittests/compression/CompressionTests.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "../thread_safety/ThreadUtils.h" #include #include @@ -79,3 +80,36 @@ TEST(CompressionTest, EmptyInputTest) { std::string decompressedStr((char *)decompressedData.get(), decompressedSize); ASSERT_EQ(input, decompressedStr); } + +// Test to check for concurrent compression and decompression. +TEST(CompressionTest, ConcurrentCompressionDecompression) { + std::string data = "Concurrent compression and decompression test!"; + + constexpr size_t ThreadCount = 20; + + Barrier b(ThreadCount); + { + auto testCompressDecompress = [&](size_t threadId) { + b.wait(); + size_t compressedDataSize = 0; + auto compressedData = ZSTDCompressor::CompressBlob( + data.c_str(), data.size(), compressedDataSize, 3); + + ASSERT_NE(compressedData, nullptr); + ASSERT_GT(compressedDataSize, (size_t)0); + + size_t decompressedSize = 0; + auto decompressedData = ZSTDCompressor::DecompressBlob( + compressedData.get(), compressedDataSize, decompressedSize); + + ASSERT_NE(decompressedData, nullptr); + ASSERT_GT(decompressedSize, (size_t)0); + + std::string decompressedStr((char *)decompressedData.get(), + decompressedSize); + ASSERT_EQ(data, decompressedStr); + }; + + ::ThreadPool MPool(ThreadCount, testCompressDecompress); + } +}