Skip to content
Discussion options

You must be logged in to vote

I believe flax.TrainState includes a step parameter which is an integer. If you're attempting to take the gradient with respect to the state, that would cause this error. Perhaps you meant to take the gradient with respect to just state.params?

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@pabloduque0
Comment options

Answer selected by pabloduque0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants