How to define a transpose rule for a Primitive with the saved inputs? #13723
-
I want to define a Primitive for some customized computation. In my computation, I want to save some input data for the later backward gradient computation. However, I cannot find a way to save the input data. Here is my code: def _heaviside_abstract(x, *, alpha):
return [x]
def _heaviside_imp(x, *, alpha):
z = jnp.asarray(x >= 0, dtype=x.dtype)
return [z]
def _heaviside_batching(args, axes, *, alpha):
return heaviside_p.bind(*args, alpha=alpha), axes
def _heaviside_jvp(primals, tangents, *, alpha):
x, = primals
xt, = tangents
primal_outs = heaviside_p.bind(x, alpha=alpha)
tangent_outs = heaviside_p.bind(xt, alpha=alpha)
return primal_outs, tangent_outs
def _heaviside_transpose(ct, x, *, alpha):
# QUESTION here!!
# how to save `x` for gradient computation?
dE_dx = ct[0] / (alpha * jnp.abs(x.aval) + 1.0) ** 2
return [dE_dx]
heaviside_p = Primitive('heaviside_p')
heaviside_p.multiple_results = True
heaviside_p.def_abstract_eval(_heaviside_abstract)
heaviside_p.def_impl(_heaviside_imp)
batching.primitive_batchers[heaviside_p] = _heaviside_batching
ad.primitive_jvps[heaviside_p] = _heaviside_jvp
ad.primitive_transposes[heaviside_p] = _heaviside_transpose |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 4 replies
-
It's not clear to me what you mean by "saving x". At which point in the code would you like to retrieve the saved value? Can you edit your code to show an example of what you have in mind? |
Beta Was this translation helpful? Give feedback.
-
By definition, the transpose rule cannot depend on the value def _heaviside_jvp(primals, tangents, *, alpha):
x, = primals
xt, = tangents
primal_outs = heaviside_p.bind(x, alpha=alpha)
tangent_outs = [jnp.zeros_like(xt)]
return primal_outs, tangent_outs That's actually enough to solve the problem for your particular example, since the primitive is no longer used in a tangent computation and thus the transpose is not going to be called. But if you want to define your transpose function, I believe it would have to look something like this: def _heaviside_transpose(ct, x, *, alpha):
if type(ct) is ad.Zero:
return [ad.Zero(x.aval)]
else:
return [ct[0] * ad.zeros_like_aval(x.aval)] The reason is that the tangents for your function are essentially |
Beta Was this translation helpful? Give feedback.
-
Dear @jakevdp , thanks for your detailed answer. What I am looking for is to implement surrogate gradient functions [1] for training spiking neural networks (SNN) in JAX. We know that the neuron spike is binary (the Heaviside function I used above), so its gradients are already zeros. The surrogate function solves this issue by calculating the Heaviside function in the forward pass while computing the gradient of a Sigmoid function (which has a similar shape to the Heaviside function) in the backward pass. This is very similar to your posted tutorial Why are gradients zero for functions based on sort order?. However, we cannot directly replace the discrete spiking function with the continuous function (like Sigmoid), because the later produces a continuous value while the former results in a binary 0/1 pattern (this is what spike means). Currently, we can implement a surrogate gradient function through @jax.custom_gradient
def surrogate(x, alpha=1.)
z = jnp.asarray(x >= 0, dtype=x.dtype)
def grad(dz):
dx = dz / (alpha * jnp.abs(x) + 1.0) ** 2
return dx, None
return z, grad However, I am wondering when I customize a dedicated operator for SNN computing through C++/CUDA in which I bind them with a JAX Specifically, when I define the above Thanks again for your careful explains. [1] E. O. Neftci, H. Mostafa and F. Zenke, "Surrogate Gradient Learning in Spiking Neural Networks: Bringing the Power of Gradient-Based Optimization to Spiking Neural Networks," in IEEE Signal Processing Magazine, vol. 36, no. 6, pp. 51-63, Nov. 2019, doi: 10.1109/MSP.2019.2931595 |
Beta Was this translation helpful? Give feedback.
By definition, the transpose rule cannot depend on the value
x
. Given that you're implementing a Heaviside function, I think there are some problems with the way you've defined the derivative rules. First of all, the gradient of a heaviside function is everywhere zero. The only point where this might come into question is atx = 0
, where one might argue that the gradient is infinite (jnp.inf
), or perhaps undefined (jnp.nan
); however zero is a reasonable result here for reasons discussed in Why are gradients zero for functions based on sort order?. So your JVP rule might look like this: