-
I have a problem in the context of neural networks in which my loss function (to be optimized over To give some more technical context, the hidden parameters (weights and biases) of the NN are treated as optimizable variables while the parameters of the activation layer are computed using the hidden parameters and are not viewed as part of the gradient or hessian and not an optimization variable. It's somewhat similar to this work: Robust Training and Initialization of Deep Neural Networks where they solve a linear least squares system involving the hidden parameters in order to obtain the activation parameters. In my current implementation, jax would view the helper function as differentiable and I believe there is numerical instability due to trying to differentiable through a linear least squares solver. Either way, my intention like I said it to have the helper function not be differentiable. I tried doing something like passing a copy of To present an MWE, I tried to simplify things as much as possible. I know this example is probably nonsensical, but it computes the minimum of the function x^2 + a(x) where a(x) = (x-1)^4 using BFGS. The current example returns a solution of x=0.410245, which is the correct location of the minimum for x^2 + (x-1)^4. I would like for the optimizer to not view
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I suspect what you're looking for is def f(x):
a = lax.stop_gradient(helperFunc(x))
return jnp.squeeze(jnp.power(x, 2) + a) |
Beta Was this translation helpful? Give feedback.
I suspect what you're looking for is
jax.lax.stop_gradient
: