You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def L(p,x, s):
"""
Args:
p: A scalar parameter.
x: A vector.
s: A scalar start time.
"""
return jnp.sum((f(p,x)[s:]-5))**2
So in the loss function, I am only interested about the elements indexed from s, and because of the slicing [s:], the cotangent g will be of the form [0,...,0,x,...,y] where the values are 0 until the index s. Now notice that in the return statement of f_bwd I am adding 1 to the vector g. Actually what I want to do is add 1 only for the element that are indexed from s. How can I detect in f_bwd the index s ?
Thank you very much for any help !
PS: The title is maybe a bit misleading, if anyone has a better proposition I would be happy to change !
PPS: To give a bit more background on my Neural ODE problem. The function f represents an ODE solver and I want to compute the gradients of the parameters using the adjoint method. The thing is that the solver outputs T time steps so the cotangent vector in the backward pass will be of length T. But I am interested only by the last s time steps of the solution to compute my loss, hence how do I detect in the cotangent those last s values.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I am working on some NeuralODE code and I am facing an issue that, I hope, can be understood from the minimal example below.
Assume I have the following function
fand I want to define acustom_vjprule:I compute some loss on it as follow:
So in the loss function, I am only interested about the elements indexed from
s, and because of the slicing[s:], the cotangentgwill be of the form[0,...,0,x,...,y]where the values are0until the indexs. Now notice that in thereturnstatement off_bwdI am adding1to the vectorg. Actually what I want to do is add1only for the element that are indexed froms. How can I detect inf_bwdthe indexs?Thank you very much for any help !
PS: The title is maybe a bit misleading, if anyone has a better proposition I would be happy to change !
PPS: To give a bit more background on my Neural ODE problem. The function
frepresents an ODE solver and I want to compute the gradients of the parameters using the adjoint method. The thing is that the solver outputsTtime steps so the cotangent vector in the backward pass will be of lengthT. But I am interested only by the laststime steps of the solution to compute my loss, hence how do I detect in the cotangent those lastsvalues.Beta Was this translation helpful? Give feedback.
All reactions