diff --git a/mnist/main.py b/mnist/main.py index dee5a384cb..e3a5c44d17 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -54,7 +54,7 @@ def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 - with torch.no_grad(): + with torch.inference_mode(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data)