You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm a beginner with JAX and want to use MSE in MINST to train. I know that MSE is not a good method in classification. I'm trying it to find out is there anything I did wrongly with MSE because I've trained this project well with cross-entropy and I think MSE will also work with lower performance. However, I always get zero gradient no matter with one float value or an one-hot vector.
My code is as following and also notice that if I change the multiplication to subtraction in cross-entropy it will return zero grads as well. I don't know if it's a clue to figure out the bug. Please kindly help me to figure out what I did wrongly to apply MSE in JAX. Thank you!
What jax/jaxlib version are you using?
jax 0.2.19
jaxlib 0.1.70
Which accelerator(s) are you using?
CPU
Additional system info
Windows WSL
NVIDIA GPU info
No response
Code:
# Import some additional JAX and dataloader helpers
from jax.scipy.special import logsumexp
from jax.experimental import optimizers
from jax import grad, jit, vmap, value_and_grad
import torch
from torchvision import datasets, transforms
import time
import jax_net
import jax_net_c2
import jax.numpy as np
import jax.random as rd
batch_size = 100
num_epochs = 3
num_classes = 10
params = jax_net.initialize_mlp([784, 512, 512, 10], jax_net.key)
opt_state = jax_net.opt_init(params)
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
@jit
def update(params, x, y, opt_state):
""" Compute the gradient for a batch and update the parameters """
value, grads = value_and_grad(loss)(params, x, y)
opt_state = jax_net.opt_update(0, grads, opt_state)
return jax_net.get_params(opt_state), opt_state, value
def loss(params, in_arrays, targets): #dosen't work and return zero gradient
preds = jax_net.batch_forward(params, in_arrays)
return (np.square(preds - targets)).mean(axis=1).sum()
# def loss(params, in_arrays, targets): #works fine
# """ Compute the multi-class cross-entropy loss """
# preds = jax_net.batch_forward(params, in_arrays)
# return -np.sum(preds * targets)
def run_mnist_training_loop(num_epochs, opt_state, params, net_type="MLP"):
""" Implements a learning loop over epochs. """
log_acc_train, log_acc_test, train_loss = [], [], []
key = rd.PRNGKey(758493)
# Loop over the training epochs
for epoch in range(num_epochs):
start_time = time.time()
for batch_idx, (data, target) in enumerate(train_loader):
if net_type == "MLP":
x = np.array(data).reshape(data.size(0), 28*28)
elif net_type == "CNN":
x = np.array(data)
y = jax_net.one_hot(np.array(target), num_classes)
params, opt_state, loss = update(params, opt_state, x, y)
train_loss.append(loss)
print("loss: ", int(loss))
if batch_idx>1000:
break
epoch_time = time.time() - start_time
train_acc = jax_net.accuracy(params, train_loader, num_classes)
test_acc = jax_net.accuracy(params, test_loader, num_classes)
log_acc_train.append(train_acc)
log_acc_test.append(test_acc)
print("Epoch {} | T: {:0.2f} | Train A: {:0.3f} | Test A: {:0.3f}".format(epoch+1, epoch_time,
train_acc, test_acc))
return train_loss, log_acc_train, log_acc_test
train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,
opt_state, params,
net_type="MLP")
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I'm a beginner with JAX and want to use MSE in MINST to train. I know that MSE is not a good method in classification. I'm trying it to find out is there anything I did wrongly with MSE because I've trained this project well with cross-entropy and I think MSE will also work with lower performance. However, I always get zero gradient no matter with one float value or an one-hot vector.
My code is as following and also notice that if I change the multiplication to subtraction in cross-entropy it will return zero grads as well. I don't know if it's a clue to figure out the bug. Please kindly help me to figure out what I did wrongly to apply MSE in JAX. Thank you!
What jax/jaxlib version are you using?
jax 0.2.19
jaxlib 0.1.70
Which accelerator(s) are you using?
CPU
Additional system info
Windows WSL
NVIDIA GPU info
No response
Code:
Beta Was this translation helpful? Give feedback.
All reactions