Skip to content

Commit 0d8ecf3

Browse files
committed
use std::call_once
1 parent 3d69ee5 commit 0d8ecf3

File tree

2 files changed

+8
-28
lines changed

2 files changed

+8
-28
lines changed

sycl/source/detail/device_binary_image.cpp

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -710,15 +710,10 @@ CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage(
710710
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart));
711711
}
712712

713-
// Decompress the device binary image if it is compressed. This function is
714-
// thread-safe and will only decompress once even if called from multiple
715-
// threads.
713+
// std::call_once ensures that this function is thread_safe and prevents
714+
// race during image decompression.
716715
void CompressedRTDeviceBinaryImage::Decompress() {
717-
ImageState expected = ImageState::Compressed;
718-
ImageState desired = ImageState::DecompressionInProgress;
719-
720-
// Decompress if not already done by another thread.
721-
if (DecompState.compare_exchange_strong(expected, desired)) {
716+
auto DecompressFunc = [&]() {
722717
size_t CompressedDataSize =
723718
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart);
724719

@@ -733,17 +728,9 @@ void CompressedRTDeviceBinaryImage::Decompress() {
733728

734729
Bin->Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize());
735730
Format = static_cast<ur::DeviceBinaryType>(Bin->Format);
731+
};
736732

737-
DecompState.store(ImageState::Decompressed);
738-
} else {
739-
// Wait until the decompression is done by another thread.
740-
while (DecompState.load() == ImageState::DecompressionInProgress) {
741-
// Just spin.
742-
}
743-
}
744-
745-
assert(DecompState.load() == ImageState::Decompressed &&
746-
"Image should be decompressed by now");
733+
std::call_once(InitFlag, DecompressFunc);
747734
}
748735

749736
CompressedRTDeviceBinaryImage::~CompressedRTDeviceBinaryImage() {

sycl/source/detail/device_binary_image.hpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <atomic>
2020
#include <cstring>
2121
#include <memory>
22+
#include <mutex>
2223

2324
namespace sycl {
2425
inline namespace _V1 {
@@ -309,14 +310,6 @@ class DynRTDeviceBinaryImage : public RTDeviceBinaryImage {
309310
// actually used to build a program.
310311
// Also, frees the decompressed data in destructor.
311312
class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage {
312-
313-
// CompressedRTDeviceBinaryImage is in one of the following state.
314-
enum ImageState {
315-
Compressed = 0,
316-
DecompressionInProgress = 1,
317-
Decompressed = 2
318-
};
319-
320313
public:
321314
CompressedRTDeviceBinaryImage(sycl_device_binary Bin);
322315
~CompressedRTDeviceBinaryImage() override;
@@ -339,8 +332,8 @@ class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage {
339332
std::unique_ptr<char[]> m_DecompressedData;
340333
size_t m_ImageSize = 0;
341334

342-
// Atomic variable used to prevent race during image decompression.
343-
std::atomic<ImageState> DecompState{ImageState::Compressed};
335+
// Flag to ensure decompression happens only once.
336+
std::once_flag InitFlag;
344337
};
345338
#endif // SYCL_RT_ZSTD_AVAILABLE
346339

0 commit comments

Comments
 (0)