@@ -50,7 +50,7 @@ def _get_grad_prim():
5050 grad_prim .prim_type = "higher_order"
5151
5252 @grad_prim .def_impl
53- def _ (* args , argnums , jaxpr , n_consts , method , h ):
53+ def _grad_def_impl (* args , argnums , jaxpr , n_consts , method , h ):
5454 if method or h : # pragma: no cover
5555 raise ValueError (f"Invalid values '{ method = } ' and '{ h = } ' without QJIT." )
5656 consts = args [:n_consts ]
@@ -63,7 +63,7 @@ def func(*inner_args):
6363
6464 # pylint: disable=unused-argument
6565 @grad_prim .def_abstract_eval
66- def _ (* args , argnums , jaxpr , n_consts , method , h ):
66+ def _grad_abstract_eval (* args , argnums , jaxpr , n_consts , method , h ):
6767 if len (jaxpr .outvars ) != 1 or jaxpr .outvars [0 ].aval .shape != ():
6868 raise TypeError ("Grad only applies to scalar-output functions. Try jacobian." )
6969 return tuple (args [i + n_consts ] for i in argnums )
@@ -90,7 +90,7 @@ def _get_jacobian_prim():
9090 jacobian_prim .prim_type = "higher_order"
9191
9292 @jacobian_prim .def_impl
93- def _ (* args , argnums , jaxpr , n_consts , method , h ):
93+ def _jacobian_def_impl (* args , argnums , jaxpr , n_consts , method , h ):
9494 if method or h : # pragma: no cover
9595 raise ValueError (f"Invalid values '{ method = } ' and '{ h = } ' without QJIT." )
9696 consts = args [:n_consts ]
@@ -103,7 +103,7 @@ def func(*inner_args):
103103
104104 # pylint: disable=unused-argument
105105 @jacobian_prim .def_abstract_eval
106- def _ (* args , argnums , jaxpr , n_consts , method , h ):
106+ def _jacobian_abstract_eval (* args , argnums , jaxpr , n_consts , method , h ):
107107 in_avals = tuple (args [i + n_consts ] for i in argnums )
108108 out_shapes = tuple (outvar .aval .shape for outvar in jaxpr .outvars )
109109 return [
0 commit comments