Skip to content

Commit b685776

Browse files
committed
Use std atomic for 'interpreter.cc'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 61fd54d commit b685776

File tree

1 file changed

+76
-58
lines changed

1 file changed

+76
-58
lines changed

python/src/interpreter.cc

Lines changed: 76 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <atomic>
12
#include <iostream>
23
#include <map>
34
#include <memory>
@@ -14,55 +15,46 @@ enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED };
1415
enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX };
1516

1617
std::map<MemSemantic, int> mem_semantic_map = {
17-
{MemSemantic::ACQUIRE_RELEASE, __ATOMIC_ACQ_REL},
18-
{MemSemantic::ACQUIRE, __ATOMIC_ACQUIRE},
19-
{MemSemantic::RELEASE, __ATOMIC_RELEASE},
20-
{MemSemantic::RELAXED, __ATOMIC_RELAXED},
18+
{MemSemantic::ACQUIRE_RELEASE, static_cast<int>(std::memory_order_acq_rel)},
19+
{MemSemantic::ACQUIRE, static_cast<int>(std::memory_order_acquire)},
20+
{MemSemantic::RELEASE, static_cast<int>(std::memory_order_release)},
21+
{MemSemantic::RELAXED, static_cast<int>(std::memory_order_relaxed)},
2122
};
22-
2323
// Use compiler builtin atomics instead of std::atomic which requires
2424
// each variable to be declared as atomic.
2525
// Currently work for clang and gcc.
26-
template <bool is_min, typename T> T atomic_cmp(T *ptr, T val, int order) {
26+
template <bool is_min, typename T>
27+
T atomic_cmp(std::atomic<T> *ptr, T val, std::memory_order order) {
2728
auto cmp = [](T old, T val) {
2829
if constexpr (is_min) {
2930
return old > val;
3031
} else {
3132
return old < val;
3233
}
3334
};
35+
3436
// First load
35-
T old_val = __atomic_load_n(ptr, order);
37+
T old_val = ptr->load(order);
3638
while (cmp(old_val, val)) {
37-
if (__atomic_compare_exchange(ptr, &old_val, &val, false, order, order)) {
39+
if (ptr->compare_exchange_weak(old_val, val, order, order)) {
3840
break;
3941
}
4042
}
4143
return old_val;
4244
}
4345

44-
template <typename T> T atomic_fadd(T *ptr, T val, int order) {
45-
T old_val;
46-
T new_val;
47-
// First load
48-
// Load ptr as if uint32_t or uint64_t and then memcpy to T
49-
if constexpr (sizeof(T) == 4) {
50-
uint32_t tmp = __atomic_load_n(reinterpret_cast<uint32_t *>(ptr), order);
51-
std::memcpy(&old_val, &tmp, sizeof(T));
52-
} else if constexpr (sizeof(T) == 8) {
53-
uint64_t tmp = __atomic_load_n(reinterpret_cast<uint64_t *>(ptr), order);
54-
std::memcpy(&old_val, &tmp, sizeof(T));
55-
} else {
56-
throw std::invalid_argument("Unsupported data type");
57-
}
58-
while (true) {
59-
new_val = old_val + val;
60-
if (__atomic_compare_exchange(ptr, &old_val, &new_val, false, order,
61-
order)) {
62-
break;
63-
}
64-
}
65-
return old_val;
46+
template <typename T>
47+
T atomic_fadd(std::atomic<T> *loc, T value, std::memory_order order) {
48+
static_assert(std::is_floating_point<T>::value,
49+
"T must be a floating-point type");
50+
51+
T old_value = loc->load(order);
52+
T new_value;
53+
do {
54+
new_value = old_value + value;
55+
} while (!loc->compare_exchange_weak(old_value, new_value, order, order));
56+
57+
return old_value;
6658
}
6759

6860
class AtomicOp {
@@ -95,13 +87,15 @@ template <typename DType> class AtomicRMWOpBase : public AtomicOp {
9587
protected:
9688
void applyAt(void *loc, size_t i) override final {
9789
if (mask[i]) {
90+
std::atomic<DType> *atomic_ptr = static_cast<std::atomic<DType> *>(loc);
9891
*(static_cast<DType *>(ret) + i) =
99-
applyAtMasked(static_cast<DType *>(loc),
100-
*(static_cast<const DType *>(val) + i), order);
92+
applyAtMasked(atomic_ptr, *(static_cast<const DType *>(val) + i),
93+
std::memory_order(order));
10194
}
10295
}
10396

104-
virtual DType applyAtMasked(DType *loc, const DType value, int order) = 0;
97+
virtual DType applyAtMasked(std::atomic<DType> *loc, const DType value,
98+
std::memory_order order) = 0;
10599

106100
const void *val;
107101
void *ret;
@@ -121,8 +115,9 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::ADD>>
121115
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
122116

123117
protected:
124-
DType applyAtMasked(DType *loc, const DType value, int order) override {
125-
return __atomic_fetch_add(loc, value, order);
118+
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
119+
std::memory_order order) override {
120+
return std::atomic_fetch_add(loc, value);
126121
}
127122
};
128123

@@ -133,7 +128,9 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::FADD>>
133128
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
134129

135130
protected:
136-
DType applyAtMasked(DType *loc, const DType value, int order) override {
131+
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
132+
std::memory_order order) override {
133+
137134
return atomic_fadd(loc, value, order);
138135
}
139136
};
@@ -145,8 +142,9 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::AND>>
145142
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
146143

147144
protected:
148-
DType applyAtMasked(DType *loc, const DType value, int order) override {
149-
return __atomic_fetch_and(loc, value, order);
145+
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
146+
std::memory_order order) override {
147+
return std::atomic_fetch_and(loc, value);
150148
}
151149
};
152150

@@ -157,8 +155,9 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::OR>>
157155
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
158156

159157
protected:
160-
DType applyAtMasked(DType *loc, const DType value, int order) override {
161-
return __atomic_fetch_or(loc, value, order);
158+
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
159+
std::memory_order order) override {
160+
return std::atomic_fetch_or(loc, value);
162161
}
163162
};
164163

@@ -169,8 +168,9 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::XOR>>
169168
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
170169

171170
protected:
172-
DType applyAtMasked(DType *loc, const DType value, int order) override {
173-
return __atomic_fetch_xor(loc, value, order);
171+
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
172+
std::memory_order order) override {
173+
return std::atomic_fetch_xor(loc, value);
174174
}
175175
};
176176

