@@ -8,9 +8,13 @@ type task_msg =
88 Task : 'a task * 'a promise -> task_msg
99| Quit : task_msg
1010
11- type pool =
12- {domains : unit Domain .t array ;
13- task_chan : task_msg Multi_channel .t }
11+ type pool_data = {
12+ domains : unit Domain .t array ;
13+ task_chan : task_msg Multi_channel .t ;
14+ name : string option
15+ }
16+
17+ type pool = pool_data option Atomic .t
1418
1519let do_task f p =
1620 try
@@ -22,7 +26,15 @@ let do_task f p =
2226 | TasksActive -> raise e
2327 | _ -> ()
2428
25- let setup_pool ~num_additional_domains =
29+ let named_pools = Hashtbl. create 8
30+
31+ let named_pools_mutex = Mutex. create ()
32+
33+ let setup_pool ?name ~num_additional_domains () =
34+ if num_additional_domains < 0 then
35+ raise (Invalid_argument
36+ " Task.setup_pool: num_additional_domains must be at least 0" )
37+ else
2638 let task_chan = Multi_channel. make (num_additional_domains+ 1 ) in
2739 let rec worker () =
2840 match Multi_channel. recv task_chan with
@@ -32,19 +44,34 @@ let setup_pool ~num_additional_domains =
3244 worker ()
3345 in
3446 let domains = Array. init num_additional_domains (fun _ -> Domain. spawn worker) in
35- {domains; task_chan}
47+ let p = Atomic. make (Some {domains; task_chan; name}) in
48+ begin match name with
49+ | None -> ()
50+ | Some x ->
51+ Mutex. lock named_pools_mutex;
52+ Hashtbl. add named_pools x p;
53+ Mutex. unlock named_pools_mutex
54+ end ;
55+ p
56+
57+ let get_pool_data p =
58+ match Atomic. get p with
59+ | None -> raise (Invalid_argument " pool already torn down" )
60+ | Some p -> p
3661
3762let async pool task =
63+ let pd = get_pool_data pool in
3864 let p = Atomic. make None in
39- Multi_channel. send pool .task_chan (Task (task,p));
65+ Multi_channel. send pd .task_chan (Task (task,p));
4066 p
4167
4268let rec await pool promise =
69+ let pd = get_pool_data pool in
4370 match Atomic. get promise with
4471 | None ->
4572 begin
4673 try
47- match Multi_channel. recv_poll pool .task_chan with
74+ match Multi_channel. recv_poll pd .task_chan with
4875 | Task (t , p ) -> do_task t p
4976 | Quit -> raise TasksActive
5077 with
@@ -55,16 +82,37 @@ let rec await pool promise =
5582 | Some (Error e ) -> raise e
5683
5784let teardown_pool pool =
58- for _i= 1 to Array. length pool.domains do
59- Multi_channel. send pool.task_chan Quit
85+ let pd = get_pool_data pool in
86+ for _i= 1 to Array. length pd.domains do
87+ Multi_channel. send pd.task_chan Quit
6088 done ;
6189 Multi_channel. clear_local_state () ;
62- Array. iter Domain. join pool.domains
90+ Array. iter Domain. join pd.domains;
91+ (* Remove the pool from the table *)
92+ begin match pd.name with
93+ | None -> ()
94+ | Some n ->
95+ Mutex. lock named_pools_mutex;
96+ Hashtbl. remove named_pools n;
97+ Mutex. unlock named_pools_mutex
98+ end ;
99+ Atomic. set pool None
100+
101+ let lookup_pool name =
102+ Mutex. lock named_pools_mutex;
103+ let p = Hashtbl. find_opt named_pools name in
104+ Mutex. unlock named_pools_mutex;
105+ p
106+
107+ let get_num_domains pool =
108+ let pd = get_pool_data pool in
109+ Array. length pd.domains + 1
63110
64111let parallel_for_reduce ?(chunk_size =0 ) ~start ~finish ~body pool reduce_fun init =
112+ let pd = get_pool_data pool in
65113 let chunk_size = if chunk_size > 0 then chunk_size
66114 else begin
67- let n_domains = (Array. length pool .domains) + 1 in
115+ let n_domains = (Array. length pd .domains) + 1 in
68116 let n_tasks = finish - start + 1 in
69117 if n_domains = 1 then n_tasks
70118 else max 1 (n_tasks/ (8 * n_domains))
@@ -88,9 +136,10 @@ let parallel_for_reduce ?(chunk_size=0) ~start ~finish ~body pool reduce_fun ini
88136 reduce_fun init (work start finish)
89137
90138let parallel_for ?(chunk_size =0 ) ~start ~finish ~body pool =
139+ let pd = get_pool_data pool in
91140 let chunk_size = if chunk_size > 0 then chunk_size
92141 else begin
93- let n_domains = (Array. length pool .domains) + 1 in
142+ let n_domains = (Array. length pd .domains) + 1 in
94143 let n_tasks = finish - start + 1 in
95144 if n_domains = 1 then n_tasks
96145 else max 1 (n_tasks/ (8 * n_domains))
@@ -109,7 +158,7 @@ let parallel_for ?(chunk_size=0) ~start ~finish ~body pool =
109158 work pool body start finish
110159
111160let parallel_scan pool op elements =
112-
161+ let pd = get_pool_data pool in
113162 let scan_part op elements prefix_sum start finish =
114163 assert (Array. length elements > (finish - start));
115164 for i = (start + 1 ) to finish do
@@ -123,7 +172,7 @@ let parallel_scan pool op elements =
123172 done
124173 in
125174 let n = Array. length elements in
126- let p = (Array. length pool .domains) + 1 in
175+ let p = (Array. length pd .domains) + 1 in
127176 let prefix_s = Array. copy elements in
128177
129178 parallel_for pool ~chunk_size: 1 ~start: 0 ~finish: (p - 1 )
0 commit comments