Skip to content

Commit f315b8f

Browse files
committed
[cilksan] Fix logic for combining reducer views to handle cases where reducers are registered and used in nested Cilk contexts, such as in nested parallel loops. Fix miscellaneous typos.
1 parent 0541b1d commit f315b8f

File tree

3 files changed

+227
-6
lines changed

3 files changed

+227
-6
lines changed

cilksan/reducers.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "cilksan_internal.h"
22
#include "debug_util.h"
33
#include "driver.h"
4+
#include "vector.h"
45
#include <cstdarg>
56
#include <cstdio>
67
#include <cstdlib>
@@ -17,6 +18,9 @@ static void reducer_register(const csi_id_t call_id, unsigned MAAP_count,
1718
reducer_views->insert((hyper_table::bucket){
1819
.key = (uintptr_t)key,
1920
.value = {.view = key, .reduce_fn = (__cilk_reduce_fn)reduce_ptr}});
21+
DBG_TRACE(REDUCER,
22+
"reducer_register: registered %p, reducer_views %p, occupancy %d\n",
23+
key, reducer_views, reducer_views->occupancy);
2024
}
2125

2226
if (!is_execution_parallel())
@@ -58,6 +62,10 @@ CILKSAN_API void __csan_llvm_reducer_unregister(const csi_id_t call_id,
5862

5963
// Remove this reducer from the table.
6064
if (hyper_table *reducer_views = CilkSanImpl.get_reducer_views()) {
65+
DBG_TRACE(
66+
REDUCER,
67+
"reducer_unregister: unregistering %p, reducer_views %p, occupancy %d\n",
68+
key, reducer_views, reducer_views->occupancy);
6169
reducer_views->remove((uintptr_t)key);
6270
}
6371

@@ -139,10 +147,16 @@ void CilkSanImpl_t::reduce_local_views() {
139147
// Reduce every reducer view in the table with its leftmost view.
140148
int32_t capacity = reducer_views->capacity;
141149
hyper_table::bucket *buckets = reducer_views->buckets;
150+
bool holdsLeftmostViews = false;
151+
Vector_t<int32_t> keysToRemove;
142152
for (int32_t i = 0; i < capacity; ++i) {
143153
hyper_table::bucket b = buckets[i];
144154
if (!is_valid(b.key))
145155
continue;
156+
if (b.key == (uintptr_t)(b.value.view)) {
157+
holdsLeftmostViews = true;
158+
continue;
159+
}
146160

147161
DBG_TRACE(REDUCER,
148162
"reduce_local_views: found view to reduce at %d: %p -> %p\n", i,
@@ -154,14 +168,20 @@ void CilkSanImpl_t::reduce_local_views() {
154168
// Delete the right view.
155169
free(rb.view);
156170
mark_free(rb.view);
171+
keysToRemove.push_back(b.key);
157172
}
158173
enable_checking();
159174

160-
// Delete the table of local reducer views
161-
DBG_TRACE(REDUCER, "reduce_local_views: delete reducer_views %p\n",
162-
reducer_views);
163-
delete reducer_views;
164-
f->reducer_views = nullptr;
175+
if (!holdsLeftmostViews) {
176+
// Delete the table of local reducer views
177+
DBG_TRACE(REDUCER, "reduce_local_views: delete reducer_views %p\n",
178+
reducer_views);
179+
delete reducer_views;
180+
f->reducer_views = nullptr;
181+
} else {
182+
for (int32_t i = 0; i < keysToRemove.size(); ++i)
183+
reducer_views->remove(buckets[keysToRemove[i]].key);
184+
}
165185
}
166186

167187
hyper_table *

test/cilksan/TestCases/alloctypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ void global_test() {
9393

9494
// CHECK-GLOBAL: Race detected on location [[GLOBAL]]
9595
// CHECK-GLOBAL: * Write {{[0-9a-f]+}} global_test
96-
// CHECK-GLOBALOB: * Write {{[0-9a-f]+}} global_test
96+
// CHECK-GLOBAL: * Write {{[0-9a-f]+}} global_test
9797
// CHECK-GLOBAL: Common calling context
9898
// CHECK-GLOBAL-NEXT: Parfor
9999

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
// RUN: %clangxx_cilksan -std=c++20 -fopencilk -O3 -g %s -o %t
2+
// RUN: %run %t 2>&1 | FileCheck %s --check-prefixes=CHECK,CILKSAN
3+
#include <array>
4+
#include <cstddef>
5+
#include <cstdint>
6+
#include <cstdio>
7+
#include <random>
8+
9+
#include <cilk/cilk.h>
10+
11+
struct Scalar {
12+
constexpr static uint64_t PRIME = 0xffffffff00000001ull;
13+
14+
// Scalar() { /* Deliberately skip initialization of _raw to allow more
15+
// aggresive optimization */
16+
// }
17+
18+
Scalar() = default;
19+
20+
explicit Scalar(uint64_t raw) : _raw{raw} {}
21+
22+
explicit operator uint64_t() const { return _raw; }
23+
24+
auto operator+(Scalar other) const -> Scalar {
25+
const uint64_t sum = _raw + other._raw;
26+
const Scalar ret{
27+
sum < _raw || sum < other._raw || sum >= PRIME ? sum - PRIME : sum};
28+
return ret;
29+
}
30+
31+
auto operator-(Scalar other) const -> Scalar {
32+
const uint64_t diff = _raw - other._raw;
33+
const Scalar ret{(diff > _raw) ? diff + PRIME : diff};
34+
return ret;
35+
}
36+
37+
auto operator*(Scalar other) const -> Scalar {
38+
// Start by carrying out an ordinary 64x64->128 bit multiplication
39+
const uint32_t a0 = _raw, a1 = _raw >> 32;
40+
const uint32_t b0 = other._raw, b1 = other._raw >> 32;
41+
const uint64_t p0 = static_cast<uint64_t>(a0) * b0,
42+
p1 = static_cast<uint64_t>(a0) * b1,
43+
p2 = static_cast<uint64_t>(a1) * b0,
44+
p3 = static_cast<uint64_t>(a1) * b1;
45+
const uint32_t cy = ((p0 >> 32) + static_cast<uint32_t>(p1) +
46+
static_cast<uint32_t>(p2)) >>
47+
32;
48+
const uint64_t x = p0 + (p1 << 32) + (p2 << 32),
49+
y = p3 + (p1 >> 32) + (p2 >> 32) + cy;
50+
// Store result in 4 32-bit words
51+
const uint32_t c0 = x, c1 = x >> 32, c2 = y, c3 = y >> 32;
52+
// Now perform reduction: modulus is phi^2 - phi + 1 where phi = 2^32
53+
// ab = c0 + c1*phi + c2*phi^2 + c3*phi^3
54+
// Exploit phi^2 = phi-1 and phi^3 = phi * (phi-1) = (phi-1) - phi = -1
55+
// ab = c0 + c1*phi + c2*(phi-1) - c3
56+
// = (c0-c2-c3) + (c1+c2)*phi
57+
const Scalar ret = (Scalar(c0) - Scalar(c2) - Scalar(c3)) +
58+
(Scalar(static_cast<uint64_t>(c1) << 32) +
59+
Scalar(static_cast<uint64_t>(c2) << 32));
60+
return ret;
61+
}
62+
63+
auto operator==(Scalar other) const -> bool { return _raw == other._raw; }
64+
65+
auto operator!=(Scalar other) const -> bool { return _raw != other._raw; }
66+
67+
// No comparison operators as they make no sense for finite fields
68+
69+
auto operator+=(Scalar other) -> Scalar & { return *this = *this + other; }
70+
71+
auto operator-=(Scalar other) -> Scalar & { return *this = *this - other; }
72+
73+
auto operator*=(Scalar other) -> Scalar & { return *this = *this * other; }
74+
75+
template <typename RNG> inline static auto random(RNG &rng) -> Scalar {
76+
static std::uniform_int_distribution<uint64_t> dist(0, PRIME - 1);
77+
return Scalar{dist(rng)};
78+
}
79+
80+
auto is_valid() const -> bool { return _raw < PRIME; }
81+
82+
private:
83+
uint64_t _raw;
84+
};
85+
86+
static inline void zero_scalar(void *view) {
87+
*reinterpret_cast<Scalar *>(view) = Scalar{0};
88+
}
89+
90+
static inline void add_scalar(void *left, void *right) {
91+
*reinterpret_cast<Scalar *>(left) += *reinterpret_cast<Scalar *>(right);
92+
}
93+
94+
using ScalarAddReducer = Scalar cilk_reducer(zero_scalar, add_scalar);
95+
96+
auto reduce_with_cilk(Scalar **as, Scalar **bs, Scalar *c, Scalar *coeffs,
97+
size_t n, size_t m) -> std::array<Scalar, 3> {
98+
ScalarAddReducer p0{0}, p2{0}, p3{0};
99+
cilk_for (size_t i = 0; i < n; i++) {
100+
// Obtain dense representations of the polynomials
101+
const Scalar *a = as[i];
102+
const Scalar *b = bs[i];
103+
const size_t half = m / 2;
104+
105+
ScalarAddReducer lp0{0}, lp2{0}, lp3{0};
106+
cilk_for (size_t j = 0; j < half; j++) {
107+
lp0 += a[j] * b[j] * c[j];
108+
const Scalar a2 = a[j + half] + a[j + half] - a[j],
109+
b2 = b[j + half] + b[j + half] - b[j],
110+
c2 = c[j + half] + c[j + half] - c[j];
111+
lp2 += a2 * b2 * c2;
112+
const Scalar a3 = a2 + a[j + half] - a[j],
113+
b3 = b2 + b[j + half] - b[j],
114+
c3 = c2 + c[j + half] - c[j];
115+
lp3 += a3 * b3 * c3;
116+
}
117+
p0 += coeffs[i] * lp0;
118+
p2 += coeffs[i] * lp2;
119+
p3 += coeffs[i] * lp3;
120+
}
121+
return {p0, p2, p3};
122+
}
123+
124+
auto reduce_serial(Scalar **as, Scalar **bs, Scalar *c, Scalar *coeffs,
125+
size_t n, size_t m) -> std::array<Scalar, 3> {
126+
Scalar p0{0}, p2{0}, p3{0};
127+
for (size_t i = 0; i < n; i++) {
128+
// Obtain dense representations of the polynomials
129+
const Scalar *a = as[i];
130+
const Scalar *b = bs[i];
131+
const size_t half = m / 2;
132+
133+
Scalar lp0{0}, lp2{0}, lp3{0};
134+
for (size_t j = 0; j < half; j++) {
135+
lp0 += a[j] * b[j] * c[j];
136+
const Scalar a2 = a[j + half] + a[j + half] - a[j],
137+
b2 = b[j + half] + b[j + half] - b[j],
138+
c2 = c[j + half] + c[j + half] - c[j];
139+
lp2 += a2 * b2 * c2;
140+
const Scalar a3 = a2 + a[j + half] - a[j],
141+
b3 = b2 + b[j + half] - b[j],
142+
c3 = c2 + c[j + half] - c[j];
143+
lp3 += a3 * b3 * c3;
144+
}
145+
146+
p0 += coeffs[i] * lp0;
147+
p2 += coeffs[i] * lp2;
148+
p3 += coeffs[i] * lp3;
149+
}
150+
return {p0, p2, p3};
151+
}
152+
153+
auto main() -> int {
154+
const size_t N = 12;
155+
const size_t M = 128;
156+
157+
std::mt19937_64 rng{42};
158+
159+
Scalar *as[N];
160+
Scalar *bs[N];
161+
Scalar c[M];
162+
Scalar coeffs[N];
163+
164+
for (size_t i = 0; i < N; i++) {
165+
as[i] = new Scalar[M];
166+
bs[i] = new Scalar[M];
167+
coeffs[i] = Scalar::random(rng);
168+
for (size_t j = 0; j < M; j++) {
169+
as[i][j] = Scalar::random(rng);
170+
bs[i][j] = Scalar::random(rng);
171+
}
172+
}
173+
for (size_t i = 0; i < M; i++) {
174+
c[i] = Scalar::random(rng);
175+
}
176+
177+
const auto res_cilk = reduce_with_cilk(as, bs, c, coeffs, N, M);
178+
const auto res_serial = reduce_serial(as, bs, c, coeffs, N, M);
179+
if (res_cilk != res_serial) {
180+
printf("res_cilk = %lu, %lu, %lu\n", static_cast<uint64_t>(res_cilk[0]),
181+
static_cast<uint64_t>(res_cilk[1]),
182+
static_cast<uint64_t>(res_cilk[2]));
183+
printf("res_serial = %lu, %lu, %lu\n",
184+
static_cast<uint64_t>(res_serial[0]),
185+
static_cast<uint64_t>(res_serial[1]),
186+
static_cast<uint64_t>(res_serial[2]));
187+
}
188+
for (size_t i = 0; i < N; i++) {
189+
delete[] as[i];
190+
delete[] bs[i];
191+
}
192+
return 0;
193+
}
194+
195+
// NOLINTEND
196+
197+
// CHECK-NOT: res_cilk =
198+
// CHECK-NOT: res_serial =
199+
200+
// CILKSAN: Cilksan detected 0 distinct races.
201+
// CILKSAN-NEXT: Cilksan suppressed 0 duplicate race reports.

0 commit comments

Comments
 (0)