(Closed) Getting intermediate training state and outputs #7173
Answered
by
pulkitgopalani
pulkitgopalani
asked this question in
Q&A
-
Hi! I'm relatively new to JAX, and was trying to train a neural net which generates images from latent vectors. I have the following function as the main training step: def train_step(optimizer, image, latent):
def loss_fn(params):
pred, new_state = Generator().apply(
params, latent, mutable=["moving_stats"]
)
return l1_loss(pred, image)
loss_grad = jax.value_and_grad(
loss_fn,
# has_aux=True,
)
loss, grad = loss_grad(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
return loss, optimizer I tried using the returned value of the optimizer to initialize a new Generator() and generate images, but those do not match with the lower training loss in every epoch. Basically, I get the same output for all epochs, even when the training error has reduced a lot. Am I missing something basic here? Thanks! |
Beta Was this translation helpful? Give feedback.
Answered by
pulkitgopalani
Jul 5, 2021
Replies: 1 comment
-
Closed. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
pulkitgopalani
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Closed.