@@ -182,7 +182,8 @@ class AtomicRMWOp<DType, Op,
182182
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
183183

184184
protected:
185-
DType applyAtMasked(DType *loc, const DType value, int order) override {
185+
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
186+
std::memory_order order) override {
186187
return atomic_cmp</*is_min=*/false>(loc, value, order);
187188
}
188189
};
@@ -195,7 +196,8 @@ class AtomicRMWOp<DType, Op,
195196
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
196197

197198
protected:
198-
DType applyAtMasked(DType *loc, const DType value, int order) override {
199+
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
200+
std::memory_order order) override {
199201
return atomic_cmp</*is_min=*/true>(loc, value, order);
200202
}
201203
};
@@ -207,8 +209,9 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::XCHG>>
207209
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
208210

209211
protected:
210-
DType applyAtMasked(DType *loc, const DType value, int order) override {
211-
return __atomic_exchange_n(loc, value, order);
212+
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
213+
std::memory_order order) override {
214+
return loc->exchange(value, order);
212215
}
213216
};
214217

@@ -224,25 +227,40 @@ class AtomicCASOp : public AtomicOp {
224227
// Atomic operations perform bitwise comparison, so it's safe to
225228
// use number of bytes (itemsize) to determine the type of pointers
226229
if (itemsize == 1) {
230+
std::atomic<uint8_t> *atomic_loc =
231+
reinterpret_cast<std::atomic<uint8_t> *>(loc);
227232
uint8_t desired_val = *(static_cast<const uint8_t *>(desired) + i);
228-
__atomic_compare_exchange_n(static_cast<uint8_t *>(loc),
229-
static_cast<uint8_t *>(expected) + i,
230-
desired_val, false, order, order);
233+
uint8_t *expected_uint = static_cast<uint8_t *>(expected);
234+
// Perform the compare and exchange operation
235+
atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val,
236+
std::memory_order(order),
237+
std::memory_order(order));
238+
231239
} else if (itemsize == 2) {
240+
std::atomic<uint16_t> *atomic_loc =
241+
reinterpret_cast<std::atomic<uint16_t> *>(loc);
232242
uint16_t desired_val = *(static_cast<const uint16_t *>(desired) + i);
233-
__atomic_compare_exchange_n(static_cast<uint16_t *>(loc),
234-
static_cast<uint16_t *>(expected) + i,
235-
desired_val, false, order, order);
243+
uint16_t *expected_uint = static_cast<uint16_t *>(expected);
244+
atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val,
245+
std::memory_order(order),
246+
std::memory_order(order));
236247
} else if (itemsize == 4) {
248+
std::atomic<uint32_t> *atomic_loc =
249+
reinterpret_cast<std::atomic<uint32_t> *>(loc);
237250
uint32_t desired_val = *(static_cast<const uint32_t *>(desired) + i);
238-
__atomic_compare_exchange_n(static_cast<uint32_t *>(loc),
239-
static_cast<uint32_t *>(expected) + i,
240-
desired_val, false, order, order);
251+
uint32_t *expected_uint = static_cast<uint32_t *>(expected);
252+
atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val,
253+
std::memory_order(order),
254+
std::memory_order(order));
241255
} else if (itemsize == 8) {
242256
uint64_t desired_val = *(static_cast<const uint64_t *>(desired) + i);
243-
__atomic_compare_exchange_n(static_cast<uint64_t *>(loc),
244-
static_cast<uint64_t *>(expected) + i,
245-
desired_val, false, order, order);
257+
std::atomic<uint64_t> *atomic_loc =
258+
static_cast<std::atomic<uint64_t> *>(loc);
259+
uint64_t *expected_uint = static_cast<uint64_t *>(expected);
260+
atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val,
261+
std::memory_order(order),
262+
std::memory_order(order));
263+
246264
} else {
247265
// The ‘__atomic’ builtins can be used with any integral scalar or pointer
248266
// type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are

0 commit comments

Comments
 (0)