Skip to content

Commit a0986f2

Browse files
committed
Improve: ChillOut to send/sync pointers
Closes: ashvardanian/ForkUnion#7
1 parent 015a535 commit a0986f2

File tree

1 file changed

+30
-22
lines changed

1 file changed

+30
-22
lines changed

reduce_bench.rs

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,18 @@ use futures::future::join_all;
1313
const MAX_CACHE_LINE_SIZE: usize = 64; // bytes on x86; adjust if needed
1414
const SCALARS_PER_CACHE_LINE: usize = MAX_CACHE_LINE_SIZE / std::mem::size_of::<f32>();
1515

16+
/// Wrapper that opt-in implements `Send`/`Sync` for a raw pointer.
17+
///
18+
/// Why? Rust deliberately withholds these auto-traits from raw pointers,
19+
/// so we must promise the compiler that the pointed-to memory really is
20+
/// shared-mutable and properly synchronized (which it is here: every thread
21+
/// writes its own unique slot).
22+
#[derive(Copy, Clone)]
23+
struct ChillOut<T>(*mut T);
24+
25+
unsafe impl<T> Send for ChillOut<T> {}
26+
unsafe impl<T> Sync for ChillOut<T> {}
27+
1628
#[inline(always)]
1729
fn divide_round_up(value: usize, divisor: usize) -> usize {
1830
(value + divisor - 1) / divisor
@@ -89,19 +101,18 @@ pub fn prepare_input() -> Vec<f32> {
89101
pub fn sum_fork_union(pool: &fork_union::ForkUnion, data: &[f32], partial_sums: &mut [f64]) -> f64 {
90102
let cores = pool.thread_count();
91103
let chunk_size = scalars_per_core(data.len(), cores);
92-
let partial_sums_ptr = partial_sums.as_mut_ptr() as usize;
104+
let partial_sums_ptr = ChillOut(partial_sums.as_mut_ptr());
93105

94-
pool.for_each_thread(|thread_index| unsafe {
106+
pool.for_each_thread(move |thread_index| unsafe {
95107
let start = thread_index * chunk_size;
96108
if start >= data.len() {
97109
return;
98110
}
99111
let stop = usize::min(start + chunk_size, data.len());
100112
let partial_sum = sum_unrolled(&data[start..stop]);
101-
ptr::write(
102-
(partial_sums_ptr as *mut f64).add(thread_index),
103-
partial_sum,
104-
);
113+
// Ensure the entire wrapper, not just `.0`, is moved into the closure
114+
let partial_sums_ptr = partial_sums_ptr;
115+
ptr::write(partial_sums_ptr.0.add(thread_index), partial_sum);
105116
});
106117

107118
partial_sums.iter().copied().sum()
@@ -111,21 +122,20 @@ pub fn sum_fork_union(pool: &fork_union::ForkUnion, data: &[f32], partial_sums:
111122
pub fn sum_rayon(pool: &rayon::ThreadPool, data: &[f32], partial_sums: &mut [f64]) -> f64 {
112123
let cores = pool.current_num_threads();
113124
let chunk_size = scalars_per_core(data.len(), cores);
114-
let partial_sums_ptr = partial_sums.as_mut_ptr() as usize;
125+
let partial_sums_ptr = ChillOut(partial_sums.as_mut_ptr());
115126

116-
pool.broadcast(|context: rayon::BroadcastContext<'_>| {
127+
pool.broadcast(move |context: rayon::BroadcastContext<'_>| {
117128
let thread_index = context.index();
118129
let start = thread_index * chunk_size;
119130
if start >= data.len() {
120131
return;
121132
}
122133
let stop = std::cmp::min(start + chunk_size, data.len());
123134
let partial_sum = sum_unrolled(&data[start..stop]);
135+
// Ensure the entire wrapper, not just `.0`, is moved into the closure
136+
let partial_sums_ptr = partial_sums_ptr;
124137
unsafe {
125-
ptr::write(
126-
(partial_sums_ptr as *mut f64).add(thread_index),
127-
partial_sum,
128-
);
138+
ptr::write(partial_sums_ptr.0.add(thread_index), partial_sum);
129139
}
130140
});
131141

@@ -137,7 +147,7 @@ pub fn sum_rayon(pool: &rayon::ThreadPool, data: &[f32], partial_sums: &mut [f64
137147
pub fn sum_tokio(pool: &tokio::runtime::Runtime, data: &[f32], partial_sums: &mut [f64]) -> f64 {
138148
let cores = num_cpus::get();
139149
let chunk_size = scalars_per_core(data.len(), cores);
140-
let partial_sums_ptr = partial_sums.as_mut_ptr() as usize;
150+
let partial_sums_ptr = ChillOut(partial_sums.as_mut_ptr());
141151

142152
// Raw parts of the slice – immutable, lives as long as `data`.
143153
let ptr = data.as_ptr() as usize;
@@ -154,10 +164,9 @@ pub fn sum_tokio(pool: &tokio::runtime::Runtime, data: &[f32], partial_sums: &mu
154164
let stop = std::cmp::min(start + chunk_size, len);
155165
let slice = std::slice::from_raw_parts((ptr as *mut f32).add(start), stop - start);
156166
let partial_sum = sum_unrolled(slice);
157-
ptr::write(
158-
(partial_sums_ptr as *mut f64).add(thread_index),
159-
partial_sum,
160-
);
167+
// Ensure the entire wrapper, not just `.0`, is moved into the closure
168+
let partial_sums_ptr = partial_sums_ptr;
169+
ptr::write(partial_sums_ptr.0.add(thread_index), partial_sum);
161170
});
162171
handles.push(handle);
163172
}
@@ -171,7 +180,7 @@ pub fn sum_tokio(pool: &tokio::runtime::Runtime, data: &[f32], partial_sums: &mu
171180
pub fn sum_smol(pool: &async_executor::Executor, data: &[f32], partial_sums: &mut [f64]) -> f64 {
172181
let cores = num_cpus::get();
173182
let chunk_size = scalars_per_core(data.len(), cores);
174-
let partial_sums_ptr = partial_sums.as_mut_ptr() as usize;
183+
let partial_sums_ptr = ChillOut(partial_sums.as_mut_ptr());
175184

176185
let ptr = data.as_ptr() as usize;
177186
let len = data.len();
@@ -189,10 +198,9 @@ pub fn sum_smol(pool: &async_executor::Executor, data: &[f32], partial_sums: &mu
189198
let slice =
190199
std::slice::from_raw_parts((ptr as *mut f32).add(start), stop - start);
191200
let partial_sum = sum_unrolled(slice);
192-
ptr::write(
193-
(partial_sums_ptr as *mut f64).add(thread_index),
194-
partial_sum,
195-
);
201+
// Ensure the entire wrapper, not just `.0`, is moved into the closure
202+
let partial_sums_ptr = partial_sums_ptr;
203+
ptr::write(partial_sums_ptr.0.add(thread_index), partial_sum);
196204
}
197205
}));
198206
}

0 commit comments

Comments
 (0)