@@ -57,6 +57,10 @@ def __init__(
5757 gpu_func : Callable = None ,
5858 apply_cpu_func_to_gpu : bool = False ,
5959 name : str = None ,
60+ batching_translation : Callable = None ,
61+ jvp_translation : Callable = None ,
62+ transpose_translation : Callable = None ,
63+ multiple_results : bool = False ,
6064 ):
6165 _check_brainpylib (register_op .__name__ )
6266 super (XLACustomOp , self ).__init__ (name = name )
@@ -77,11 +81,17 @@ def __init__(
7781 gpu_func = None
7882
7983 # register OP
80- self .op = brainpylib .register_op_with_numba (self .name ,
81- cpu_func = cpu_func ,
82- gpu_func = gpu_func ,
83- out_shapes = eval_shape ,
84- apply_cpu_func_to_gpu = apply_cpu_func_to_gpu )
84+ self .op = brainpylib .register_op_with_numba (
85+ self .name ,
86+ cpu_func = cpu_func ,
87+ gpu_func_translation = gpu_func ,
88+ out_shapes = eval_shape ,
89+ apply_cpu_func_to_gpu = apply_cpu_func_to_gpu ,
90+ batching_translation = batching_translation ,
91+ jvp_translation = jvp_translation ,
92+ transpose_translation = transpose_translation ,
93+ multiple_results = multiple_results ,
94+ )
8595
8696 def __call__ (self , * args , ** kwargs ):
8797 args = tree_map (lambda a : a .value if isinstance (a , JaxArray ) else a ,
0 commit comments