Target size (torch.Size([800])) must be the same as input size (torch.Size([800, 4])) #547
Answered
by
Mubassir1820
Mubassir1820
asked this question in
Q&A
-
Target size (torch.Size([800])) must be the same as input size (torch.Size([800, 4]))TODO:I have tried the exact steps in the training loop(also tried copy paste from the pdf) but for some reason this error keeps showing up My codeTODO: Add your code here, best to format with backticks as well, for example: [
# Fit the multi-class model to the data
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Set number of epochs
epochs = 100
# Put the data to the target device
X_blob_train, y_blob_train = X_blob_train.to(device), y_blob_train.to(device)
X_blob_test, y_blob_test = X_blob_test.to(device), y_blob_test.to(device)
# Loop through data
for epoch in range(epochs):
### Training
model_4.train()
y_logits = model_4(X_blob_train)
y_pred = torch.softmax(y_logits, dim = 1).argmax(dim = 1)
print(y_logits.shape, y_blob_train.shape)
loss = loss_fn(y_logits, y_blob_train)
acc = accuracy_fn(y_true = y_blob_train,
y_pred = y_pred)
optimizer.zero_grad()
loss.backward()
optimizer.step()
### Testing
model_4.eval()
with torch.inference_mode():
test_logits = model_4(X_blob_test)
test_preds = torch.softmax(test_logits, dim = 1).argmax(dim = 1)
test_loss = loss_fn(test_logits, y_blob_test)
test_acc = accuracy_fn(y_true=y_blob_test,
y_pred = test_preds)
# Print out whats happening
if epoch % 10 == 0:
print(f'Epoch: {epochs} | Loss: {loss:.4f}, Acc: {acc:.2f}% | Test Loss: {test_loss:.4f}, Test acc: {test_acc:.2f}%')]
``` <- be sure to delete the slashes, only keep the backticks
## What I've tried so far
TODO: I have tried using squeeze() like binary classification problem, but it still has the same issue |
Beta Was this translation helpful? Give feedback.
Answered by
Mubassir1820
Jul 18, 2023
Replies: 1 comment
-
So, I finally figured it out. I wrote |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
Mubassir1820
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
So, I finally figured it out. I wrote
loss_dn
instead ofloss_fn
and later referred toloss_fn
. As the loss function helps to convert the shape of y_logits to match with y_blob_train, the error kept popping up