Skip to content

Commit 36d139c

Browse files
committed
try to make reinterpret_cast<atomic> more safe
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent b685776 commit 36d139c

File tree

1 file changed

+138
-69
lines changed

1 file changed

+138
-69
lines changed

python/src/interpreter.cc

Lines changed: 138 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <iostream>
33
#include <map>
44
#include <memory>
5+
#include <mutex>
56
#include <pybind11/numpy.h>
67
#include <pybind11/pybind11.h>
78
#include <type_traits>
@@ -12,6 +13,16 @@ namespace {
1213

1314
enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED };
1415

16+
std::mutex atomic_op_guard;
17+
18+
template <typename T>
19+
constexpr bool is_reinterpret_cast_to_atomic_safe =
20+
std::is_trivially_copyable_v<T> &&
21+
std::is_trivially_copyable_v<std::atomic<T>> &&
22+
std::is_standard_layout_v<T> && std::is_standard_layout_v<std::atomic<T>> &&
23+
sizeof(T) == sizeof(std::atomic<T>) &&
24+
alignof(T) == alignof(std::atomic<T>);
25+
1526
enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX };
1627

1728
std::map<MemSemantic, int> mem_semantic_map = {
@@ -20,11 +31,9 @@ std::map<MemSemantic, int> mem_semantic_map = {
2031
{MemSemantic::RELEASE, static_cast<int>(std::memory_order_release)},
2132
{MemSemantic::RELAXED, static_cast<int>(std::memory_order_relaxed)},
2233
};
23-
// Use compiler builtin atomics instead of std::atomic which requires
24-
// each variable to be declared as atomic.
25-
// Currently work for clang and gcc.
34+
2635
template <bool is_min, typename T>
27-
T atomic_cmp(std::atomic<T> *ptr, T val, std::memory_order order) {
36+
T atomic_cmp(T *ptr, T val, std::memory_order order) {
2837
auto cmp = [](T old, T val) {
2938
if constexpr (is_min) {
3039
return old > val;
@@ -33,26 +42,43 @@ T atomic_cmp(std::atomic<T> *ptr, T val, std::memory_order order) {
3342
}
3443
};
3544

36-
// First load
37-
T old_val = ptr->load(order);
38-
while (cmp(old_val, val)) {
39-
if (ptr->compare_exchange_weak(old_val, val, order, order)) {
40-
break;
45+
T old_val;
46+
if constexpr (is_reinterpret_cast_to_atomic_safe<T>) {
47+
std::atomic<T> *atomic_ptr = reinterpret_cast<std::atomic<T> *>(ptr);
48+
old_val = atomic_ptr->load(order);
49+
while (cmp(old_val, val)) {
50+
if (atomic_ptr->compare_exchange_weak(old_val, val, order, order)) {
51+
break;
52+
}
53+
}
54+
} else {
55+
const std::lock_guard<std::mutex> lock(atomic_op_guard);
56+
old_val = *ptr;
57+
if (cmp(old_val, val)) {
58+
*ptr = val;
4159
}
4260
}
4361
return old_val;
4462
}
4563

46-
template <typename T>
47-
T atomic_fadd(std::atomic<T> *loc, T value, std::memory_order order) {
64+
template <typename T> T atomic_fadd(T *loc, T value, std::memory_order order) {
4865
static_assert(std::is_floating_point<T>::value,
4966
"T must be a floating-point type");
50-
51-
T old_value = loc->load(order);
52-
T new_value;
53-
do {
54-
new_value = old_value + value;
55-
} while (!loc->compare_exchange_weak(old_value, new_value, order, order));
67+
T old_value;
68+
69+
if constexpr (is_reinterpret_cast_to_atomic_safe<T>) {
70+
T new_value;
71+
std::atomic<T> *atomic_loc = reinterpret_cast<std::atomic<T> *>(loc);
72+
old_value = atomic_loc->load(order);
73+
do {
74+
new_value = old_value + value;
75+
} while (
76+
!atomic_loc->compare_exchange_weak(old_value, new_value, order, order));
77+
} else {
78+
const std::lock_guard<std::mutex> lock(atomic_op_guard);
79+
old_value = *loc;
80+
*loc = old_value + value;
81+
}
5682

5783
return old_value;
5884
}
@@ -87,14 +113,14 @@ template <typename DType> class AtomicRMWOpBase : public AtomicOp {
87113
protected:
88114
void applyAt(void *loc, size_t i) override final {
89115
if (mask[i]) {
90-
std::atomic<DType> *atomic_ptr = static_cast<std::atomic<DType> *>(loc);
116+
DType *atomic_ptr = static_cast<DType *>(loc);
91117
*(static_cast<DType *>(ret) + i) =
92118
applyAtMasked(atomic_ptr, *(static_cast<const DType *>(val) + i),
93119
std::memory_order(order));
94120
}
95121
}
96122

97-
virtual DType applyAtMasked(std::atomic<DType> *loc, const DType value,
123+
virtual DType applyAtMasked(DType *loc, const DType value,
98124
std::memory_order order) = 0;
99125

100126
const void *val;
@@ -115,9 +141,19 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::ADD>>
115141
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
116142

117143
protected:
118-
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
144+
DType applyAtMasked(DType *loc, const DType value,
119145
std::memory_order order) override {
120-
return std::atomic_fetch_add(loc, value);
146+
DType old_val;
147+
if constexpr (is_reinterpret_cast_to_atomic_safe<DType>) {
148+
std::atomic<DType> *atomic_loc =
149+
reinterpret_cast<std::atomic<DType> *>(loc);
150+
old_val = std::atomic_fetch_add_explicit(atomic_loc, value, order);
151+
} else {
152+
const std::lock_guard<std::mutex> lock(atomic_op_guard);
153+
old_val = *loc;
154+
*loc = *loc + value;
155+
}
156+
return old_val;
121157
}
122158
};
123159

@@ -128,9 +164,8 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::FADD>>
128164
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
129165

130166
protected:
131-
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
167+
DType applyAtMasked(DType *loc, const DType value,
132168
std::memory_order order) override {
133-
134169
return atomic_fadd(loc, value, order);
135170
}
136171
};
@@ -142,9 +177,19 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::AND>>
142177
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
143178

144179
protected:
145-
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
180+
DType applyAtMasked(DType *loc, const DType value,
146181
std::memory_order order) override {
147-
return std::atomic_fetch_and(loc, value);
182+
DType old_val;
183+
if constexpr (is_reinterpret_cast_to_atomic_safe<DType>) {
184+
std::atomic<DType> *atomic_loc =
185+
reinterpret_cast<std::atomic<DType> *>(loc);
186+
old_val = std::atomic_fetch_and_explicit(atomic_loc, value, order);
187+
} else {
188+
const std::lock_guard<std::mutex> lock(atomic_op_guard);
189+
old_val = *loc;
190+
*loc = *loc and value;
191+
}
192+
return old_val;
148193
}
149194
};
150195

@@ -155,9 +200,19 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::OR>>
155200
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
156201

157202
protected:
158-
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
203+
DType applyAtMasked(DType *loc, const DType value,
159204
std::memory_order order) override {
160-
return std::atomic_fetch_or(loc, value);
205+
DType old_val;
206+
if constexpr (is_reinterpret_cast_to_atomic_safe<DType>) {
207+
std::atomic<DType> *atomic_loc =
208+
reinterpret_cast<std::atomic<DType> *>(loc);
209+
old_val = std::atomic_fetch_or_explicit(atomic_loc, value, order);
210+
} else {
211+
const std::lock_guard<std::mutex> lock(atomic_op_guard);
212+
old_val = *loc;
213+
*loc = *loc or value;
214+
}
215+
return old_val;
161216
}
162217
};
163218

@@ -168,9 +223,19 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::XOR>>
168223
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
169224

170225
protected:
171-
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
226+
DType applyAtMasked(DType *loc, const DType value,
172227
std::memory_order order) override {
173-
return std::atomic_fetch_xor(loc, value);
228+
DType old_val;
229+
if constexpr (is_reinterpret_cast_to_atomic_safe<DType>) {
230+
std::atomic<DType> *atomic_loc =
231+
reinterpret_cast<std::atomic<DType> *>(loc);
232+
old_val = std::atomic_fetch_xor_explicit(atomic_loc, value, order);
233+
} else {
234+
const std::lock_guard<std::mutex> lock(atomic_op_guard);
235+
old_val = *loc;
236+
*loc = *loc xor value;
237+
}
238+
return old_val;
174239
}
175240
};
176241

@@ -182,7 +247,7 @@ class AtomicRMWOp<DType, Op,
182247
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
183248

184249
protected:
185-
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
250+
DType applyAtMasked(DType *loc, const DType value,
186251
std::memory_order order) override {
187252
return atomic_cmp</*is_min=*/false>(loc, value, order);
188253
}
@@ -196,7 +261,7 @@ class AtomicRMWOp<DType, Op,
196261
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
197262

198263
protected:
199-
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
264+
DType applyAtMasked(DType *loc, const DType value,
200265
std::memory_order order) override {
201266
return atomic_cmp</*is_min=*/true>(loc, value, order);
202267
}
@@ -209,12 +274,44 @@ class AtomicRMWOp<DType, Op, std::enable_if_t<Op == RMWOp::XCHG>>
209274
using AtomicRMWOpBase<DType>::AtomicRMWOpBase;
210275

211276
protected:
212-
DType applyAtMasked(std::atomic<DType> *loc, const DType value,
277+
DType applyAtMasked(DType *loc, const DType value,
213278
std::memory_order order) override {
214-
return loc->exchange(value, order);
279+
DType old_val;
280+
if constexpr (is_reinterpret_cast_to_atomic_safe<DType>) {
281+
std::atomic<DType> *atomic_loc =
282+
reinterpret_cast<std::atomic<DType> *>(loc);
283+
old_val = atomic_loc->exchange(value, order);
284+
} else {
285+
const std::lock_guard<std::mutex> lock(atomic_op_guard);
286+
old_val = *loc;
287+
*loc = value;
288+
}
289+
return old_val;
215290
}
216291
};
217292

293+
template <typename T>
294+
void atomic_compare_exchange_strong(void *loc, void *expected,
295+
const void *desired, size_t i,
296+
std::memory_order order) {
297+
T desired_val = *(static_cast<const T *>(desired) + i);
298+
T *expected_uint = static_cast<T *>(expected);
299+
300+
if constexpr (is_reinterpret_cast_to_atomic_safe<T>) {
301+
std::atomic<T> *atomic_loc = reinterpret_cast<std::atomic<T> *>(loc);
302+
atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val,
303+
order, order);
304+
} else {
305+
const std::lock_guard<std::mutex> lock(atomic_op_guard);
306+
T *atomic_loc = static_cast<T *>(loc);
307+
if (*atomic_loc == *(expected_uint + i)) {
308+
*atomic_loc = desired_val;
309+
} else {
310+
*(expected_uint + i) = *atomic_loc;
311+
}
312+
}
313+
}
314+
218315
class AtomicCASOp : public AtomicOp {
219316
public:
220317
AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired,
@@ -227,46 +324,18 @@ class AtomicCASOp : public AtomicOp {
227324
// Atomic operations perform bitwise comparison, so it's safe to
228325
// use number of bytes (itemsize) to determine the type of pointers
229326
if (itemsize == 1) {
230-
std::atomic<uint8_t> *atomic_loc =
231-
reinterpret_cast<std::atomic<uint8_t> *>(loc);
232-
uint8_t desired_val = *(static_cast<const uint8_t *>(desired) + i);
233-
uint8_t *expected_uint = static_cast<uint8_t *>(expected);
234-
// Perform the compare and exchange operation
235-
atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val,
236-
std::memory_order(order),
237-
std::memory_order(order));
238-
327+
atomic_compare_exchange_strong<uint8_t>(loc, expected, desired, i,
328+
std::memory_order(order));
239329
} else if (itemsize == 2) {
240-
std::atomic<uint16_t> *atomic_loc =
241-
reinterpret_cast<std::atomic<uint16_t> *>(loc);
242-
uint16_t desired_val = *(static_cast<const uint16_t *>(desired) + i);
243-
uint16_t *expected_uint = static_cast<uint16_t *>(expected);
244-
atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val,
245-
std::memory_order(order),
246-
std::memory_order(order));
330+
atomic_compare_exchange_strong<uint16_t>(loc, expected, desired, i,
331+
std::memory_order(order));
247332
} else if (itemsize == 4) {
248-
std::atomic<uint32_t> *atomic_loc =
249-
reinterpret_cast<std::atomic<uint32_t> *>(loc);
250-
uint32_t desired_val = *(static_cast<const uint32_t *>(desired) + i);
251-
uint32_t *expected_uint = static_cast<uint32_t *>(expected);
252-
atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val,
253-
std::memory_order(order),
254-
std::memory_order(order));
333+
atomic_compare_exchange_strong<uint32_t>(loc, expected, desired, i,
334+
std::memory_order(order));
255335
} else if (itemsize == 8) {
256-
uint64_t desired_val = *(static_cast<const uint64_t *>(desired) + i);
257-
std::atomic<uint64_t> *atomic_loc =
258-
static_cast<std::atomic<uint64_t> *>(loc);
259-
uint64_t *expected_uint = static_cast<uint64_t *>(expected);
260-
atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val,
261-
std::memory_order(order),
262-
std::memory_order(order));
263-
336+
atomic_compare_exchange_strong<uint64_t>(loc, expected, desired, i,
337+
std::memory_order(order));
264338
} else {
265-
// The ‘__atomic’ builtins can be used with any integral scalar or pointer
266-
// type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are
267-
// also allowed if ‘__int128’ (see 128-bit Integers) is supported by the
268-
// architecture.
269-
// https://gcc.gnu.org/onlinedocs/gcc/_005f_005fatomic-Builtins.html
270339
throw std::invalid_argument("Invalid byte size");
271340
}
272341
}

0 commit comments

Comments
 (0)