Skip to content

Commit 340773a

Browse files
authored
Merge pull request #1 from ocaml-multicore/better_teardown
Clean up better at teardown.
2 parents b269930 + cd062f9 commit 340773a

File tree

2 files changed

+57
-24
lines changed

2 files changed

+57
-24
lines changed

lib/task.ml

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@ type task_msg =
88
Task : 'a task * 'a promise -> task_msg
99
| Quit : task_msg
1010

11-
type pool =
12-
{domains : unit Domain.t array;
13-
task_chan : task_msg Multi_channel.t}
11+
type pool_data = {
12+
domains : unit Domain.t array;
13+
task_chan : task_msg Multi_channel.t;
14+
name: string option
15+
}
16+
17+
type pool = pool_data option Atomic.t
1418

1519
let do_task f p =
1620
try
@@ -24,7 +28,7 @@ let do_task f p =
2428

2529
let named_pools = Hashtbl.create 8
2630

27-
let named_pool_mutex = Mutex.create ()
31+
let named_pools_mutex = Mutex.create ()
2832

2933
let setup_pool ?name ~num_additional_domains () =
3034
if num_additional_domains < 0 then
@@ -40,27 +44,34 @@ let setup_pool ?name ~num_additional_domains () =
4044
worker ()
4145
in
4246
let domains = Array.init num_additional_domains (fun _ -> Domain.spawn worker) in
43-
let p = {domains; task_chan} in
44-
let _ = match name with
45-
| Some x ->
46-
Mutex.lock named_pool_mutex;
47-
Hashtbl.add named_pools x p;
48-
Mutex.unlock named_pool_mutex
47+
let p = Atomic.make (Some {domains; task_chan; name}) in
48+
begin match name with
4949
| None -> ()
50-
in
50+
| Some x ->
51+
Mutex.lock named_pools_mutex;
52+
Hashtbl.add named_pools x p;
53+
Mutex.unlock named_pools_mutex
54+
end;
5155
p
5256

57+
let get_pool_data p =
58+
match Atomic.get p with
59+
| None -> raise (Invalid_argument "pool already torn down")
60+
| Some p -> p
61+
5362
let async pool task =
63+
let pd = get_pool_data pool in
5464
let p = Atomic.make None in
55-
Multi_channel.send pool.task_chan (Task(task,p));
65+
Multi_channel.send pd.task_chan (Task(task,p));
5666
p
5767

5868
let rec await pool promise =
69+
let pd = get_pool_data pool in
5970
match Atomic.get promise with
6071
| None ->
6172
begin
6273
try
63-
match Multi_channel.recv_poll pool.task_chan with
74+
match Multi_channel.recv_poll pd.task_chan with
6475
| Task (t, p) -> do_task t p
6576
| Quit -> raise TasksActive
6677
with
@@ -71,24 +82,37 @@ let rec await pool promise =
7182
| Some (Error e) -> raise e
7283

7384
let teardown_pool pool =
74-
for _i=1 to Array.length pool.domains do
75-
Multi_channel.send pool.task_chan Quit
85+
let pd = get_pool_data pool in
86+
for _i=1 to Array.length pd.domains do
87+
Multi_channel.send pd.task_chan Quit
7688
done;
7789
Multi_channel.clear_local_state ();
78-
Array.iter Domain.join pool.domains
90+
Array.iter Domain.join pd.domains;
91+
(* Remove the pool from the table *)
92+
begin match pd.name with
93+
| None -> ()
94+
| Some n ->
95+
Mutex.lock named_pools_mutex;
96+
Hashtbl.remove named_pools n;
97+
Mutex.unlock named_pools_mutex
98+
end;
99+
Atomic.set pool None
79100

80101
let lookup_pool name =
81-
Mutex.lock named_pool_mutex;
102+
Mutex.lock named_pools_mutex;
82103
let p = Hashtbl.find_opt named_pools name in
83-
Mutex.unlock named_pool_mutex;
104+
Mutex.unlock named_pools_mutex;
84105
p
85106

86-
let get_num_domains pool = (Array.length pool.domains + 1)
107+
let get_num_domains pool =
108+
let pd = get_pool_data pool in
109+
Array.length pd.domains + 1
87110

88111
let parallel_for_reduce ?(chunk_size=0) ~start ~finish ~body pool reduce_fun init =
112+
let pd = get_pool_data pool in
89113
let chunk_size = if chunk_size > 0 then chunk_size
90114
else begin
91-
let n_domains = (Array.length pool.domains) + 1 in
115+
let n_domains = (Array.length pd.domains) + 1 in
92116
let n_tasks = finish - start + 1 in
93117
if n_domains = 1 then n_tasks
94118
else max 1 (n_tasks/(8*n_domains))
@@ -112,9 +136,10 @@ let parallel_for_reduce ?(chunk_size=0) ~start ~finish ~body pool reduce_fun ini
112136
reduce_fun init (work start finish)
113137

114138
let parallel_for ?(chunk_size=0) ~start ~finish ~body pool =
139+
let pd = get_pool_data pool in
115140
let chunk_size = if chunk_size > 0 then chunk_size
116141
else begin
117-
let n_domains = (Array.length pool.domains) + 1 in
142+
let n_domains = (Array.length pd.domains) + 1 in
118143
let n_tasks = finish - start + 1 in
119144
if n_domains = 1 then n_tasks
120145
else max 1 (n_tasks/(8*n_domains))
@@ -133,7 +158,7 @@ let parallel_for ?(chunk_size=0) ~start ~finish ~body pool =
133158
work pool body start finish
134159

135160
let parallel_scan pool op elements =
136-
161+
let pd = get_pool_data pool in
137162
let scan_part op elements prefix_sum start finish =
138163
assert (Array.length elements > (finish - start));
139164
for i = (start + 1) to finish do
@@ -147,7 +172,7 @@ let parallel_scan pool op elements =
147172
done
148173
in
149174
let n = Array.length elements in
150-
let p = (Array.length pool.domains) + 1 in
175+
let p = (Array.length pd.domains) + 1 in
151176
let prefix_s = Array.copy elements in
152177

153178
parallel_for pool ~chunk_size:1 ~start:0 ~finish:(p - 1)

test/test_task.ml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,15 @@ let () =
5757
prefix_sum p2 ();
5858
Task.teardown_pool pool1;
5959
Task.teardown_pool pool2;
60+
61+
try
62+
sum_sequence pool2 0 0 ();
63+
assert false
64+
with Invalid_argument _ -> ();
65+
66+
assert (Task.lookup_pool "pool1" = None);
67+
6068
try
6169
let _ = Task.setup_pool ~num_additional_domains:(-1) () in ()
6270
with Invalid_argument _ -> ();
63-
print_endline "ok"
71+
print_endline "ok"

0 commit comments

Comments
 (0)