22
33from typing import Union , Sequence , Callable
44
5- from jax .abstract_arrays import ShapedArray
5+ from jax .core import ShapedArray
66from jax .tree_util import tree_map
77
88from brainpy .base import Base
@@ -57,6 +57,10 @@ def __init__(
5757 gpu_func : Callable = None ,
5858 apply_cpu_func_to_gpu : bool = False ,
5959 name : str = None ,
60+ batching_translation : Callable = None ,
61+ jvp_translation : Callable = None ,
62+ transpose_translation : Callable = None ,
63+ multiple_results : bool = False ,
6064 ):
6165 _check_brainpylib (register_op .__name__ )
6266 super (XLACustomOp , self ).__init__ (name = name )
@@ -77,19 +81,25 @@ def __init__(
7781 gpu_func = None
7882
7983 # register OP
80- self .op = brainpylib .register_op (self .name ,
81- cpu_func = cpu_func ,
82- gpu_func = gpu_func ,
83- out_shapes = eval_shape ,
84- apply_cpu_func_to_gpu = apply_cpu_func_to_gpu )
84+ self .op = brainpylib .register_op_with_numba (
85+ self .name ,
86+ cpu_func = cpu_func ,
87+ gpu_func_translation = gpu_func ,
88+ out_shapes = eval_shape ,
89+ apply_cpu_func_to_gpu = apply_cpu_func_to_gpu ,
90+ batching_translation = batching_translation ,
91+ jvp_translation = jvp_translation ,
92+ transpose_translation = transpose_translation ,
93+ multiple_results = multiple_results ,
94+ )
8595
8696 def __call__ (self , * args , ** kwargs ):
8797 args = tree_map (lambda a : a .value if isinstance (a , JaxArray ) else a ,
8898 args , is_leaf = lambda a : isinstance (a , JaxArray ))
8999 kwargs = tree_map (lambda a : a .value if isinstance (a , JaxArray ) else a ,
90100 kwargs , is_leaf = lambda a : isinstance (a , JaxArray ))
91101 res = self .op .bind (* args , ** kwargs )
92- return res [ 0 ] if len ( res ) == 1 else res
102+ return res
93103
94104
95105def register_op (
@@ -122,15 +132,15 @@ def register_op(
122132 A jitable JAX function.
123133 """
124134 _check_brainpylib (register_op .__name__ )
125- f = brainpylib .register_op (name ,
126- cpu_func = cpu_func ,
127- gpu_func = gpu_func ,
128- out_shapes = eval_shape ,
129- apply_cpu_func_to_gpu = apply_cpu_func_to_gpu )
135+ f = brainpylib .register_op_with_numba (name ,
136+ cpu_func = cpu_func ,
137+ gpu_func_translation = gpu_func ,
138+ out_shapes = eval_shape ,
139+ apply_cpu_func_to_gpu = apply_cpu_func_to_gpu )
130140
131141 def fixed_op (* inputs , ** info ):
132142 inputs = tuple ([i .value if isinstance (i , JaxArray ) else i for i in inputs ])
133143 res = f .bind (* inputs , ** info )
134- return res [ 0 ] if len ( res ) == 1 else res
144+ return res
135145
136146 return fixed_op
0 commit comments