|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +from typing import Sequence, Dict, Union |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from jax import vmap, pmap |
| 7 | +from jax.tree_util import tree_unflatten, tree_flatten |
| 8 | + |
| 9 | +import brainpy.math as bm |
| 10 | +from brainpy.types import Array |
| 11 | + |
| 12 | +__all__ = [ |
| 13 | + 'jax_vectorize_map', |
| 14 | + 'jax_parallelize_map', |
| 15 | +] |
| 16 | + |
| 17 | + |
| 18 | +def jax_vectorize_map( |
| 19 | + func: callable, |
| 20 | + arguments: Union[Dict[str, Array], Sequence[Array]], |
| 21 | + num_parallel: int, |
| 22 | + clear_buffer: bool = False |
| 23 | +): |
| 24 | + """Perform a vectorized map of a function by using ``jax.vmap``. |
| 25 | +
|
| 26 | + This function can be used in CPU or GPU backends. But it is highly |
| 27 | + suitable to be used in GPU backends. This is because ``jax.vmap`` |
| 28 | + can parallelize the mapped axis on GPU devices. |
| 29 | +
|
| 30 | + Parameters |
| 31 | + ---------- |
| 32 | + func: callable, function |
| 33 | + The function to be mapped. |
| 34 | + arguments: sequence, dict |
| 35 | + The function arguments, used to define tasks. |
| 36 | + num_parallel: int |
| 37 | + The number of batch size. |
| 38 | + clear_buffer: bool |
| 39 | + Clear the buffer memory after running each batch data. |
| 40 | +
|
| 41 | + Returns |
| 42 | + ------- |
| 43 | + results: Any |
| 44 | + The running results. |
| 45 | + """ |
| 46 | + if not isinstance(arguments, (dict, tuple, list)): |
| 47 | + raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}') |
| 48 | + elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.JaxArray)) |
| 49 | + if clear_buffer: |
| 50 | + elements = [np.asarray(ele) for ele in elements] |
| 51 | + num_pars = [len(ele) for ele in elements] |
| 52 | + if len(np.unique(num_pars)) != 1: |
| 53 | + raise ValueError(f'All elements in parameters should have the same length. ' |
| 54 | + f'But we got {tree_unflatten(tree, num_pars)}') |
| 55 | + |
| 56 | + res_tree = None |
| 57 | + results = None |
| 58 | + vmap_func = vmap(func) |
| 59 | + for i in range(0, num_pars[0], num_parallel): |
| 60 | + run_f = vmap(func) if clear_buffer else vmap_func |
| 61 | + if isinstance(arguments, dict): |
| 62 | + r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) |
| 63 | + else: |
| 64 | + r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) |
| 65 | + res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.JaxArray)) |
| 66 | + if results is None: |
| 67 | + results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values) |
| 68 | + else: |
| 69 | + for j, val in enumerate(res_values): |
| 70 | + results[j].append(np.asarray(val) if clear_buffer else val) |
| 71 | + if clear_buffer: |
| 72 | + bm.clear_buffer_memory() |
| 73 | + if res_tree is None: |
| 74 | + return None |
| 75 | + results = ([np.concatenate(res, axis=0) for res in results] |
| 76 | + if clear_buffer else |
| 77 | + [bm.concatenate(res, axis=0) for res in results]) |
| 78 | + return tree_unflatten(res_tree, results) |
| 79 | + |
| 80 | + |
| 81 | +def jax_parallelize_map( |
| 82 | + func: callable, |
| 83 | + arguments: Union[Dict[str, Array], Sequence[Array]], |
| 84 | + num_parallel: int, |
| 85 | + clear_buffer: bool = False |
| 86 | +): |
| 87 | + """Perform a parallelized map of a function by using ``jax.pmap``. |
| 88 | +
|
| 89 | + This function can be used in multi- CPU or GPU backends. |
| 90 | + If you are using it in a single CPU, please set host device count |
| 91 | + by ``brainpy.math.set_host_device_count(n)`` before. |
| 92 | +
|
| 93 | + Parameters |
| 94 | + ---------- |
| 95 | + func: callable, function |
| 96 | + The function to be mapped. |
| 97 | + arguments: sequence, dict |
| 98 | + The function arguments, used to define tasks. |
| 99 | + num_parallel: int |
| 100 | + The number of batch size. |
| 101 | + clear_buffer: bool |
| 102 | + Clear the buffer memory after running each batch data. |
| 103 | +
|
| 104 | + Returns |
| 105 | + ------- |
| 106 | + results: Any |
| 107 | + The running results. |
| 108 | + """ |
| 109 | + if not isinstance(arguments, (dict, tuple, list)): |
| 110 | + raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}') |
| 111 | + elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.JaxArray)) |
| 112 | + if clear_buffer: |
| 113 | + elements = [np.asarray(ele) for ele in elements] |
| 114 | + num_pars = [len(ele) for ele in elements] |
| 115 | + if len(np.unique(num_pars)) != 1: |
| 116 | + raise ValueError(f'All elements in parameters should have the same length. ' |
| 117 | + f'But we got {tree_unflatten(tree, num_pars)}') |
| 118 | + |
| 119 | + res_tree = None |
| 120 | + results = None |
| 121 | + vmap_func = pmap(func) |
| 122 | + for i in range(0, num_pars[0], num_parallel): |
| 123 | + run_f = pmap(func) if clear_buffer else vmap_func |
| 124 | + if isinstance(arguments, dict): |
| 125 | + r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) |
| 126 | + else: |
| 127 | + r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) |
| 128 | + res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.JaxArray)) |
| 129 | + if results is None: |
| 130 | + results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values) |
| 131 | + else: |
| 132 | + for j, val in enumerate(res_values): |
| 133 | + results[j].append(np.asarray(val) if clear_buffer else val) |
| 134 | + if clear_buffer: |
| 135 | + bm.clear_buffer_memory() |
| 136 | + if res_tree is None: |
| 137 | + return None |
| 138 | + results = ([np.concatenate(res, axis=0) for res in results] |
| 139 | + if clear_buffer else |
| 140 | + [bm.concatenate(res, axis=0) for res in results]) |
| 141 | + return tree_unflatten(res_tree, results) |
0 commit comments