diff --git a/sycl/source/detail/device_binary_image.cpp b/sycl/source/detail/device_binary_image.cpp index 0b58e9c5a7279..f339b7280fbce 100644 --- a/sycl/source/detail/device_binary_image.cpp +++ b/sycl/source/detail/device_binary_image.cpp @@ -710,22 +710,29 @@ CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage( static_cast(Bin->BinaryEnd - Bin->BinaryStart)); } +// std::call_once ensures that this function is thread_safe and prevents +// race during image decompression. void CompressedRTDeviceBinaryImage::Decompress() { + auto DecompressFunc = [&]() { + size_t CompressedDataSize = + static_cast(Bin->BinaryEnd - Bin->BinaryStart); - size_t CompressedDataSize = - static_cast(Bin->BinaryEnd - Bin->BinaryStart); + size_t DecompressedSize = 0; + m_DecompressedData = ZSTDCompressor::DecompressBlob( + reinterpret_cast(Bin->BinaryStart), CompressedDataSize, + DecompressedSize); - size_t DecompressedSize = 0; - m_DecompressedData = ZSTDCompressor::DecompressBlob( - reinterpret_cast(Bin->BinaryStart), CompressedDataSize, - DecompressedSize); + Bin->BinaryStart = + reinterpret_cast(m_DecompressedData.get()); + Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize; - Bin->BinaryStart = - reinterpret_cast(m_DecompressedData.get()); - Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize; + Bin->Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize()); + Format = static_cast(Bin->Format); - Bin->Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize()); - Format = static_cast(Bin->Format); + m_IsCompressed.store(false); + }; + + std::call_once(m_InitFlag, DecompressFunc); } CompressedRTDeviceBinaryImage::~CompressedRTDeviceBinaryImage() { diff --git a/sycl/source/detail/device_binary_image.hpp b/sycl/source/detail/device_binary_image.hpp index ac4fb92d3f9a8..075229effb3ec 100644 --- a/sycl/source/detail/device_binary_image.hpp +++ b/sycl/source/detail/device_binary_image.hpp @@ -19,6 +19,7 @@ #include #include #include +#include namespace sycl { inline namespace _V1 { @@ -321,7 +322,8 @@ class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage { return m_ImageSize; } - bool IsCompressed() const { return m_DecompressedData.get() == nullptr; } + bool IsCompressed() const { return m_IsCompressed.load(); } + void print() const override { RTDeviceBinaryImage::print(); std::cerr << " COMPRESSED\n"; @@ -330,6 +332,10 @@ class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage { private: std::unique_ptr m_DecompressedData; size_t m_ImageSize = 0; + + // Flag to ensure decompression happens only once. + std::once_flag m_InitFlag; + std::atomic m_IsCompressed{true}; }; #endif // SYCL_RT_ZSTD_AVAILABLE diff --git a/sycl/unittests/compression/CompressionTests.cpp b/sycl/unittests/compression/CompressionTests.cpp index 55e99dd3685fc..9ce8dda9bc830 100644 --- a/sycl/unittests/compression/CompressionTests.cpp +++ b/sycl/unittests/compression/CompressionTests.cpp @@ -8,6 +8,7 @@ #include "../thread_safety/ThreadUtils.h" #include +#include #include #include @@ -113,3 +114,61 @@ TEST(CompressionTest, ConcurrentCompressionDecompression) { ::ThreadPool MPool(ThreadCount, testCompressDecompress); } } + +// Test to decompress CompressedRTDeviceImage using multiple threads. +// The idea behind this test is to ensure that a device image is +// decompressed only once even if multiple threads try to decompress +// it at the same time. +TEST(CompressionTest, ConcurrentDecompressionOfDeviceImage) { + // Data to compress. + std::string data = "Hello World! I'm about to get compressed :P"; + + // Compress this data. + size_t compressedSize = 0; + auto compressedData = ZSTDCompressor::CompressBlob(data.c_str(), data.size(), + compressedSize, 1); + + unsigned char *compressedDataPtr = + reinterpret_cast(compressedData.get()); + + const char *EntryName = "Entry"; + _sycl_offload_entry_struct EntryStruct = { + /*addr*/ nullptr, const_cast(EntryName), strlen(EntryName), + /*flags*/ 0, /*reserved*/ 0}; + sycl_device_binary_struct BinStruct{/*Version*/ 1, + /*Kind*/ 4, + /*Format*/ SYCL_DEVICE_BINARY_TYPE_SPIRV, + /*DeviceTargetSpec*/ nullptr, + /*CompileOptions*/ nullptr, + /*LinkOptions*/ nullptr, + /*ManifestStart*/ nullptr, + /*ManifestEnd*/ nullptr, + /*BinaryStart*/ compressedDataPtr, + /*BinaryEnd*/ compressedDataPtr + + compressedSize, + /*EntriesBegin*/ &EntryStruct, + /*EntriesEnd*/ &EntryStruct + 1, + /*PropertySetsBegin*/ nullptr, + /*PropertySetsEnd*/ nullptr}; + sycl_device_binary Bin = &BinStruct; + CompressedRTDeviceBinaryImage Img{Bin}; + + // Decompress the image with multiple threads. + constexpr size_t ThreadCount = 20; + Barrier b(ThreadCount); + { + auto testDecompress = [&](size_t threadId) { + b.wait(); + Img.Decompress(); + + // Check if decompressed data is same as original data. + // Img.getRawData will change if there's a race in image decompression + // and the check will fail. + for (size_t i = 0; i < Img.getSize(); ++i) { + ASSERT_EQ(data[i], Img.getRawData().BinaryStart[i]); + } + }; + + ::ThreadPool MPool(ThreadCount, testDecompress); + } +}