Skip to content

Commit bb63a74

Browse files
committed
mnist: support --model-path and --load-model
If specify --load-model will only run test(), if specify --model-path=A, the saved model will be renamed to A. Example: python main.py --save-model --model-path=mnist_cnn.pt python main.py --load-model=mnist_cnn.pt Signed-off-by: Rong Tao <[email protected]>
1 parent 77f55b9 commit bb63a74

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

mnist/main.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,12 @@ def main():
9090
help='random seed (default: 1)')
9191
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
9292
help='how many batches to wait before logging training status')
93-
parser.add_argument('--save-model', action='store_true',
93+
parser.add_argument('--save-model', action='store_true',
9494
help='For Saving the current Model')
95+
parser.add_argument('--model-path', type=str, default='mnist_cnn.pt',
96+
help='path to save the trained model')
97+
parser.add_argument('--load-model', type=str, default=None,
98+
help='Path to load a pre-trained model')
9599
args = parser.parse_args()
96100

97101
use_accel = not args.no_accel and torch.accelerator.is_available()
@@ -125,16 +129,22 @@ def main():
125129
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
126130

127131
model = Net().to(device)
128-
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
129132

130-
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
131-
for epoch in range(1, args.epochs + 1):
132-
train(args, model, device, train_loader, optimizer, epoch)
133+
if args.load_model:
134+
print(f"Loading model from {args.load_model}")
135+
model.load_state_dict(torch.load(args.load_model, map_location=device))
133136
test(model, device, test_loader)
134-
scheduler.step()
137+
else:
138+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
139+
140+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
141+
for epoch in range(1, args.epochs + 1):
142+
train(args, model, device, train_loader, optimizer, epoch)
143+
test(model, device, test_loader)
144+
scheduler.step()
135145

136-
if args.save_model:
137-
torch.save(model.state_dict(), "mnist_cnn.pt")
146+
if args.save_model:
147+
torch.save(model.state_dict(), args.model_path)
138148

139149

140150
if __name__ == '__main__':

0 commit comments

Comments
 (0)