@@ -94,9 +94,9 @@ def __call__(self, *args, **kwargs):
9494
9595
9696def register_op (
97- op_name : str ,
97+ name : str ,
98+ eval_shape : Union [Callable , ShapedArray , Sequence [ShapedArray ]],
9899 cpu_func : Callable ,
99- out_shapes : Union [Callable , ShapedArray , Sequence [ShapedArray ]],
100100 gpu_func : Callable = None ,
101101 apply_cpu_func_to_gpu : bool = False
102102):
@@ -105,13 +105,13 @@ def register_op(
105105
106106 Parameters
107107 ----------
108- op_name : str
108+ name : str
109109 Name of the operators.
110110 cpu_func: Callble
111111 A callable numba-jitted function or pure function (can be lambda function) running on CPU.
112112 gpu_func: Callable, default = None
113113 A callable cuda-jitted kernel running on GPU.
114- out_shapes : Callable, ShapedArray, Sequence[ShapedArray], default = None
114+ eval_shape : Callable, ShapedArray, Sequence[ShapedArray], default = None
115115 Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or
116116 a sequence of `ShapedArray`. If it is a function, it takes as input the argument
117117 shapes and dtypes and should return correct output shapes of `ShapedArray`.
@@ -123,10 +123,10 @@ def register_op(
123123 A jitable JAX function.
124124 """
125125 _check_brainpylib (register_op .__name__ )
126- f = brainpylib .register_op (op_name ,
126+ f = brainpylib .register_op (name ,
127127 cpu_func = cpu_func ,
128128 gpu_func = gpu_func ,
129- out_shapes = out_shapes ,
129+ out_shapes = eval_shape ,
130130 apply_cpu_func_to_gpu = apply_cpu_func_to_gpu )
131131
132132 def fixed_op (* inputs ):
0 commit comments