Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions sycl/source/detail/device_binary_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,22 +710,29 @@ CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage(
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart));
}

// std::call_once ensures that this function is thread_safe and prevents
// race during image decompression.
void CompressedRTDeviceBinaryImage::Decompress() {
auto DecompressFunc = [&]() {
size_t CompressedDataSize =
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart);

size_t CompressedDataSize =
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart);
size_t DecompressedSize = 0;
m_DecompressedData = ZSTDCompressor::DecompressBlob(
reinterpret_cast<const char *>(Bin->BinaryStart), CompressedDataSize,
DecompressedSize);

size_t DecompressedSize = 0;
m_DecompressedData = ZSTDCompressor::DecompressBlob(
reinterpret_cast<const char *>(Bin->BinaryStart), CompressedDataSize,
DecompressedSize);
Bin->BinaryStart =
reinterpret_cast<const unsigned char *>(m_DecompressedData.get());
Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize;

Bin->BinaryStart =
reinterpret_cast<const unsigned char *>(m_DecompressedData.get());
Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize;
Bin->Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize());
Format = static_cast<ur::DeviceBinaryType>(Bin->Format);

Bin->Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize());
Format = static_cast<ur::DeviceBinaryType>(Bin->Format);
m_IsCompressed.store(false);
};

std::call_once(m_InitFlag, DecompressFunc);
}

CompressedRTDeviceBinaryImage::~CompressedRTDeviceBinaryImage() {
Expand Down
8 changes: 7 additions & 1 deletion sycl/source/detail/device_binary_image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <atomic>
#include <cstring>
#include <memory>
#include <mutex>

namespace sycl {
inline namespace _V1 {
Expand Down Expand Up @@ -321,7 +322,8 @@ class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage {
return m_ImageSize;
}

bool IsCompressed() const { return m_DecompressedData.get() == nullptr; }
bool IsCompressed() const { return m_IsCompressed.load(); }

void print() const override {
RTDeviceBinaryImage::print();
std::cerr << " COMPRESSED\n";
Expand All @@ -330,6 +332,10 @@ class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage {
private:
std::unique_ptr<char[]> m_DecompressedData;
size_t m_ImageSize = 0;

// Flag to ensure decompression happens only once.
std::once_flag m_InitFlag;
std::atomic<bool> m_IsCompressed{true};
};
#endif // SYCL_RT_ZSTD_AVAILABLE

Expand Down
59 changes: 59 additions & 0 deletions sycl/unittests/compression/CompressionTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "../thread_safety/ThreadUtils.h"
#include <detail/compression.hpp>
#include <detail/device_binary_image.hpp>
#include <sycl/sycl.hpp>

#include <string>
Expand Down Expand Up @@ -113,3 +114,61 @@ TEST(CompressionTest, ConcurrentCompressionDecompression) {
::ThreadPool MPool(ThreadCount, testCompressDecompress);
}
}

// Test to decompress CompressedRTDeviceImage using multiple threads.
// The idea behind this test is to ensure that a device image is
// decompressed only once even if multiple threads try to decompress
// it at the same time.
TEST(CompressionTest, ConcurrentDecompressionOfDeviceImage) {
// Data to compress.
std::string data = "Hello World! I'm about to get compressed :P";

// Compress this data.
size_t compressedSize = 0;
auto compressedData = ZSTDCompressor::CompressBlob(data.c_str(), data.size(),
compressedSize, 1);

unsigned char *compressedDataPtr =
reinterpret_cast<unsigned char *>(compressedData.get());

const char *EntryName = "Entry";
_sycl_offload_entry_struct EntryStruct = {
/*addr*/ nullptr, const_cast<char *>(EntryName), strlen(EntryName),
/*flags*/ 0, /*reserved*/ 0};
sycl_device_binary_struct BinStruct{/*Version*/ 1,
/*Kind*/ 4,
/*Format*/ SYCL_DEVICE_BINARY_TYPE_SPIRV,
/*DeviceTargetSpec*/ nullptr,
/*CompileOptions*/ nullptr,
/*LinkOptions*/ nullptr,
/*ManifestStart*/ nullptr,
/*ManifestEnd*/ nullptr,
/*BinaryStart*/ compressedDataPtr,
/*BinaryEnd*/ compressedDataPtr +
compressedSize,
/*EntriesBegin*/ &EntryStruct,
/*EntriesEnd*/ &EntryStruct + 1,
/*PropertySetsBegin*/ nullptr,
/*PropertySetsEnd*/ nullptr};
sycl_device_binary Bin = &BinStruct;
CompressedRTDeviceBinaryImage Img{Bin};

// Decompress the image with multiple threads.
constexpr size_t ThreadCount = 20;
Barrier b(ThreadCount);
{
auto testDecompress = [&](size_t threadId) {
b.wait();
Img.Decompress();

// Check if decompressed data is same as original data.
// Img.getRawData will change if there's a race in image decompression
// and the check will fail.
for (size_t i = 0; i < Img.getSize(); ++i) {
ASSERT_EQ(data[i], Img.getRawData().BinaryStart[i]);
}
};

::ThreadPool MPool(ThreadCount, testDecompress);
}
}
Loading