@@ -203,36 +203,41 @@ let parallel_for ?(chunk_size=0) ~start ~finish ~body pool =
203203
204204let parallel_scan pool op elements =
205205 let pd = get_pool_data pool in
206+ let n = Array. length elements in
207+ let p = min (n - 1 ) ((Array. length pd.domains) + 1 ) in
208+ let prefix_s = Array. copy elements in
206209 let scan_part op elements prefix_sum start finish =
207210 assert (Array. length elements > (finish - start));
208211 for i = (start + 1 ) to finish do
209212 prefix_sum.(i) < - op prefix_sum.(i - 1 ) elements.(i)
210213 done
211214 in
215+ if p < 2 then begin
216+ (* Do a sequential scan when number of domains or array's length is less
217+ than 2 *)
218+ scan_part op elements prefix_s 0 (n - 1 );
219+ prefix_s
220+ end
221+ else begin
212222 let add_offset op prefix_sum offset start finish =
213223 assert (Array. length prefix_sum > (finish - start));
214224 for i = start to finish do
215225 prefix_sum.(i) < - op offset prefix_sum.(i)
216226 done
217227 in
218- let n = Array. length elements in
219- let p = (Array. length pd.domains) + 1 in
220- let prefix_s = Array. copy elements in
221228
222229 parallel_for pool ~chunk_size: 1 ~start: 0 ~finish: (p - 1 )
223230 ~body: (fun i ->
224231 let s = (i * n) / (p ) in
225232 let e = (i + 1 ) * n / (p ) - 1 in
226233 scan_part op elements prefix_s s e);
227234
228- if (p > 2 ) then begin
229235 let x = ref prefix_s.(n/ p - 1 ) in
230236 for i = 2 to p do
231237 let ind = i * n / p - 1 in
232238 x := op prefix_s.(ind) ! x;
233239 prefix_s.(ind) < - ! x
234- done
235- end ;
240+ done ;
236241
237242 parallel_for pool ~chunk_size: 1 ~start: 1 ~finish: (p - 1 )
238243 ~body: ( fun i ->
@@ -243,3 +248,4 @@ let parallel_scan pool op elements =
243248 );
244249
245250 prefix_s
251+ end
0 commit comments