What is the difference between treemap and opt? #10514
Unanswered
marcuswang6
asked this question in
Q&A
Replies: 1 comment
-
optax internally use treemap, but provides a lot of gradient transform to implement advanced optimizer beyond SGD. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
For params 'p', use sample x and y to update it, what's the difference between following codes?
grads = jax.grad(loss)(p, x, y)
updates, state = opt.update(grads, state)
p = optax.apply_updates(p, updates)
grads = jax.grad(loss)(p, x, y)
inner_sgd_fn = lambda g, state: (state - alpha*g)
p = jax.tree_util.tree_multimap(inner_sgd_fn, grads, p)
Beta Was this translation helpful? Give feedback.
All reactions