You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am currently using pytorch but wonder whether I should to switch to JAX. I want to do the following:
Let's say I have a network composed of layers: input -> layer 1 -> layer 2 -> ... -> layer n -> output
I need the Jacobian with regards to the input for each layer in a differentiable way, so that I can use it in the loss, without being too inefficient. So what I want to optimise is:
loss = criterion(network(input), targets) + alpha*sum_{layer i in layers} f(Jacobian_i(input))
I think the most straightforward and reasonably efficient way is to use forward-mode differentiation to record the Jacobians, but I wonder how this is possible? I don't see a way to record the interim jacobians during forward-propagation.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I am currently using pytorch but wonder whether I should to switch to JAX. I want to do the following:
Let's say I have a network composed of layers:
input -> layer 1 -> layer 2 -> ... -> layer n -> output
I need the Jacobian with regards to the input for each layer in a differentiable way, so that I can use it in the loss, without being too inefficient. So what I want to optimise is:
loss = criterion(network(input), targets) + alpha*sum_{layer i in layers} f(Jacobian_i(input))
I think the most straightforward and reasonably efficient way is to use forward-mode differentiation to record the Jacobians, but I wonder how this is possible? I don't see a way to record the interim jacobians during forward-propagation.
Beta Was this translation helpful? Give feedback.
All reactions