Skip to content

Commit ea27855

Browse files
committed
std::memory_order instead of int
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 36d139c commit ea27855

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

python/src/interpreter.cc

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ constexpr bool is_reinterpret_cast_to_atomic_safe =
2525

2626
enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX };
2727

28-
std::map<MemSemantic, int> mem_semantic_map = {
29-
{MemSemantic::ACQUIRE_RELEASE, static_cast<int>(std::memory_order_acq_rel)},
30-
{MemSemantic::ACQUIRE, static_cast<int>(std::memory_order_acquire)},
31-
{MemSemantic::RELEASE, static_cast<int>(std::memory_order_release)},
32-
{MemSemantic::RELAXED, static_cast<int>(std::memory_order_relaxed)},
28+
std::map<MemSemantic, std::memory_order> mem_semantic_map = {
29+
{MemSemantic::ACQUIRE_RELEASE, std::memory_order_acq_rel},
30+
{MemSemantic::ACQUIRE, std::memory_order_acquire},
31+
{MemSemantic::RELEASE, std::memory_order_release},
32+
{MemSemantic::RELAXED, std::memory_order_relaxed},
3333
};
3434

3535
template <bool is_min, typename T>
@@ -85,7 +85,7 @@ template <typename T> T atomic_fadd(T *loc, T value, std::memory_order order) {
8585

8686
class AtomicOp {
8787
public:
88-
AtomicOp(const uint64_t *ptr, size_t numel, int order)
88+
AtomicOp(const uint64_t *ptr, size_t numel, std::memory_order order)
8989
: ptr(ptr), numel(numel), order(order) {}
9090

9191
void apply() {
@@ -101,22 +101,21 @@ class AtomicOp {
101101

102102
const uint64_t *ptr;
103103
size_t numel;
104-
int order;
104+
std::memory_order order;
105105
};
106106

107107
template <typename DType> class AtomicRMWOpBase : public AtomicOp {
108108
public:
109109
AtomicRMWOpBase(const uint64_t *ptr, const void *val, void *ret,
110-
const bool *mask, size_t numel, int order)
110+
const bool *mask, size_t numel, std::memory_order order)
111111
: AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {}
112112

113113
protected:
114114
void applyAt(void *loc, size_t i) override final {
115115
if (mask[i]) {
116116
DType *atomic_ptr = static_cast<DType *>(loc);
117-
*(static_cast<DType *>(ret) + i) =
118-
applyAtMasked(atomic_ptr, *(static_cast<const DType *>(val) + i),
119-
std::memory_order(order));
117+
*(static_cast<DType *>(ret) + i) = applyAtMasked(
118+
atomic_ptr, *(static_cast<const DType *>(val) + i), order);
120119
}
121120
}
122121

@@ -315,7 +314,7 @@ void atomic_compare_exchange_strong(void *loc, void *expected,
315314
class AtomicCASOp : public AtomicOp {
316315
public:
317316
AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired,
318-
size_t itemsize, size_t numel, int order)
317+
size_t itemsize, size_t numel, std::memory_order order)
319318
: AtomicOp(ptr, numel, order), expected(expected), desired(desired),
320319
itemsize(itemsize) {}
321320

@@ -324,17 +323,16 @@ class AtomicCASOp : public AtomicOp {
324323
// Atomic operations perform bitwise comparison, so it's safe to
325324
// use number of bytes (itemsize) to determine the type of pointers
326325
if (itemsize == 1) {
327-
atomic_compare_exchange_strong<uint8_t>(loc, expected, desired, i,
328-
std::memory_order(order));
326+
atomic_compare_exchange_strong<uint8_t>(loc, expected, desired, i, order);
329327
} else if (itemsize == 2) {
330328
atomic_compare_exchange_strong<uint16_t>(loc, expected, desired, i,
331-
std::memory_order(order));
329+
order);
332330
} else if (itemsize == 4) {
333331
atomic_compare_exchange_strong<uint32_t>(loc, expected, desired, i,
334-
std::memory_order(order));
332+
order);
335333
} else if (itemsize == 8) {
336334
atomic_compare_exchange_strong<uint64_t>(loc, expected, desired, i,
337-
std::memory_order(order));
335+
order);
338336
} else {
339337
throw std::invalid_argument("Invalid byte size");
340338
}
@@ -361,7 +359,7 @@ template <RMWOp Op> struct OpCreator {
361359
void *ret;
362360
const bool *mask;
363361
size_t numel;
364-
int order;
362+
std::memory_order order;
365363
std::unique_ptr<AtomicOp> &atomic_op;
366364

367365
template <typename T> void create() {
@@ -375,7 +373,8 @@ template <RMWOp Op> struct OpCreator {
375373
template <RMWOp Op, typename... SupportedDTypes>
376374
std::unique_ptr<AtomicOp>
377375
makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val,
378-
void *ret, const bool *mask, size_t numel, int order) {
376+
void *ret, const bool *mask, size_t numel,
377+
std::memory_order order) {
379378
// Iterate over all supported data types, make one that matches, and return
380379
std::unique_ptr<AtomicOp> atomic_op;
381380
OpCreator<Op> try_make_op{dtype, ptr, val, ret,
@@ -453,7 +452,7 @@ void init_triton_interpreter(py::module &&m) {
453452
m.def("atomic_rmw",
454453
[](RMWOp rmw_op, py::array_t<uint64_t> ptr, py::array val,
455454
py::array_t<bool> mask, MemSemantic sem) -> py::array {
456-
int order = mem_semantic_map[sem];
455+
std::memory_order order = mem_semantic_map[sem];
457456
int numel = ptr.size();
458457
auto shape =
459458
std::vector<ptrdiff_t>(ptr.shape(), ptr.shape() + ptr.ndim());
@@ -500,7 +499,7 @@ void init_triton_interpreter(py::module &&m) {
500499
m.def("atomic_cas",
501500
[](py::array_t<uint64_t> ptr, py::array &cmp, py::array &val,
502501
MemSemantic sem) -> py::array {
503-
int order = mem_semantic_map[sem];
502+
std::memory_order order = mem_semantic_map[sem];
504503
int numel = ptr.size();
505504
auto shape =
506505
std::vector<ptrdiff_t>(ptr.shape(), ptr.shape() + ptr.ndim());

0 commit comments

Comments
 (0)