@@ -710,15 +710,10 @@ CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage(
710710 static_cast <size_t >(Bin->BinaryEnd - Bin->BinaryStart ));
711711}
712712
713- // Decompress the device binary image if it is compressed. This function is
714- // thread-safe and will only decompress once even if called from multiple
715- // threads.
713+ // std::call_once ensures that this function is thread_safe and prevents
714+ // race during image decompression.
716715void CompressedRTDeviceBinaryImage::Decompress () {
717- ImageState expected = ImageState::Compressed;
718- ImageState desired = ImageState::DecompressionInProgress;
719-
720- // Decompress if not already done by another thread.
721- if (DecompState.compare_exchange_strong (expected, desired)) {
716+ auto DecompressFunc = [&]() {
722717 size_t CompressedDataSize =
723718 static_cast <size_t >(Bin->BinaryEnd - Bin->BinaryStart );
724719
@@ -733,17 +728,9 @@ void CompressedRTDeviceBinaryImage::Decompress() {
733728
734729 Bin->Format = ur::getBinaryImageFormat (Bin->BinaryStart , getSize ());
735730 Format = static_cast <ur::DeviceBinaryType>(Bin->Format );
731+ };
736732
737- DecompState.store (ImageState::Decompressed);
738- } else {
739- // Wait until the decompression is done by another thread.
740- while (DecompState.load () == ImageState::DecompressionInProgress) {
741- // Just spin.
742- }
743- }
744-
745- assert (DecompState.load () == ImageState::Decompressed &&
746- " Image should be decompressed by now" );
733+ std::call_once (InitFlag, DecompressFunc);
747734}
748735
749736CompressedRTDeviceBinaryImage::~CompressedRTDeviceBinaryImage () {
0 commit comments