Skip to content

Commit 52cb519

Browse files
committed
[SYCL] Fix race during image decompression
1 parent 7e292b9 commit 52cb519

File tree

3 files changed

+99
-11
lines changed

3 files changed

+99
-11
lines changed

sycl/source/detail/device_binary_image.cpp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -710,22 +710,40 @@ CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage(
710710
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart));
711711
}
712712

713+
// Decompress the device binary image if it is compressed. This function is
714+
// thread-safe and will only decompress once even if called from multiple
715+
// threads.
713716
void CompressedRTDeviceBinaryImage::Decompress() {
717+
ImageState expected = ImageState::Compressed;
718+
ImageState desired = ImageState::DecompressionInProgress;
714719

715-
size_t CompressedDataSize =
716-
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart);
720+
// Decompress if not already done by another thread.
721+
if (DecompState.compare_exchange_strong(expected, desired)) {
722+
size_t CompressedDataSize =
723+
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart);
717724

718-
size_t DecompressedSize = 0;
719-
m_DecompressedData = ZSTDCompressor::DecompressBlob(
720-
reinterpret_cast<const char *>(Bin->BinaryStart), CompressedDataSize,
721-
DecompressedSize);
725+
size_t DecompressedSize = 0;
726+
m_DecompressedData = ZSTDCompressor::DecompressBlob(
727+
reinterpret_cast<const char *>(Bin->BinaryStart), CompressedDataSize,
728+
DecompressedSize);
722729

723-
Bin->BinaryStart =
724-
reinterpret_cast<const unsigned char *>(m_DecompressedData.get());
725-
Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize;
730+
Bin->BinaryStart =
731+
reinterpret_cast<const unsigned char *>(m_DecompressedData.get());
732+
Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize;
726733

727-
Bin->Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize());
728-
Format = static_cast<ur::DeviceBinaryType>(Bin->Format);
734+
Bin->Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize());
735+
Format = static_cast<ur::DeviceBinaryType>(Bin->Format);
736+
737+
DecompState.store(ImageState::Decompressed);
738+
} else {
739+
// Wait until the decompression is done by another thread.
740+
while (DecompState.load() == ImageState::DecompressionInProgress) {
741+
// Just spin.
742+
}
743+
}
744+
745+
assert(DecompState.load() == ImageState::Decompressed &&
746+
"Image should be decompressed by now");
729747
}
730748

731749
CompressedRTDeviceBinaryImage::~CompressedRTDeviceBinaryImage() {

sycl/source/detail/device_binary_image.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,14 @@ class DynRTDeviceBinaryImage : public RTDeviceBinaryImage {
309309
// actually used to build a program.
310310
// Also, frees the decompressed data in destructor.
311311
class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage {
312+
313+
// CompressedRTDeviceBinaryImage is in one of the following state.
314+
enum ImageState {
315+
Compressed = 0,
316+
DecompressionInProgress = 1,
317+
Decompressed = 2
318+
};
319+
312320
public:
313321
CompressedRTDeviceBinaryImage(sycl_device_binary Bin);
314322
~CompressedRTDeviceBinaryImage() override;
@@ -330,6 +338,9 @@ class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage {
330338
private:
331339
std::unique_ptr<char[]> m_DecompressedData;
332340
size_t m_ImageSize = 0;
341+
342+
// Atomic variable used to prevent race during image decompression.
343+
std::atomic<ImageState> DecompState{ImageState::Compressed};
333344
};
334345
#endif // SYCL_RT_ZSTD_AVAILABLE
335346

sycl/unittests/compression/CompressionTests.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "../thread_safety/ThreadUtils.h"
1010
#include <detail/compression.hpp>
11+
#include <detail/device_binary_image.hpp>
1112
#include <sycl/sycl.hpp>
1213

1314
#include <string>
@@ -113,3 +114,61 @@ TEST(CompressionTest, ConcurrentCompressionDecompression) {
113114
::ThreadPool MPool(ThreadCount, testCompressDecompress);
114115
}
115116
}
117+
118+
// Test to decompress CompressedRTDeviceImage using multiple threads.
119+
// The idea behind this test is to ensure that a device image is
120+
// decompressed only once even if multiple threads try to decompress
121+
// it at the same time.
122+
TEST(CompressionTest, ConcurrentDecompressionOfDeviceImage) {
123+
// Data to compress.
124+
std::string data = "Hello World! I'm about to get compressed :P";
125+
126+
// Compress this data.
127+
size_t compressedSize = 0;
128+
auto compressedData = ZSTDCompressor::CompressBlob(data.c_str(), data.size(),
129+
compressedSize, 1);
130+
131+
unsigned char *compressedDataPtr =
132+
reinterpret_cast<unsigned char *>(compressedData.get());
133+
134+
const char *EntryName = "Entry";
135+
_sycl_offload_entry_struct EntryStruct = {
136+
/*addr*/ nullptr, const_cast<char *>(EntryName), strlen(EntryName),
137+
/*flags*/ 0, /*reserved*/ 0};
138+
sycl_device_binary_struct BinStruct{/*Version*/ 1,
139+
/*Kind*/ 4,
140+
/*Format*/ SYCL_DEVICE_BINARY_TYPE_SPIRV,
141+
/*DeviceTargetSpec*/ nullptr,
142+
/*CompileOptions*/ nullptr,
143+
/*LinkOptions*/ nullptr,
144+
/*ManifestStart*/ nullptr,
145+
/*ManifestEnd*/ nullptr,
146+
/*BinaryStart*/ compressedDataPtr,
147+
/*BinaryEnd*/ compressedDataPtr +
148+
compressedSize,
149+
/*EntriesBegin*/ &EntryStruct,
150+
/*EntriesEnd*/ &EntryStruct + 1,
151+
/*PropertySetsBegin*/ nullptr,
152+
/*PropertySetsEnd*/ nullptr};
153+
sycl_device_binary Bin = &BinStruct;
154+
CompressedRTDeviceBinaryImage Img{Bin};
155+
156+
// Decompress the image with multiple threads.
157+
constexpr size_t ThreadCount = 20;
158+
Barrier b(ThreadCount);
159+
{
160+
auto testDecompress = [&](size_t threadId) {
161+
b.wait();
162+
Img.Decompress();
163+
164+
// Check if decompressed data is same as original data.
165+
// Img.getRawData will change if there's a race in image decompression
166+
// and the check will fail.
167+
for (size_t i = 0; i < Img.getSize(); ++i) {
168+
ASSERT_EQ(data[i], Img.getRawData().BinaryStart[i]);
169+
}
170+
};
171+
172+
::ThreadPool MPool(ThreadCount, testDecompress);
173+
}
174+
}

0 commit comments

Comments
 (0)