We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5481ec4 commit 4515c8cCopy full SHA for 4515c8c
brainpy/math/operators/op_register.py
@@ -132,11 +132,11 @@ def register_op(
132
A jitable JAX function.
133
"""
134
_check_brainpylib(register_op.__name__)
135
- f = brainpylib.register_op(name,
136
- cpu_func=cpu_func,
137
- gpu_func_translation=gpu_func,
138
- out_shapes=eval_shape,
139
- apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
+ f = brainpylib.register_op_with_numba(name,
+ cpu_func=cpu_func,
+ gpu_func_translation=gpu_func,
+ out_shapes=eval_shape,
+ apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
140
141
def fixed_op(*inputs, **info):
142
inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs])
0 commit comments