From b68577690830857e89da65c4e2d5a5471e0ed67f Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 23 Oct 2024 11:11:39 +0000 Subject: [PATCH 1/4] Use std atomic for 'interpreter.cc' Signed-off-by: Anatoly Myachev --- python/src/interpreter.cc | 134 +++++++++++++++++++++----------------- 1 file changed, 76 insertions(+), 58 deletions(-) diff --git a/python/src/interpreter.cc b/python/src/interpreter.cc index 6ab7c6c75c..458e5127cf 100644 --- a/python/src/interpreter.cc +++ b/python/src/interpreter.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -14,16 +15,16 @@ enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; std::map mem_semantic_map = { - {MemSemantic::ACQUIRE_RELEASE, __ATOMIC_ACQ_REL}, - {MemSemantic::ACQUIRE, __ATOMIC_ACQUIRE}, - {MemSemantic::RELEASE, __ATOMIC_RELEASE}, - {MemSemantic::RELAXED, __ATOMIC_RELAXED}, + {MemSemantic::ACQUIRE_RELEASE, static_cast(std::memory_order_acq_rel)}, + {MemSemantic::ACQUIRE, static_cast(std::memory_order_acquire)}, + {MemSemantic::RELEASE, static_cast(std::memory_order_release)}, + {MemSemantic::RELAXED, static_cast(std::memory_order_relaxed)}, }; - // Use compiler builtin atomics instead of std::atomic which requires // each variable to be declared as atomic. // Currently work for clang and gcc. -template T atomic_cmp(T *ptr, T val, int order) { +template +T atomic_cmp(std::atomic *ptr, T val, std::memory_order order) { auto cmp = [](T old, T val) { if constexpr (is_min) { return old > val; @@ -31,38 +32,29 @@ template T atomic_cmp(T *ptr, T val, int order) { return old < val; } }; + // First load - T old_val = __atomic_load_n(ptr, order); + T old_val = ptr->load(order); while (cmp(old_val, val)) { - if (__atomic_compare_exchange(ptr, &old_val, &val, false, order, order)) { + if (ptr->compare_exchange_weak(old_val, val, order, order)) { break; } } return old_val; } -template T atomic_fadd(T *ptr, T val, int order) { - T old_val; - T new_val; - // First load - // Load ptr as if uint32_t or uint64_t and then memcpy to T - if constexpr (sizeof(T) == 4) { - uint32_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); - std::memcpy(&old_val, &tmp, sizeof(T)); - } else if constexpr (sizeof(T) == 8) { - uint64_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); - std::memcpy(&old_val, &tmp, sizeof(T)); - } else { - throw std::invalid_argument("Unsupported data type"); - } - while (true) { - new_val = old_val + val; - if (__atomic_compare_exchange(ptr, &old_val, &new_val, false, order, - order)) { - break; - } - } - return old_val; +template +T atomic_fadd(std::atomic *loc, T value, std::memory_order order) { + static_assert(std::is_floating_point::value, + "T must be a floating-point type"); + + T old_value = loc->load(order); + T new_value; + do { + new_value = old_value + value; + } while (!loc->compare_exchange_weak(old_value, new_value, order, order)); + + return old_value; } class AtomicOp { @@ -95,13 +87,15 @@ template class AtomicRMWOpBase : public AtomicOp { protected: void applyAt(void *loc, size_t i) override final { if (mask[i]) { + std::atomic *atomic_ptr = static_cast *>(loc); *(static_cast(ret) + i) = - applyAtMasked(static_cast(loc), - *(static_cast(val) + i), order); + applyAtMasked(atomic_ptr, *(static_cast(val) + i), + std::memory_order(order)); } } - virtual DType applyAtMasked(DType *loc, const DType value, int order) = 0; + virtual DType applyAtMasked(std::atomic *loc, const DType value, + std::memory_order order) = 0; const void *val; void *ret; @@ -121,8 +115,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_add(loc, value, order); + DType applyAtMasked(std::atomic *loc, const DType value, + std::memory_order order) override { + return std::atomic_fetch_add(loc, value); } }; @@ -133,7 +128,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { + DType applyAtMasked(std::atomic *loc, const DType value, + std::memory_order order) override { + return atomic_fadd(loc, value, order); } }; @@ -145,8 +142,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_and(loc, value, order); + DType applyAtMasked(std::atomic *loc, const DType value, + std::memory_order order) override { + return std::atomic_fetch_and(loc, value); } }; @@ -157,8 +155,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_or(loc, value, order); + DType applyAtMasked(std::atomic *loc, const DType value, + std::memory_order order) override { + return std::atomic_fetch_or(loc, value); } }; @@ -169,8 +168,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_xor(loc, value, order); + DType applyAtMasked(std::atomic *loc, const DType value, + std::memory_order order) override { + return std::atomic_fetch_xor(loc, value); } }; @@ -182,7 +182,8 @@ class AtomicRMWOp::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { + DType applyAtMasked(std::atomic *loc, const DType value, + std::memory_order order) override { return atomic_cmp(loc, value, order); } }; @@ -195,7 +196,8 @@ class AtomicRMWOp::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { + DType applyAtMasked(std::atomic *loc, const DType value, + std::memory_order order) override { return atomic_cmp(loc, value, order); } }; @@ -207,8 +209,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_exchange_n(loc, value, order); + DType applyAtMasked(std::atomic *loc, const DType value, + std::memory_order order) override { + return loc->exchange(value, order); } }; @@ -224,25 +227,40 @@ class AtomicCASOp : public AtomicOp { // Atomic operations perform bitwise comparison, so it's safe to // use number of bytes (itemsize) to determine the type of pointers if (itemsize == 1) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); uint8_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + uint8_t *expected_uint = static_cast(expected); + // Perform the compare and exchange operation + atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, + std::memory_order(order), + std::memory_order(order)); + } else if (itemsize == 2) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); uint16_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + uint16_t *expected_uint = static_cast(expected); + atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, + std::memory_order(order), + std::memory_order(order)); } else if (itemsize == 4) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); uint32_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + uint32_t *expected_uint = static_cast(expected); + atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, + std::memory_order(order), + std::memory_order(order)); } else if (itemsize == 8) { uint64_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + std::atomic *atomic_loc = + static_cast *>(loc); + uint64_t *expected_uint = static_cast(expected); + atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, + std::memory_order(order), + std::memory_order(order)); + } else { // The ‘__atomic’ builtins can be used with any integral scalar or pointer // type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are From 36d139c12f90202f16eeccc9f95feb6d6cece185 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 28 Oct 2024 11:48:53 +0000 Subject: [PATCH 2/4] try to make reinterpret_cast more safe Signed-off-by: Anatoly Myachev --- python/src/interpreter.cc | 207 +++++++++++++++++++++++++------------- 1 file changed, 138 insertions(+), 69 deletions(-) diff --git a/python/src/interpreter.cc b/python/src/interpreter.cc index 458e5127cf..2abe8ef58d 100644 --- a/python/src/interpreter.cc +++ b/python/src/interpreter.cc @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -12,6 +13,16 @@ namespace { enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; +std::mutex atomic_op_guard; + +template +constexpr bool is_reinterpret_cast_to_atomic_safe = + std::is_trivially_copyable_v && + std::is_trivially_copyable_v> && + std::is_standard_layout_v && std::is_standard_layout_v> && + sizeof(T) == sizeof(std::atomic) && + alignof(T) == alignof(std::atomic); + enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; std::map mem_semantic_map = { @@ -20,11 +31,9 @@ std::map mem_semantic_map = { {MemSemantic::RELEASE, static_cast(std::memory_order_release)}, {MemSemantic::RELAXED, static_cast(std::memory_order_relaxed)}, }; -// Use compiler builtin atomics instead of std::atomic which requires -// each variable to be declared as atomic. -// Currently work for clang and gcc. + template -T atomic_cmp(std::atomic *ptr, T val, std::memory_order order) { +T atomic_cmp(T *ptr, T val, std::memory_order order) { auto cmp = [](T old, T val) { if constexpr (is_min) { return old > val; @@ -33,26 +42,43 @@ T atomic_cmp(std::atomic *ptr, T val, std::memory_order order) { } }; - // First load - T old_val = ptr->load(order); - while (cmp(old_val, val)) { - if (ptr->compare_exchange_weak(old_val, val, order, order)) { - break; + T old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_ptr = reinterpret_cast *>(ptr); + old_val = atomic_ptr->load(order); + while (cmp(old_val, val)) { + if (atomic_ptr->compare_exchange_weak(old_val, val, order, order)) { + break; + } + } + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *ptr; + if (cmp(old_val, val)) { + *ptr = val; } } return old_val; } -template -T atomic_fadd(std::atomic *loc, T value, std::memory_order order) { +template T atomic_fadd(T *loc, T value, std::memory_order order) { static_assert(std::is_floating_point::value, "T must be a floating-point type"); - - T old_value = loc->load(order); - T new_value; - do { - new_value = old_value + value; - } while (!loc->compare_exchange_weak(old_value, new_value, order, order)); + T old_value; + + if constexpr (is_reinterpret_cast_to_atomic_safe) { + T new_value; + std::atomic *atomic_loc = reinterpret_cast *>(loc); + old_value = atomic_loc->load(order); + do { + new_value = old_value + value; + } while ( + !atomic_loc->compare_exchange_weak(old_value, new_value, order, order)); + } else { + const std::lock_guard lock(atomic_op_guard); + old_value = *loc; + *loc = old_value + value; + } return old_value; } @@ -87,14 +113,14 @@ template class AtomicRMWOpBase : public AtomicOp { protected: void applyAt(void *loc, size_t i) override final { if (mask[i]) { - std::atomic *atomic_ptr = static_cast *>(loc); + DType *atomic_ptr = static_cast(loc); *(static_cast(ret) + i) = applyAtMasked(atomic_ptr, *(static_cast(val) + i), std::memory_order(order)); } } - virtual DType applyAtMasked(std::atomic *loc, const DType value, + virtual DType applyAtMasked(DType *loc, const DType value, std::memory_order order) = 0; const void *val; @@ -115,9 +141,19 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(std::atomic *loc, const DType value, + DType applyAtMasked(DType *loc, const DType value, std::memory_order order) override { - return std::atomic_fetch_add(loc, value); + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_add_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc + value; + } + return old_val; } }; @@ -128,9 +164,8 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(std::atomic *loc, const DType value, + DType applyAtMasked(DType *loc, const DType value, std::memory_order order) override { - return atomic_fadd(loc, value, order); } }; @@ -142,9 +177,19 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(std::atomic *loc, const DType value, + DType applyAtMasked(DType *loc, const DType value, std::memory_order order) override { - return std::atomic_fetch_and(loc, value); + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_and_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc and value; + } + return old_val; } }; @@ -155,9 +200,19 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(std::atomic *loc, const DType value, + DType applyAtMasked(DType *loc, const DType value, std::memory_order order) override { - return std::atomic_fetch_or(loc, value); + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_or_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc or value; + } + return old_val; } }; @@ -168,9 +223,19 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(std::atomic *loc, const DType value, + DType applyAtMasked(DType *loc, const DType value, std::memory_order order) override { - return std::atomic_fetch_xor(loc, value); + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_xor_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc xor value; + } + return old_val; } }; @@ -182,7 +247,7 @@ class AtomicRMWOp::AtomicRMWOpBase; protected: - DType applyAtMasked(std::atomic *loc, const DType value, + DType applyAtMasked(DType *loc, const DType value, std::memory_order order) override { return atomic_cmp(loc, value, order); } @@ -196,7 +261,7 @@ class AtomicRMWOp::AtomicRMWOpBase; protected: - DType applyAtMasked(std::atomic *loc, const DType value, + DType applyAtMasked(DType *loc, const DType value, std::memory_order order) override { return atomic_cmp(loc, value, order); } @@ -209,12 +274,44 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(std::atomic *loc, const DType value, + DType applyAtMasked(DType *loc, const DType value, std::memory_order order) override { - return loc->exchange(value, order); + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = atomic_loc->exchange(value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = value; + } + return old_val; } }; +template +void atomic_compare_exchange_strong(void *loc, void *expected, + const void *desired, size_t i, + std::memory_order order) { + T desired_val = *(static_cast(desired) + i); + T *expected_uint = static_cast(expected); + + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = reinterpret_cast *>(loc); + atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, + order, order); + } else { + const std::lock_guard lock(atomic_op_guard); + T *atomic_loc = static_cast(loc); + if (*atomic_loc == *(expected_uint + i)) { + *atomic_loc = desired_val; + } else { + *(expected_uint + i) = *atomic_loc; + } + } +} + class AtomicCASOp : public AtomicOp { public: AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired, @@ -227,46 +324,18 @@ class AtomicCASOp : public AtomicOp { // Atomic operations perform bitwise comparison, so it's safe to // use number of bytes (itemsize) to determine the type of pointers if (itemsize == 1) { - std::atomic *atomic_loc = - reinterpret_cast *>(loc); - uint8_t desired_val = *(static_cast(desired) + i); - uint8_t *expected_uint = static_cast(expected); - // Perform the compare and exchange operation - atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, - std::memory_order(order), - std::memory_order(order)); - + atomic_compare_exchange_strong(loc, expected, desired, i, + std::memory_order(order)); } else if (itemsize == 2) { - std::atomic *atomic_loc = - reinterpret_cast *>(loc); - uint16_t desired_val = *(static_cast(desired) + i); - uint16_t *expected_uint = static_cast(expected); - atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, - std::memory_order(order), - std::memory_order(order)); + atomic_compare_exchange_strong(loc, expected, desired, i, + std::memory_order(order)); } else if (itemsize == 4) { - std::atomic *atomic_loc = - reinterpret_cast *>(loc); - uint32_t desired_val = *(static_cast(desired) + i); - uint32_t *expected_uint = static_cast(expected); - atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, - std::memory_order(order), - std::memory_order(order)); + atomic_compare_exchange_strong(loc, expected, desired, i, + std::memory_order(order)); } else if (itemsize == 8) { - uint64_t desired_val = *(static_cast(desired) + i); - std::atomic *atomic_loc = - static_cast *>(loc); - uint64_t *expected_uint = static_cast(expected); - atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, - std::memory_order(order), - std::memory_order(order)); - + atomic_compare_exchange_strong(loc, expected, desired, i, + std::memory_order(order)); } else { - // The ‘__atomic’ builtins can be used with any integral scalar or pointer - // type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are - // also allowed if ‘__int128’ (see 128-bit Integers) is supported by the - // architecture. - // https://gcc.gnu.org/onlinedocs/gcc/_005f_005fatomic-Builtins.html throw std::invalid_argument("Invalid byte size"); } } From ea2785583e95e7ca2f2354d35bb70d574413ff26 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 28 Oct 2024 12:26:52 +0000 Subject: [PATCH 3/4] std::memory_order instead of int Signed-off-by: Anatoly Myachev --- python/src/interpreter.cc | 41 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/python/src/interpreter.cc b/python/src/interpreter.cc index 2abe8ef58d..391fc00ea9 100644 --- a/python/src/interpreter.cc +++ b/python/src/interpreter.cc @@ -25,11 +25,11 @@ constexpr bool is_reinterpret_cast_to_atomic_safe = enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; -std::map mem_semantic_map = { - {MemSemantic::ACQUIRE_RELEASE, static_cast(std::memory_order_acq_rel)}, - {MemSemantic::ACQUIRE, static_cast(std::memory_order_acquire)}, - {MemSemantic::RELEASE, static_cast(std::memory_order_release)}, - {MemSemantic::RELAXED, static_cast(std::memory_order_relaxed)}, +std::map mem_semantic_map = { + {MemSemantic::ACQUIRE_RELEASE, std::memory_order_acq_rel}, + {MemSemantic::ACQUIRE, std::memory_order_acquire}, + {MemSemantic::RELEASE, std::memory_order_release}, + {MemSemantic::RELAXED, std::memory_order_relaxed}, }; template @@ -85,7 +85,7 @@ template T atomic_fadd(T *loc, T value, std::memory_order order) { class AtomicOp { public: - AtomicOp(const uint64_t *ptr, size_t numel, int order) + AtomicOp(const uint64_t *ptr, size_t numel, std::memory_order order) : ptr(ptr), numel(numel), order(order) {} void apply() { @@ -101,22 +101,21 @@ class AtomicOp { const uint64_t *ptr; size_t numel; - int order; + std::memory_order order; }; template class AtomicRMWOpBase : public AtomicOp { public: AtomicRMWOpBase(const uint64_t *ptr, const void *val, void *ret, - const bool *mask, size_t numel, int order) + const bool *mask, size_t numel, std::memory_order order) : AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {} protected: void applyAt(void *loc, size_t i) override final { if (mask[i]) { DType *atomic_ptr = static_cast(loc); - *(static_cast(ret) + i) = - applyAtMasked(atomic_ptr, *(static_cast(val) + i), - std::memory_order(order)); + *(static_cast(ret) + i) = applyAtMasked( + atomic_ptr, *(static_cast(val) + i), order); } } @@ -315,7 +314,7 @@ void atomic_compare_exchange_strong(void *loc, void *expected, class AtomicCASOp : public AtomicOp { public: AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired, - size_t itemsize, size_t numel, int order) + size_t itemsize, size_t numel, std::memory_order order) : AtomicOp(ptr, numel, order), expected(expected), desired(desired), itemsize(itemsize) {} @@ -324,17 +323,16 @@ class AtomicCASOp : public AtomicOp { // Atomic operations perform bitwise comparison, so it's safe to // use number of bytes (itemsize) to determine the type of pointers if (itemsize == 1) { - atomic_compare_exchange_strong(loc, expected, desired, i, - std::memory_order(order)); + atomic_compare_exchange_strong(loc, expected, desired, i, order); } else if (itemsize == 2) { atomic_compare_exchange_strong(loc, expected, desired, i, - std::memory_order(order)); + order); } else if (itemsize == 4) { atomic_compare_exchange_strong(loc, expected, desired, i, - std::memory_order(order)); + order); } else if (itemsize == 8) { atomic_compare_exchange_strong(loc, expected, desired, i, - std::memory_order(order)); + order); } else { throw std::invalid_argument("Invalid byte size"); } @@ -361,7 +359,7 @@ template struct OpCreator { void *ret; const bool *mask; size_t numel; - int order; + std::memory_order order; std::unique_ptr &atomic_op; template void create() { @@ -375,7 +373,8 @@ template struct OpCreator { template std::unique_ptr makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, - void *ret, const bool *mask, size_t numel, int order) { + void *ret, const bool *mask, size_t numel, + std::memory_order order) { // Iterate over all supported data types, make one that matches, and return std::unique_ptr atomic_op; OpCreator try_make_op{dtype, ptr, val, ret, @@ -453,7 +452,7 @@ void init_triton_interpreter(py::module &&m) { m.def("atomic_rmw", [](RMWOp rmw_op, py::array_t ptr, py::array val, py::array_t mask, MemSemantic sem) -> py::array { - int order = mem_semantic_map[sem]; + std::memory_order order = mem_semantic_map[sem]; int numel = ptr.size(); auto shape = std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); @@ -500,7 +499,7 @@ void init_triton_interpreter(py::module &&m) { m.def("atomic_cas", [](py::array_t ptr, py::array &cmp, py::array &val, MemSemantic sem) -> py::array { - int order = mem_semantic_map[sem]; + std::memory_order order = mem_semantic_map[sem]; int numel = ptr.size(); auto shape = std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); From d58443fe187a4ade0f1d3d534967088b2d592b9d Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 28 Oct 2024 12:40:40 +0000 Subject: [PATCH 4/4] cleanup Signed-off-by: Anatoly Myachev --- python/src/interpreter.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/src/interpreter.cc b/python/src/interpreter.cc index 391fc00ea9..bd7b215fe9 100644 --- a/python/src/interpreter.cc +++ b/python/src/interpreter.cc @@ -113,9 +113,9 @@ template class AtomicRMWOpBase : public AtomicOp { protected: void applyAt(void *loc, size_t i) override final { if (mask[i]) { - DType *atomic_ptr = static_cast(loc); - *(static_cast(ret) + i) = applyAtMasked( - atomic_ptr, *(static_cast(val) + i), order); + DType *ptr = static_cast(loc); + *(static_cast(ret) + i) = + applyAtMasked(ptr, *(static_cast(val) + i), order); } } @@ -294,19 +294,19 @@ void atomic_compare_exchange_strong(void *loc, void *expected, const void *desired, size_t i, std::memory_order order) { T desired_val = *(static_cast(desired) + i); - T *expected_uint = static_cast(expected); + T *expected_uint = static_cast(expected) + i; if constexpr (is_reinterpret_cast_to_atomic_safe) { std::atomic *atomic_loc = reinterpret_cast *>(loc); - atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, - order, order); + atomic_loc->compare_exchange_strong(*expected_uint, desired_val, order, + order); } else { const std::lock_guard lock(atomic_op_guard); T *atomic_loc = static_cast(loc); - if (*atomic_loc == *(expected_uint + i)) { + if (*atomic_loc == *expected_uint) { *atomic_loc = desired_val; } else { - *(expected_uint + i) = *atomic_loc; + *expected_uint = *atomic_loc; } } }