Skip to content

Commit 0ca7fb9

Browse files
authored
add multiprocessing functions for batch running of BrainPy functions (#298)
add multiprocessing functions for batch running of BrainPy functions
2 parents f4ab548 + 23f42c0 commit 0ca7fb9

File tree

4 files changed

+395
-21
lines changed

4 files changed

+395
-21
lines changed

brainpy/running/__init__.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,26 @@
22

33

44
"""
5-
This module provides APIs for brain simulations.
5+
This module provides APIs for parallel brain simulations.
66
"""
77

8-
from .multiprocess import *
8+
from . import jax_multiprocessing
9+
from . import native_multiprocessing
10+
from . import pathos_multiprocessing
11+
from . import runner
12+
from . import constants
13+
14+
15+
__all__ = (native_multiprocessing.__all__ +
16+
pathos_multiprocessing.__all__ +
17+
jax_multiprocessing.__all__ +
18+
runner.__all__ +
19+
constants.__all__)
20+
21+
922
from .runner import *
23+
from .jax_multiprocessing import *
24+
from .native_multiprocessing import *
25+
from .pathos_multiprocessing import *
1026
from .constants import *
27+
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from typing import Sequence, Dict, Union
4+
5+
import numpy as np
6+
from jax import vmap, pmap
7+
from jax.tree_util import tree_unflatten, tree_flatten
8+
9+
import brainpy.math as bm
10+
from brainpy.types import Array
11+
12+
__all__ = [
13+
'jax_vectorize_map',
14+
'jax_parallelize_map',
15+
]
16+
17+
18+
def jax_vectorize_map(
19+
func: callable,
20+
arguments: Union[Dict[str, Array], Sequence[Array]],
21+
num_parallel: int,
22+
clear_buffer: bool = False
23+
):
24+
"""Perform a vectorized map of a function by using ``jax.vmap``.
25+
26+
This function can be used in CPU or GPU backends. But it is highly
27+
suitable to be used in GPU backends. This is because ``jax.vmap``
28+
can parallelize the mapped axis on GPU devices.
29+
30+
Parameters
31+
----------
32+
func: callable, function
33+
The function to be mapped.
34+
arguments: sequence, dict
35+
The function arguments, used to define tasks.
36+
num_parallel: int
37+
The number of batch size.
38+
clear_buffer: bool
39+
Clear the buffer memory after running each batch data.
40+
41+
Returns
42+
-------
43+
results: Any
44+
The running results.
45+
"""
46+
if not isinstance(arguments, (dict, tuple, list)):
47+
raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}')
48+
elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.JaxArray))
49+
if clear_buffer:
50+
elements = [np.asarray(ele) for ele in elements]
51+
num_pars = [len(ele) for ele in elements]
52+
if len(np.unique(num_pars)) != 1:
53+
raise ValueError(f'All elements in parameters should have the same length. '
54+
f'But we got {tree_unflatten(tree, num_pars)}')
55+
56+
res_tree = None
57+
results = None
58+
vmap_func = vmap(func)
59+
for i in range(0, num_pars[0], num_parallel):
60+
run_f = vmap(func) if clear_buffer else vmap_func
61+
if isinstance(arguments, dict):
62+
r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
63+
else:
64+
r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
65+
res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.JaxArray))
66+
if results is None:
67+
results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values)
68+
else:
69+
for j, val in enumerate(res_values):
70+
results[j].append(np.asarray(val) if clear_buffer else val)
71+
if clear_buffer:
72+
bm.clear_buffer_memory()
73+
if res_tree is None:
74+
return None
75+
results = ([np.concatenate(res, axis=0) for res in results]
76+
if clear_buffer else
77+
[bm.concatenate(res, axis=0) for res in results])
78+
return tree_unflatten(res_tree, results)
79+
80+
81+
def jax_parallelize_map(
82+
func: callable,
83+
arguments: Union[Dict[str, Array], Sequence[Array]],
84+
num_parallel: int,
85+
clear_buffer: bool = False
86+
):
87+
"""Perform a parallelized map of a function by using ``jax.pmap``.
88+
89+
This function can be used in multi- CPU or GPU backends.
90+
If you are using it in a single CPU, please set host device count
91+
by ``brainpy.math.set_host_device_count(n)`` before.
92+
93+
Parameters
94+
----------
95+
func: callable, function
96+
The function to be mapped.
97+
arguments: sequence, dict
98+
The function arguments, used to define tasks.
99+
num_parallel: int
100+
The number of batch size.
101+
clear_buffer: bool
102+
Clear the buffer memory after running each batch data.
103+
104+
Returns
105+
-------
106+
results: Any
107+
The running results.
108+
"""
109+
if not isinstance(arguments, (dict, tuple, list)):
110+
raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}')
111+
elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.JaxArray))
112+
if clear_buffer:
113+
elements = [np.asarray(ele) for ele in elements]
114+
num_pars = [len(ele) for ele in elements]
115+
if len(np.unique(num_pars)) != 1:
116+
raise ValueError(f'All elements in parameters should have the same length. '
117+
f'But we got {tree_unflatten(tree, num_pars)}')
118+
119+
res_tree = None
120+
results = None
121+
vmap_func = pmap(func)
122+
for i in range(0, num_pars[0], num_parallel):
123+
run_f = pmap(func) if clear_buffer else vmap_func
124+
if isinstance(arguments, dict):
125+
r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
126+
else:
127+
r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
128+
res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.JaxArray))
129+
if results is None:
130+
results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values)
131+
else:
132+
for j, val in enumerate(res_values):
133+
results[j].append(np.asarray(val) if clear_buffer else val)
134+
if clear_buffer:
135+
bm.clear_buffer_memory()
136+
if res_tree is None:
137+
return None
138+
results = ([np.concatenate(res, axis=0) for res in results]
139+
if clear_buffer else
140+
[bm.concatenate(res, axis=0) for res in results])
141+
return tree_unflatten(res_tree, results)

