@@ -25,11 +25,11 @@ constexpr bool is_reinterpret_cast_to_atomic_safe =
2525
2626enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX };
2727
28- std::map<MemSemantic, int > mem_semantic_map = {
29- {MemSemantic::ACQUIRE_RELEASE, static_cast < int >( std::memory_order_acq_rel) },
30- {MemSemantic::ACQUIRE, static_cast < int >( std::memory_order_acquire) },
31- {MemSemantic::RELEASE, static_cast < int >( std::memory_order_release) },
32- {MemSemantic::RELAXED, static_cast < int >( std::memory_order_relaxed) },
28+ std::map<MemSemantic, std::memory_order > mem_semantic_map = {
29+ {MemSemantic::ACQUIRE_RELEASE, std::memory_order_acq_rel},
30+ {MemSemantic::ACQUIRE, std::memory_order_acquire},
31+ {MemSemantic::RELEASE, std::memory_order_release},
32+ {MemSemantic::RELAXED, std::memory_order_relaxed},
3333};
3434
3535template <bool is_min, typename T>
@@ -85,7 +85,7 @@ template <typename T> T atomic_fadd(T *loc, T value, std::memory_order order) {
8585
8686class AtomicOp {
8787public:
88- AtomicOp (const uint64_t *ptr, size_t numel, int order)
88+ AtomicOp (const uint64_t *ptr, size_t numel, std::memory_order order)
8989 : ptr(ptr), numel(numel), order(order) {}
9090
9191 void apply () {
@@ -101,22 +101,21 @@ class AtomicOp {
101101
102102 const uint64_t *ptr;
103103 size_t numel;
104- int order;
104+ std::memory_order order;
105105};
106106
107107template <typename DType> class AtomicRMWOpBase : public AtomicOp {
108108public:
109109 AtomicRMWOpBase (const uint64_t *ptr, const void *val, void *ret,
110- const bool *mask, size_t numel, int order)
110+ const bool *mask, size_t numel, std::memory_order order)
111111 : AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {}
112112
113113protected:
114114 void applyAt (void *loc, size_t i) override final {
115115 if (mask[i]) {
116116 DType *atomic_ptr = static_cast <DType *>(loc);
117- *(static_cast <DType *>(ret) + i) =
118- applyAtMasked (atomic_ptr, *(static_cast <const DType *>(val) + i),
119- std::memory_order (order));
117+ *(static_cast <DType *>(ret) + i) = applyAtMasked (
118+ atomic_ptr, *(static_cast <const DType *>(val) + i), order);
120119 }
121120 }
122121
@@ -315,7 +314,7 @@ void atomic_compare_exchange_strong(void *loc, void *expected,
315314class AtomicCASOp : public AtomicOp {
316315public:
317316 AtomicCASOp (const uint64_t *ptr, void *expected, const void *desired,
318- size_t itemsize, size_t numel, int order)
317+ size_t itemsize, size_t numel, std::memory_order order)
319318 : AtomicOp(ptr, numel, order), expected(expected), desired(desired),
320319 itemsize (itemsize) {}
321320
@@ -324,17 +323,16 @@ class AtomicCASOp : public AtomicOp {
324323 // Atomic operations perform bitwise comparison, so it's safe to
325324 // use number of bytes (itemsize) to determine the type of pointers
326325 if (itemsize == 1 ) {
327- atomic_compare_exchange_strong<uint8_t >(loc, expected, desired, i,
328- std::memory_order (order));
326+ atomic_compare_exchange_strong<uint8_t >(loc, expected, desired, i, order);
329327 } else if (itemsize == 2 ) {
330328 atomic_compare_exchange_strong<uint16_t >(loc, expected, desired, i,
331- std::memory_order ( order) );
329+ order);
332330 } else if (itemsize == 4 ) {
333331 atomic_compare_exchange_strong<uint32_t >(loc, expected, desired, i,
334- std::memory_order ( order) );
332+ order);
335333 } else if (itemsize == 8 ) {
336334 atomic_compare_exchange_strong<uint64_t >(loc, expected, desired, i,
337- std::memory_order ( order) );
335+ order);
338336 } else {
339337 throw std::invalid_argument (" Invalid byte size" );
340338 }
@@ -361,7 +359,7 @@ template <RMWOp Op> struct OpCreator {
361359 void *ret;
362360 const bool *mask;
363361 size_t numel;
364- int order;
362+ std::memory_order order;
365363 std::unique_ptr<AtomicOp> &atomic_op;
366364
367365 template <typename T> void create () {
@@ -375,7 +373,8 @@ template <RMWOp Op> struct OpCreator {
375373template <RMWOp Op, typename ... SupportedDTypes>
376374std::unique_ptr<AtomicOp>
377375makeAtomicRMWOp (pybind11::dtype dtype, const uint64_t *ptr, const void *val,
378- void *ret, const bool *mask, size_t numel, int order) {
376+ void *ret, const bool *mask, size_t numel,
377+ std::memory_order order) {
379378 // Iterate over all supported data types, make one that matches, and return
380379 std::unique_ptr<AtomicOp> atomic_op;
381380 OpCreator<Op> try_make_op{dtype, ptr, val, ret,
@@ -453,7 +452,7 @@ void init_triton_interpreter(py::module &&m) {
453452 m.def (" atomic_rmw" ,
454453 [](RMWOp rmw_op, py::array_t <uint64_t > ptr, py::array val,
455454 py::array_t <bool > mask, MemSemantic sem) -> py::array {
456- int order = mem_semantic_map[sem];
455+ std::memory_order order = mem_semantic_map[sem];
457456 int numel = ptr.size ();
458457 auto shape =
459458 std::vector<ptrdiff_t >(ptr.shape (), ptr.shape () + ptr.ndim ());
@@ -500,7 +499,7 @@ void init_triton_interpreter(py::module &&m) {
500499 m.def (" atomic_cas" ,
501500 [](py::array_t <uint64_t > ptr, py::array &cmp, py::array &val,
502501 MemSemantic sem) -> py::array {
503- int order = mem_semantic_map[sem];
502+ std::memory_order order = mem_semantic_map[sem];
504503 int numel = ptr.size ();
505504 auto shape =
506505 std::vector<ptrdiff_t >(ptr.shape (), ptr.shape () + ptr.ndim ());
0 commit comments