Skip to content

Commit 9f6040b

Browse files
committed
fix op register bug
1 parent 627af9b commit 9f6040b

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

brainpy/math/operators/op_register.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def __init__(
5757
gpu_func: Callable = None,
5858
apply_cpu_func_to_gpu: bool = False,
5959
name: str = None,
60+
batching_translation: Callable = None,
61+
jvp_translation: Callable = None,
62+
transpose_translation: Callable = None,
63+
multiple_results: bool = False,
6064
):
6165
_check_brainpylib(register_op.__name__)
6266
super(XLACustomOp, self).__init__(name=name)
@@ -77,11 +81,17 @@ def __init__(
7781
gpu_func = None
7882

7983
# register OP
80-
self.op = brainpylib.register_op_with_numba(self.name,
81-
cpu_func=cpu_func,
82-
gpu_func=gpu_func,
83-
out_shapes=eval_shape,
84-
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
84+
self.op = brainpylib.register_op_with_numba(
85+
self.name,
86+
cpu_func=cpu_func,
87+
gpu_func_translation=gpu_func,
88+
out_shapes=eval_shape,
89+
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu,
90+
batching_translation=batching_translation,
91+
jvp_translation=jvp_translation,
92+
transpose_translation=transpose_translation,
93+
multiple_results=multiple_results,
94+
)
8595

8696
def __call__(self, *args, **kwargs):
8797
args = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a,

brainpy/math/operators/pre_syn_post.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def pre2post_event_sum(events: Array,
9797
out: JaxArray, jax.numpy.ndarray
9898
A tensor with the shape of ``post_num``.
9999
"""
100-
_check_brainpylib('pre2post_event_sum')
100+
_check_brainpylib('event_csr_matvec')
101101
indices, idnptr = pre2post
102102
events = as_jax(events)
103103
indices = as_jax(indices)

0 commit comments

Comments
 (0)