diff --git a/mnist/main.py b/mnist/main.py index dee5a384cb..e7618c606d 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -90,8 +90,12 @@ def main(): help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') - parser.add_argument('--save-model', action='store_true', + parser.add_argument('--save-model', action='store_true', help='For Saving the current Model') + parser.add_argument('--model-path', type=str, default='mnist_cnn.pt', + help='path to save the trained model') + parser.add_argument('--load-model', type=str, default=None, + help='Path to load a pre-trained model') args = parser.parse_args() use_accel = not args.no_accel and torch.accelerator.is_available() @@ -125,16 +129,22 @@ def main(): test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) model = Net().to(device) - optimizer = optim.Adadelta(model.parameters(), lr=args.lr) - scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) + if args.load_model: + print(f"Loading model from {args.load_model}") + model.load_state_dict(torch.load(args.load_model, map_location=device)) test(model, device, test_loader) - scheduler.step() + else: + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() - if args.save_model: - torch.save(model.state_dict(), "mnist_cnn.pt") + if args.save_model: + torch.save(model.state_dict(), args.model_path) if __name__ == '__main__':