You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I know that in TF there was the notion of gradient tape that allowed efficient computation like:
""" Computing And Preserving Shared-Gradients """withtf.GradientTape(persistent=True) astape:
.... .... ...
returntraj""" Later on, use tape two times """withtape:
....
...
#Then Optimize first component using tape..self.opt(tape, loss, A)
""" Later, use tape again to optimize second component """withtape:
... ...
# Optimize the second componentself.opt(tape, loss, B)
This allows computing the first part of gradient once for both components. Is there anything similar in Jax?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hey all,
I know that in TF there was the notion of gradient tape that allowed efficient computation like:
This allows computing the first part of gradient once for both components. Is there anything similar in Jax?
Beta Was this translation helpful? Give feedback.
All reactions