brainpy/running/multiprocess.py renamed to brainpy/running/native_multiprocessing.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
# -*- coding: utf-8 -*-
22

3+
from typing import Union, Sequence, Dict
34
import multiprocessing
45

5-
66
__all__ = [
77
'process_pool',
88
'process_pool_lock',
9-
'vectorize_map',
10-
'parallelize_map',
119
]
1210

1311

14-
def process_pool(func, all_params, num_process):
12+
def process_pool(func: callable,
13+
all_params: Union[Sequence, Dict],
14+
num_process: int):
1515
"""Run multiple models in multi-processes.
1616
1717
.. Note::
@@ -47,7 +47,9 @@ def process_pool(func, all_params, num_process):
4747
return [r.get() for r in results]
4848

4949

50-
def process_pool_lock(func, all_params, nb_process):
50+
def process_pool_lock(func: callable,
51+
all_params: Union[Sequence, Dict],
52+
num_process: int):
5153
"""Run multiple models in multi-processes with lock.
5254
5355
Sometimes, you want to synchronize the processes. For example,
@@ -71,11 +73,11 @@ def some_func(..., lock, ...):
7173
7274
Parameters
7375
----------
74-
func : callable
76+
func: callable
7577
The function to run model.
7678
all_params : list, tuple, dict
7779
The parameters of the function arguments.
78-
nb_process : int
80+
num_process : int
7981
The number of the processes.
8082
8183
Returns
@@ -84,7 +86,7 @@ def some_func(..., lock, ...):
8486
Process results.
8587
"""
8688
print('{} jobs total.'.format(len(all_params)))
87-
pool = multiprocessing.Pool(processes=nb_process)
89+
pool = multiprocessing.Pool(processes=num_process)
8890
m = multiprocessing.Manager()
8991
lock = m.Lock()
9092
results = []
@@ -99,14 +101,3 @@ def some_func(..., lock, ...):
99101
pool.close()
100102
pool.join()
101103
return [r.get() for r in results]
102-
103-
104-
def vectorize_map(func, all_params, num_thread):
105-
pass
106-
107-
108-
def parallelize_map(func, all_params, num_process):
109-
pass
110-
111-
112-

0 commit comments

Comments
 (0)