-
#next (model_0.parameters())
#torch.manual_seed(53)
def eval_model(model:torch.nn.Module,data_loader:torch.utils.data.DataLoader,loss: torch.nn.Module,accuracy):
Loss,acc=0,0
model.eval()
with torch.inference_mode():
for x,y in data_loader:
y_pred=model(x)
Loss+=loss(y_pred,y)
acc=accuracy (y_true=y,y_pred=y_pred.argmax(dim=1))
Loss/=len(data_loader)
acc/=len(data_loader)
return {"MODEL_NAME": model.__class__.__name__ ,"MODEL_LOSS":Loss.item() ,"model_acc":acc}
model_0_results=eval_model(model=model_0,data_loader=test_dataloader,loss:Loss,accuracy=acc)
model_0_results It gives the error , what's wrong?
|
Beta Was this translation helpful? Give feedback.
Answered by
mrdbourke
Aug 4, 2023
Replies: 1 comment 2 replies
-
Hi @Aditi4AI , I think Python is getting confused between your variables You could try rewriting your function to set Also in the line:
You have written But I have updated the whole code to reflect the updates: #next (model_0.parameters())
#torch.manual_seed(53)
def eval_model(model:torch.nn.Module, data_loader:torch.utils.data.DataLoader, loss_fn: torch.nn.Module, accuracy):
test_loss, acc = 0, 0
model.eval()
with torch.inference_mode():
for x, y in data_loader:
y_pred = model(x)
test_loss += loss_fn(y_pred, y)
acc = accuracy(y_true=y, y_pred=y_pred.argmax(dim=1))
test_loss /= len(data_loader)
acc /= len(data_loader)
return {"MODEL_NAME": model.__class__.__name__ , "MODEL_LOSS": test_loss.item(), "model_acc": acc}
model_0_results = eval_model(model=model_0, data_loader=test_dataloader, loss_fn=Loss, accuracy=acc)
model_0_results |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
mrdbourke
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @Aditi4AI ,
I think Python is getting confused between your variables
Loss
andloss
.You could try rewriting your function to set
loss
asloss_fn
andLoss
astest_loss
to escape the confusion.Also in the line:
You have written
loss:Loss
, this should beloss=Loss
.But I have updated the whole code to reflect the updates: