Skip to content

Commit df4afa2

Browse files
authored
Merge pull request #60 from ocaml-multicore/parallel_scan_bug_fix
Bug fix in `parallel_scan`
2 parents 59ee895 + f8cea3b commit df4afa2

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

lib/task.ml

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,36 +203,41 @@ let parallel_for ?(chunk_size=0) ~start ~finish ~body pool =
203203

204204
let parallel_scan pool op elements =
205205
let pd = get_pool_data pool in
206+
let n = Array.length elements in
207+
let p = min (n - 1) ((Array.length pd.domains) + 1) in
208+
let prefix_s = Array.copy elements in
206209
let scan_part op elements prefix_sum start finish =
207210
assert (Array.length elements > (finish - start));
208211
for i = (start + 1) to finish do
209212
prefix_sum.(i) <- op prefix_sum.(i - 1) elements.(i)
210213
done
211214
in
215+
if p < 2 then begin
216+
(* Do a sequential scan when number of domains or array's length is less
217+
than 2 *)
218+
scan_part op elements prefix_s 0 (n - 1);
219+
prefix_s
220+
end
221+
else begin
212222
let add_offset op prefix_sum offset start finish =
213223
assert (Array.length prefix_sum > (finish - start));
214224
for i = start to finish do
215225
prefix_sum.(i) <- op offset prefix_sum.(i)
216226
done
217227
in
218-
let n = Array.length elements in
219-
let p = (Array.length pd.domains) + 1 in
220-
let prefix_s = Array.copy elements in
221228

222229
parallel_for pool ~chunk_size:1 ~start:0 ~finish:(p - 1)
223230
~body:(fun i ->
224231
let s = (i * n) / (p ) in
225232
let e = (i + 1) * n / (p ) - 1 in
226233
scan_part op elements prefix_s s e);
227234

228-
if (p > 2) then begin
229235
let x = ref prefix_s.(n/p - 1) in
230236
for i = 2 to p do
231237
let ind = i * n / p - 1 in
232238
x := op prefix_s.(ind) !x;
233239
prefix_s.(ind) <- !x
234-
done
235-
end;
240+
done;
236241

237242
parallel_for pool ~chunk_size:1 ~start:1 ~finish:(p - 1)
238243
~body:( fun i ->
@@ -243,3 +248,4 @@ let parallel_scan pool op elements =
243248
);
244249

245250
prefix_s
251+
end

test/dune

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,9 @@
9898
(libraries domainslib)
9999
(modules backtrace)
100100
(modes native))
101+
102+
(test
103+
(name off_by_one)
104+
(libraries domainslib)
105+
(modules off_by_one)
106+
(modes native))

test/off_by_one.ml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
open Domainslib
2+
3+
let print_array a =
4+
let b = Buffer.create 25 in
5+
Buffer.add_string b "[|";
6+
Array.iter (fun elem -> Buffer.add_string b (string_of_int elem ^ "; ")) a;
7+
Buffer.add_string b "|]";
8+
Buffer.contents b
9+
10+
let r = Array.init 20 (fun i -> i + 1)
11+
12+
let scan_task num_doms =
13+
let pool = Task.setup_pool ~num_additional_domains:num_doms () in
14+
let a = Task.run pool (fun () -> Task.parallel_scan pool (+) (Array.make 20 1)) in
15+
Task.teardown_pool pool;
16+
Printf.printf "%i: %s\n%!" num_doms (print_array a);
17+
assert (a = r)
18+
;;
19+
for num_dom=0 to 21 do
20+
scan_task num_dom;
21+
done

0 commit comments

Comments
 (0)