@@ -8,9 +8,13 @@ type task_msg =
88 Task : 'a task * 'a promise -> task_msg
99| Quit : task_msg
1010
11- type pool =
12- {domains : unit Domain .t array ;
13- task_chan : task_msg Multi_channel .t }
11+ type pool_data = {
12+ domains : unit Domain .t array ;
13+ task_chan : task_msg Multi_channel .t ;
14+ name : string option
15+ }
16+
17+ type pool = pool_data option Atomic .t
1418
1519let do_task f p =
1620 try
@@ -24,7 +28,7 @@ let do_task f p =
2428
2529let named_pools = Hashtbl. create 8
2630
27- let named_pool_mutex = Mutex. create ()
31+ let named_pools_mutex = Mutex. create ()
2832
2933let setup_pool ?name ~num_additional_domains () =
3034 if num_additional_domains < 0 then
@@ -40,27 +44,34 @@ let setup_pool ?name ~num_additional_domains () =
4044 worker ()
4145 in
4246 let domains = Array. init num_additional_domains (fun _ -> Domain. spawn worker) in
43- let p = {domains; task_chan} in
44- let _ = match name with
45- | Some x ->
46- Mutex. lock named_pool_mutex;
47- Hashtbl. add named_pools x p;
48- Mutex. unlock named_pool_mutex
47+ let p = Atomic. make (Some {domains; task_chan; name}) in
48+ begin match name with
4949 | None -> ()
50- in
50+ | Some x ->
51+ Mutex. lock named_pools_mutex;
52+ Hashtbl. add named_pools x p;
53+ Mutex. unlock named_pools_mutex
54+ end ;
5155 p
5256
57+ let get_pool_data p =
58+ match Atomic. get p with
59+ | None -> raise (Invalid_argument " pool already torn down" )
60+ | Some p -> p
61+
5362let async pool task =
63+ let pd = get_pool_data pool in
5464 let p = Atomic. make None in
55- Multi_channel. send pool .task_chan (Task (task,p));
65+ Multi_channel. send pd .task_chan (Task (task,p));
5666 p
5767
5868let rec await pool promise =
69+ let pd = get_pool_data pool in
5970 match Atomic. get promise with
6071 | None ->
6172 begin
6273 try
63- match Multi_channel. recv_poll pool .task_chan with
74+ match Multi_channel. recv_poll pd .task_chan with
6475 | Task (t , p ) -> do_task t p
6576 | Quit -> raise TasksActive
6677 with
@@ -71,24 +82,37 @@ let rec await pool promise =
7182 | Some (Error e ) -> raise e
7283
7384let teardown_pool pool =
74- for _i= 1 to Array. length pool.domains do
75- Multi_channel. send pool.task_chan Quit
85+ let pd = get_pool_data pool in
86+ for _i= 1 to Array. length pd.domains do
87+ Multi_channel. send pd.task_chan Quit
7688 done ;
7789 Multi_channel. clear_local_state () ;
78- Array. iter Domain. join pool.domains
90+ Array. iter Domain. join pd.domains;
91+ (* Remove the pool from the table *)
92+ begin match pd.name with
93+ | None -> ()
94+ | Some n ->
95+ Mutex. lock named_pools_mutex;
96+ Hashtbl. remove named_pools n;
97+ Mutex. unlock named_pools_mutex
98+ end ;
99+ Atomic. set pool None
79100
80101let lookup_pool name =
81- Mutex. lock named_pool_mutex ;
102+ Mutex. lock named_pools_mutex ;
82103 let p = Hashtbl. find_opt named_pools name in
83- Mutex. unlock named_pool_mutex ;
104+ Mutex. unlock named_pools_mutex ;
84105 p
85106
86- let get_num_domains pool = (Array. length pool.domains + 1 )
107+ let get_num_domains pool =
108+ let pd = get_pool_data pool in
109+ Array. length pd.domains + 1
87110
88111let parallel_for_reduce ?(chunk_size =0 ) ~start ~finish ~body pool reduce_fun init =
112+ let pd = get_pool_data pool in
89113 let chunk_size = if chunk_size > 0 then chunk_size
90114 else begin
91- let n_domains = (Array. length pool .domains) + 1 in
115+ let n_domains = (Array. length pd .domains) + 1 in
92116 let n_tasks = finish - start + 1 in
93117 if n_domains = 1 then n_tasks
94118 else max 1 (n_tasks/ (8 * n_domains))
@@ -112,9 +136,10 @@ let parallel_for_reduce ?(chunk_size=0) ~start ~finish ~body pool reduce_fun ini
112136 reduce_fun init (work start finish)
113137
114138let parallel_for ?(chunk_size =0 ) ~start ~finish ~body pool =
139+ let pd = get_pool_data pool in
115140 let chunk_size = if chunk_size > 0 then chunk_size
116141 else begin
117- let n_domains = (Array. length pool .domains) + 1 in
142+ let n_domains = (Array. length pd .domains) + 1 in
118143 let n_tasks = finish - start + 1 in
119144 if n_domains = 1 then n_tasks
120145 else max 1 (n_tasks/ (8 * n_domains))
@@ -133,7 +158,7 @@ let parallel_for ?(chunk_size=0) ~start ~finish ~body pool =
133158 work pool body start finish
134159
135160let parallel_scan pool op elements =
136-
161+ let pd = get_pool_data pool in
137162 let scan_part op elements prefix_sum start finish =
138163 assert (Array. length elements > (finish - start));
139164 for i = (start + 1 ) to finish do
@@ -147,7 +172,7 @@ let parallel_scan pool op elements =
147172 done
148173 in
149174 let n = Array. length elements in
150- let p = (Array. length pool .domains) + 1 in
175+ let p = (Array. length pd .domains) + 1 in
151176 let prefix_s = Array. copy elements in
152177
153178 parallel_for pool ~chunk_size: 1 ~start: 0 ~finish: (p - 1 )
0 commit comments