@@ -13,6 +13,18 @@ use futures::future::join_all;
1313const MAX_CACHE_LINE_SIZE : usize = 64 ; // bytes on x86; adjust if needed
1414const SCALARS_PER_CACHE_LINE : usize = MAX_CACHE_LINE_SIZE / std:: mem:: size_of :: < f32 > ( ) ;
1515
16+ /// Wrapper that opt-in implements `Send`/`Sync` for a raw pointer.
17+ ///
18+ /// Why? Rust deliberately withholds these auto-traits from raw pointers,
19+ /// so we must promise the compiler that the pointed-to memory really is
20+ /// shared-mutable and properly synchronized (which it is here: every thread
21+ /// writes its own unique slot).
22+ #[ derive( Copy , Clone ) ]
23+ struct ChillOut < T > ( * mut T ) ;
24+
25+ unsafe impl < T > Send for ChillOut < T > { }
26+ unsafe impl < T > Sync for ChillOut < T > { }
27+
1628#[ inline( always) ]
1729fn divide_round_up ( value : usize , divisor : usize ) -> usize {
1830 ( value + divisor - 1 ) / divisor
@@ -89,19 +101,18 @@ pub fn prepare_input() -> Vec<f32> {
89101pub fn sum_fork_union ( pool : & fork_union:: ForkUnion , data : & [ f32 ] , partial_sums : & mut [ f64 ] ) -> f64 {
90102 let cores = pool. thread_count ( ) ;
91103 let chunk_size = scalars_per_core ( data. len ( ) , cores) ;
92- let partial_sums_ptr = partial_sums. as_mut_ptr ( ) as usize ;
104+ let partial_sums_ptr = ChillOut ( partial_sums. as_mut_ptr ( ) ) ;
93105
94- pool. for_each_thread ( |thread_index| unsafe {
106+ pool. for_each_thread ( move |thread_index| unsafe {
95107 let start = thread_index * chunk_size;
96108 if start >= data. len ( ) {
97109 return ;
98110 }
99111 let stop = usize:: min ( start + chunk_size, data. len ( ) ) ;
100112 let partial_sum = sum_unrolled ( & data[ start..stop] ) ;
101- ptr:: write (
102- ( partial_sums_ptr as * mut f64 ) . add ( thread_index) ,
103- partial_sum,
104- ) ;
113+ // Ensure the entire wrapper, not just `.0`, is moved into the closure
114+ let partial_sums_ptr = partial_sums_ptr;
115+ ptr:: write ( partial_sums_ptr. 0 . add ( thread_index) , partial_sum) ;
105116 } ) ;
106117
107118 partial_sums. iter ( ) . copied ( ) . sum ( )
@@ -111,21 +122,20 @@ pub fn sum_fork_union(pool: &fork_union::ForkUnion, data: &[f32], partial_sums:
111122pub fn sum_rayon ( pool : & rayon:: ThreadPool , data : & [ f32 ] , partial_sums : & mut [ f64 ] ) -> f64 {
112123 let cores = pool. current_num_threads ( ) ;
113124 let chunk_size = scalars_per_core ( data. len ( ) , cores) ;
114- let partial_sums_ptr = partial_sums. as_mut_ptr ( ) as usize ;
125+ let partial_sums_ptr = ChillOut ( partial_sums. as_mut_ptr ( ) ) ;
115126
116- pool. broadcast ( |context : rayon:: BroadcastContext < ' _ > | {
127+ pool. broadcast ( move |context : rayon:: BroadcastContext < ' _ > | {
117128 let thread_index = context. index ( ) ;
118129 let start = thread_index * chunk_size;
119130 if start >= data. len ( ) {
120131 return ;
121132 }
122133 let stop = std:: cmp:: min ( start + chunk_size, data. len ( ) ) ;
123134 let partial_sum = sum_unrolled ( & data[ start..stop] ) ;
135+ // Ensure the entire wrapper, not just `.0`, is moved into the closure
136+ let partial_sums_ptr = partial_sums_ptr;
124137 unsafe {
125- ptr:: write (
126- ( partial_sums_ptr as * mut f64 ) . add ( thread_index) ,
127- partial_sum,
128- ) ;
138+ ptr:: write ( partial_sums_ptr. 0 . add ( thread_index) , partial_sum) ;
129139 }
130140 } ) ;
131141
@@ -137,7 +147,7 @@ pub fn sum_rayon(pool: &rayon::ThreadPool, data: &[f32], partial_sums: &mut [f64
137147pub fn sum_tokio ( pool : & tokio:: runtime:: Runtime , data : & [ f32 ] , partial_sums : & mut [ f64 ] ) -> f64 {
138148 let cores = num_cpus:: get ( ) ;
139149 let chunk_size = scalars_per_core ( data. len ( ) , cores) ;
140- let partial_sums_ptr = partial_sums. as_mut_ptr ( ) as usize ;
150+ let partial_sums_ptr = ChillOut ( partial_sums. as_mut_ptr ( ) ) ;
141151
142152 // Raw parts of the slice – immutable, lives as long as `data`.
143153 let ptr = data. as_ptr ( ) as usize ;
@@ -154,10 +164,9 @@ pub fn sum_tokio(pool: &tokio::runtime::Runtime, data: &[f32], partial_sums: &mu
154164 let stop = std:: cmp:: min ( start + chunk_size, len) ;
155165 let slice = std:: slice:: from_raw_parts ( ( ptr as * mut f32 ) . add ( start) , stop - start) ;
156166 let partial_sum = sum_unrolled ( slice) ;
157- ptr:: write (
158- ( partial_sums_ptr as * mut f64 ) . add ( thread_index) ,
159- partial_sum,
160- ) ;
167+ // Ensure the entire wrapper, not just `.0`, is moved into the closure
168+ let partial_sums_ptr = partial_sums_ptr;
169+ ptr:: write ( partial_sums_ptr. 0 . add ( thread_index) , partial_sum) ;
161170 } ) ;
162171 handles. push ( handle) ;
163172 }
@@ -171,7 +180,7 @@ pub fn sum_tokio(pool: &tokio::runtime::Runtime, data: &[f32], partial_sums: &mu
171180pub fn sum_smol ( pool : & async_executor:: Executor , data : & [ f32 ] , partial_sums : & mut [ f64 ] ) -> f64 {
172181 let cores = num_cpus:: get ( ) ;
173182 let chunk_size = scalars_per_core ( data. len ( ) , cores) ;
174- let partial_sums_ptr = partial_sums. as_mut_ptr ( ) as usize ;
183+ let partial_sums_ptr = ChillOut ( partial_sums. as_mut_ptr ( ) ) ;
175184
176185 let ptr = data. as_ptr ( ) as usize ;
177186 let len = data. len ( ) ;
@@ -189,10 +198,9 @@ pub fn sum_smol(pool: &async_executor::Executor, data: &[f32], partial_sums: &mu
189198 let slice =
190199 std:: slice:: from_raw_parts ( ( ptr as * mut f32 ) . add ( start) , stop - start) ;
191200 let partial_sum = sum_unrolled ( slice) ;
192- ptr:: write (
193- ( partial_sums_ptr as * mut f64 ) . add ( thread_index) ,
194- partial_sum,
195- ) ;
201+ // Ensure the entire wrapper, not just `.0`, is moved into the closure
202+ let partial_sums_ptr = partial_sums_ptr;
203+ ptr:: write ( partial_sums_ptr. 0 . add ( thread_index) , partial_sum) ;
196204 }
197205 } ) ) ;
198206 }
0 commit comments