diff --git a/sycl/source/detail/compression.hpp b/sycl/source/detail/compression.hpp index 6362e6ec47884..b5185a45cfc17 100644 --- a/sycl/source/detail/compression.hpp +++ b/sycl/source/detail/compression.hpp @@ -33,7 +33,9 @@ class ZSTDCompressor { // Get the singleton instance of the ZSTDCompressor class. static ZSTDCompressor &GetSingletonInstance() { - static ZSTDCompressor instance; + // Use thread_local to ensure that each thread has its own instance. + // This avoids issues with concurrent access to the ZSTD contexts. + thread_local ZSTDCompressor instance; return instance; } diff --git a/sycl/unittests/compression/CompressionTests.cpp b/sycl/unittests/compression/CompressionTests.cpp index 2c30cace1b574..55e99dd3685fc 100644 --- a/sycl/unittests/compression/CompressionTests.cpp +++ b/sycl/unittests/compression/CompressionTests.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "../thread_safety/ThreadUtils.h" #include #include @@ -79,3 +80,36 @@ TEST(CompressionTest, EmptyInputTest) { std::string decompressedStr((char *)decompressedData.get(), decompressedSize); ASSERT_EQ(input, decompressedStr); } + +// Test to check for concurrent compression and decompression. +TEST(CompressionTest, ConcurrentCompressionDecompression) { + std::string data = "Concurrent compression and decompression test!"; + + constexpr size_t ThreadCount = 20; + + Barrier b(ThreadCount); + { + auto testCompressDecompress = [&](size_t threadId) { + b.wait(); + size_t compressedDataSize = 0; + auto compressedData = ZSTDCompressor::CompressBlob( + data.c_str(), data.size(), compressedDataSize, 3); + + ASSERT_NE(compressedData, nullptr); + ASSERT_GT(compressedDataSize, (size_t)0); + + size_t decompressedSize = 0; + auto decompressedData = ZSTDCompressor::DecompressBlob( + compressedData.get(), compressedDataSize, decompressedSize); + + ASSERT_NE(decompressedData, nullptr); + ASSERT_GT(decompressedSize, (size_t)0); + + std::string decompressedStr((char *)decompressedData.get(), + decompressedSize); + ASSERT_EQ(data, decompressedStr); + }; + + ::ThreadPool MPool(ThreadCount, testCompressDecompress); + } +}