Skip to content

Commit f3e7d72

Browse files
authored
upgrade operators to match brainpylib>=0.1.0 (#284)
upgrade operators to match brainpylib>=0.1.0
2 parents 15828ca + c12d316 commit f3e7d72

File tree

8 files changed

+205
-95
lines changed

8 files changed

+205
-95
lines changed

brainpy/math/operators/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
# -*- coding: utf-8 -*-
22

33

4-
from . import multiplication
4+
from . import sparse_matmul, event_matmul
55
from . import op_register
66
from . import pre_syn_post as pre_syn_post_module
77
from . import wrap_jax
88
from . import spikegrad
99

10-
__all__ = multiplication.__all__ + op_register.__all__
10+
__all__ = event_matmul.__all__ + sparse_matmul.__all__ + op_register.__all__
1111
__all__ += pre_syn_post_module.__all__ + wrap_jax.__all__ + spikegrad.__all__
1212

1313

14-
from .multiplication import *
14+
from .event_matmul import *
15+
from .sparse_matmul import *
1516
from .op_register import *
1617
from .pre_syn_post import *
1718
from .wrap_jax import *
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
from typing import Tuple
5+
6+
from brainpy.math.numpy_ops import as_jax
7+
from brainpy.types import Array
8+
from .utils import _check_brainpylib
9+
10+
try:
11+
import brainpylib
12+
except ModuleNotFoundError:
13+
brainpylib = None
14+
15+
__all__ = [
16+
'event_csr_matvec',
17+
]
18+
19+
20+
def event_csr_matvec(values: Array,
21+
indices: Array,
22+
indptr: Array,
23+
events: Array,
24+
shape: Tuple[int, ...],
25+
transpose: bool = False):
26+
"""The pre-to-post event-driven synaptic summation with `CSR` synapse structure.
27+
28+
Parameters
29+
----------
30+
values: Array, float
31+
An array of shape ``(nse,)`` or a float.
32+
indices: Array
33+
An array of shape ``(nse,)``.
34+
indptr: Array
35+
An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
36+
events: Array
37+
An array of shape ``(shape[0] if transpose else shape[1],)``
38+
and dtype ``data.dtype``.
39+
shape: tuple of int
40+
A length-2 tuple representing the sparse matrix shape.
41+
transpose: bool
42+
A boolean specifying whether to transpose the sparse matrix
43+
before computing. Default is False.
44+
45+
Returns
46+
-------
47+
out: Array
48+
A tensor with the shape of ``shape[1]`` if `transpose=True`,
49+
or ``shape[0]`` if `transpose=False`.
50+
"""
51+
_check_brainpylib('event_csr_matvec')
52+
events = as_jax(events)
53+
indices = as_jax(indices)
54+
indptr = as_jax(indptr)
55+
values = as_jax(values)
56+
return brainpylib.event_csr_matvec(values, indices, indptr, events,
57+
shape=shape, transpose=transpose)

brainpy/math/operators/op_register.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Union, Sequence, Callable
44

5-
from jax.abstract_arrays import ShapedArray
5+
from jax.core import ShapedArray
66
from jax.tree_util import tree_map
77

88
from brainpy.base import Base
@@ -57,6 +57,10 @@ def __init__(
5757
gpu_func: Callable = None,
5858
apply_cpu_func_to_gpu: bool = False,
5959
name: str = None,
60+
batching_translation: Callable = None,
61+
jvp_translation: Callable = None,
62+
transpose_translation: Callable = None,
63+
multiple_results: bool = False,
6064
):
6165
_check_brainpylib(register_op.__name__)
6266
super(XLACustomOp, self).__init__(name=name)
@@ -77,19 +81,25 @@ def __init__(
7781
gpu_func = None
7882

7983
# register OP
80-
self.op = brainpylib.register_op(self.name,
81-
cpu_func=cpu_func,
82-
gpu_func=gpu_func,
83-
out_shapes=eval_shape,
84-
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
84+
self.op = brainpylib.register_op_with_numba(
85+
self.name,
86+
cpu_func=cpu_func,
87+
gpu_func_translation=gpu_func,
88+
out_shapes=eval_shape,
89+
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu,
90+
batching_translation=batching_translation,
91+
jvp_translation=jvp_translation,
92+
transpose_translation=transpose_translation,
93+
multiple_results=multiple_results,
94+
)
8595

8696
def __call__(self, *args, **kwargs):
8797
args = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a,
8898
args, is_leaf=lambda a: isinstance(a, JaxArray))
8999
kwargs = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a,
90100
kwargs, is_leaf=lambda a: isinstance(a, JaxArray))
91101
res = self.op.bind(*args, **kwargs)
92-
return res[0] if len(res) == 1 else res
102+
return res
93103

94104

95105
def register_op(
@@ -122,15 +132,15 @@ def register_op(
122132
A jitable JAX function.
123133
"""
124134
_check_brainpylib(register_op.__name__)
125-
f = brainpylib.register_op(name,
126-
cpu_func=cpu_func,
127-
gpu_func=gpu_func,
128-
out_shapes=eval_shape,
129-
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
135+
f = brainpylib.register_op_with_numba(name,
136+
cpu_func=cpu_func,
137+
gpu_func_translation=gpu_func,
138+
out_shapes=eval_shape,
139+
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
130140

131141
def fixed_op(*inputs, **info):
132142
inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs])
133143
res = f.bind(*inputs, **info)
134-
return res[0] if len(res) == 1 else res
144+
return res
135145

136146
return fixed_op

0 commit comments

Comments
 (0)