22#include < iostream>
33#include < map>
44#include < memory>
5+ #include < mutex>
56#include < pybind11/numpy.h>
67#include < pybind11/pybind11.h>
78#include < type_traits>
@@ -12,6 +13,16 @@ namespace {
1213
1314enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED };
1415
16+ std::mutex atomic_op_guard;
17+
18+ template <typename T>
19+ constexpr bool is_reinterpret_cast_to_atomic_safe =
20+ std::is_trivially_copyable_v<T> &&
21+ std::is_trivially_copyable_v<std::atomic<T>> &&
22+ std::is_standard_layout_v<T> && std::is_standard_layout_v<std::atomic<T>> &&
23+ sizeof (T) == sizeof (std::atomic<T>) &&
24+ alignof (T) == alignof (std::atomic<T>);
25+
1526enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX };
1627
1728std::map<MemSemantic, int > mem_semantic_map = {
@@ -20,11 +31,9 @@ std::map<MemSemantic, int> mem_semantic_map = {
2031 {MemSemantic::RELEASE, static_cast <int >(std::memory_order_release)},
2132 {MemSemantic::RELAXED, static_cast <int >(std::memory_order_relaxed)},
2233};
23- // Use compiler builtin atomics instead of std::atomic which requires
24- // each variable to be declared as atomic.
25- // Currently work for clang and gcc.
34+
2635template <bool is_min, typename T>
27- T atomic_cmp (std::atomic<T> *ptr, T val, std::memory_order order) {
36+ T atomic_cmp (T *ptr, T val, std::memory_order order) {
2837 auto cmp = [](T old, T val) {
2938 if constexpr (is_min) {
3039 return old > val;
@@ -33,26 +42,43 @@ T atomic_cmp(std::atomic<T> *ptr, T val, std::memory_order order) {
3342 }
3443 };
3544
36- // First load
37- T old_val = ptr->load (order);
38- while (cmp (old_val, val)) {
39- if (ptr->compare_exchange_weak (old_val, val, order, order)) {
40- break ;
45+ T old_val;
46+ if constexpr (is_reinterpret_cast_to_atomic_safe<T>) {
47+ std::atomic<T> *atomic_ptr = reinterpret_cast <std::atomic<T> *>(ptr);
48+ old_val = atomic_ptr->load (order);
49+ while (cmp (old_val, val)) {
50+ if (atomic_ptr->compare_exchange_weak (old_val, val, order, order)) {
51+ break ;
52+ }
53+ }
54+ } else {
55+ const std::lock_guard<std::mutex> lock (atomic_op_guard);
56+ old_val = *ptr;
57+ if (cmp (old_val, val)) {
58+ *ptr = val;
4159 }
4260 }
4361 return old_val;
4462}
4563
46- template <typename T>
47- T atomic_fadd (std::atomic<T> *loc, T value, std::memory_order order) {
64+ template <typename T> T atomic_fadd (T *loc, T value, std::memory_order order) {
4865 static_assert (std::is_floating_point<T>::value,
4966 " T must be a floating-point type" );
50-
51- T old_value = loc->load (order);
52- T new_value;
53- do {
54- new_value = old_value + value;
55- } while (!loc->compare_exchange_weak (old_value, new_value, order, order));
67+ T old_value;
68+
69+ if constexpr (is_reinterpret_cast_to_atomic_safe<T>) {
70+ T new_value;
71+ std::atomic<T> *atomic_loc = reinterpret_cast <std::atomic<T> *>(loc);
72+ old_value = atomic_loc->load (order);
73+ do {
74+ new_value = old_value + value;
75+ } while (
76+ !atomic_loc->compare_exchange_weak (old_value, new_value, order, order));
77+ } else {
78+ const std::lock_guard<std::mutex> lock (atomic_op_guard);
79+ old_value = *loc;
80+ *loc = old_value + value;
81+ }
5682
5783 return old_value;
5884}
@@ -87,14 +113,14 @@ template <typename DType> class AtomicRMWOpBase : public AtomicOp {
87113protected:
88114 void applyAt (void *loc, size_t i) override final {
89115 if (mask[i]) {
90- std::atomic< DType> *atomic_ptr = static_cast <std::atomic< DType> *>(loc);
116+ DType *atomic_ptr = static_cast <DType *>(loc);
91117 *(static_cast <DType *>(ret) + i) =
92118 applyAtMasked (atomic_ptr, *(static_cast <const DType *>(val) + i),
93119 std::memory_order (order));
94120 }
95121 }
96122
97- virtual DType applyAtMasked (std::atomic< DType> *loc, const DType value,
123+ virtual DType applyAtMasked (DType *loc, const DType value,
98124 std::memory_order order) = 0;
99125
100126 const void *val;
@@ -115,9 +141,19 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::ADD>>
115141 using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
116142
117143protected:
118- DType applyAtMasked (std::atomic< DType> *loc, const DType value,
144+ DType applyAtMasked (DType *loc, const DType value,
119145 std::memory_order order) override {
120- return std::atomic_fetch_add (loc, value);
146+ DType old_val;
147+ if constexpr (is_reinterpret_cast_to_atomic_safe<DType>) {
148+ std::atomic<DType> *atomic_loc =
149+ reinterpret_cast <std::atomic<DType> *>(loc);
150+ old_val = std::atomic_fetch_add_explicit (atomic_loc, value, order);
151+ } else {
152+ const std::lock_guard<std::mutex> lock (atomic_op_guard);
153+ old_val = *loc;
154+ *loc = *loc + value;
155+ }
156+ return old_val;
121157 }
122158};
123159
@@ -128,9 +164,8 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::FADD>>
128164 using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
129165
130166protected:
131- DType applyAtMasked (std::atomic< DType> *loc, const DType value,
167+ DType applyAtMasked (DType *loc, const DType value,
132168 std::memory_order order) override {
133-
134169 return atomic_fadd (loc, value, order);
135170 }
136171};
@@ -142,9 +177,19 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::AND>>
142177 using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
143178
144179protected:
145- DType applyAtMasked (std::atomic< DType> *loc, const DType value,
180+ DType applyAtMasked (DType *loc, const DType value,
146181 std::memory_order order) override {
147- return std::atomic_fetch_and (loc, value);
182+ DType old_val;
183+ if constexpr (is_reinterpret_cast_to_atomic_safe<DType>) {
184+ std::atomic<DType> *atomic_loc =
185+ reinterpret_cast <std::atomic<DType> *>(loc);
186+ old_val = std::atomic_fetch_and_explicit (atomic_loc, value, order);
187+ } else {
188+ const std::lock_guard<std::mutex> lock (atomic_op_guard);
189+ old_val = *loc;
190+ *loc = *loc and value;
191+ }
192+ return old_val;
148193 }
149194};
150195
@@ -155,9 +200,19 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::OR>>
155200 using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
156201
157202protected:
158- DType applyAtMasked (std::atomic< DType> *loc, const DType value,
203+ DType applyAtMasked (DType *loc, const DType value,
159204 std::memory_order order) override {
160- return std::atomic_fetch_or (loc, value);
205+ DType old_val;
206+ if constexpr (is_reinterpret_cast_to_atomic_safe<DType>) {
207+ std::atomic<DType> *atomic_loc =
208+ reinterpret_cast <std::atomic<DType> *>(loc);
209+ old_val = std::atomic_fetch_or_explicit (atomic_loc, value, order);
210+ } else {
211+ const std::lock_guard<std::mutex> lock (atomic_op_guard);
212+ old_val = *loc;
213+ *loc = *loc or value;
214+ }
215+ return old_val;
161216 }
162217};
163218
@@ -168,9 +223,19 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::XOR>>
168223 using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
169224
170225protected:
171- DType applyAtMasked (std::atomic< DType> *loc, const DType value,
226+ DType applyAtMasked (DType *loc, const DType value,
172227 std::memory_order order) override {
173- return std::atomic_fetch_xor (loc, value);
228+ DType old_val;
229+ if constexpr (is_reinterpret_cast_to_atomic_safe<DType>) {
230+ std::atomic<DType> *atomic_loc =
231+ reinterpret_cast <std::atomic<DType> *>(loc);
232+ old_val = std::atomic_fetch_xor_explicit (atomic_loc, value, order);
233+ } else {
234+ const std::lock_guard<std::mutex> lock (atomic_op_guard);
235+ old_val = *loc;
236+ *loc = *loc xor value;
237+ }
238+ return old_val;
174239 }
175240};
176241
@@ -182,7 +247,7 @@ class AtomicRMWOp<DType, Op,
182247 using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
183248
184249protected:
185- DType applyAtMasked (std::atomic< DType> *loc, const DType value,
250+ DType applyAtMasked (DType *loc, const DType value,
186251 std::memory_order order) override {
187252 return atomic_cmp</* is_min=*/ false >(loc, value, order);
188253 }
@@ -196,7 +261,7 @@ class AtomicRMWOp<DType, Op,
196261 using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
197262
198263protected:
199- DType applyAtMasked (std::atomic< DType> *loc, const DType value,
264+ DType applyAtMasked (DType *loc, const DType value,
200265 std::memory_order order) override {
201266 return atomic_cmp</* is_min=*/ true >(loc, value, order);
202267 }
@@ -209,12 +274,44 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::XCHG>>
209274 using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
210275
211276protected:
212- DType applyAtMasked (std::atomic< DType> *loc, const DType value,
277+ DType applyAtMasked (DType *loc, const DType value,
213278 std::memory_order order) override {
214- return loc->exchange (value, order);
279+ DType old_val;
280+ if constexpr (is_reinterpret_cast_to_atomic_safe<DType>) {
281+ std::atomic<DType> *atomic_loc =
282+ reinterpret_cast <std::atomic<DType> *>(loc);
283+ old_val = atomic_loc->exchange (value, order);
284+ } else {
285+ const std::lock_guard<std::mutex> lock (atomic_op_guard);
286+ old_val = *loc;
287+ *loc = value;
288+ }
289+ return old_val;
215290 }
216291};
217292
293+ template <typename T>
294+ void atomic_compare_exchange_strong (void *loc, void *expected,
295+ const void *desired, size_t i,
296+ std::memory_order order) {
297+ T desired_val = *(static_cast <const T *>(desired) + i);
298+ T *expected_uint = static_cast <T *>(expected);
299+
300+ if constexpr (is_reinterpret_cast_to_atomic_safe<T>) {
301+ std::atomic<T> *atomic_loc = reinterpret_cast <std::atomic<T> *>(loc);
302+ atomic_loc->compare_exchange_strong (*(expected_uint + i), desired_val,
303+ order, order);
304+ } else {
305+ const std::lock_guard<std::mutex> lock (atomic_op_guard);
306+ T *atomic_loc = static_cast <T *>(loc);
307+ if (*atomic_loc == *(expected_uint + i)) {
308+ *atomic_loc = desired_val;
309+ } else {
310+ *(expected_uint + i) = *atomic_loc;
311+ }
312+ }
313+ }
314+
218315class AtomicCASOp : public AtomicOp {
219316public:
220317 AtomicCASOp (const uint64_t *ptr, void *expected, const void *desired,
@@ -227,46 +324,18 @@ class AtomicCASOp : public AtomicOp {
227324 // Atomic operations perform bitwise comparison, so it's safe to
228325 // use number of bytes (itemsize) to determine the type of pointers
229326 if (itemsize == 1 ) {
230- std::atomic<uint8_t > *atomic_loc =
231- reinterpret_cast <std::atomic<uint8_t > *>(loc);
232- uint8_t desired_val = *(static_cast <const uint8_t *>(desired) + i);
233- uint8_t *expected_uint = static_cast <uint8_t *>(expected);
234- // Perform the compare and exchange operation
235- atomic_loc->compare_exchange_strong (*(expected_uint + i), desired_val,
236- std::memory_order (order),
237- std::memory_order (order));
238-
327+ atomic_compare_exchange_strong<uint8_t >(loc, expected, desired, i,
328+ std::memory_order (order));
239329 } else if (itemsize == 2 ) {
240- std::atomic<uint16_t > *atomic_loc =
241- reinterpret_cast <std::atomic<uint16_t > *>(loc);
242- uint16_t desired_val = *(static_cast <const uint16_t *>(desired) + i);
243- uint16_t *expected_uint = static_cast <uint16_t *>(expected);
244- atomic_loc->compare_exchange_strong (*(expected_uint + i), desired_val,
245- std::memory_order (order),
246- std::memory_order (order));
330+ atomic_compare_exchange_strong<uint16_t >(loc, expected, desired, i,
331+ std::memory_order (order));
247332 } else if (itemsize == 4 ) {
248- std::atomic<uint32_t > *atomic_loc =
249- reinterpret_cast <std::atomic<uint32_t > *>(loc);
250- uint32_t desired_val = *(static_cast <const uint32_t *>(desired) + i);
251- uint32_t *expected_uint = static_cast <uint32_t *>(expected);
252- atomic_loc->compare_exchange_strong (*(expected_uint + i), desired_val,
253- std::memory_order (order),
254- std::memory_order (order));
333+ atomic_compare_exchange_strong<uint32_t >(loc, expected, desired, i,
334+ std::memory_order (order));
255335 } else if (itemsize == 8 ) {
256- uint64_t desired_val = *(static_cast <const uint64_t *>(desired) + i);
257- std::atomic<uint64_t > *atomic_loc =
258- static_cast <std::atomic<uint64_t > *>(loc);
259- uint64_t *expected_uint = static_cast <uint64_t *>(expected);
260- atomic_loc->compare_exchange_strong (*(expected_uint + i), desired_val,
261- std::memory_order (order),
262- std::memory_order (order));
263-
336+ atomic_compare_exchange_strong<uint64_t >(loc, expected, desired, i,
337+ std::memory_order (order));
264338 } else {
265- // The ‘__atomic’ builtins can be used with any integral scalar or pointer
266- // type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are
267- // also allowed if ‘__int128’ (see 128-bit Integers) is supported by the
268- // architecture.
269- // https://gcc.gnu.org/onlinedocs/gcc/_005f_005fatomic-Builtins.html
270339 throw std::invalid_argument (" Invalid byte size" );
271340 }
272341 }
0 commit comments