3
3
case with a (very) small and simple Feedforward Network training on MNIST
4
4
dataset with a learning rate scheduler. In this case ReduceLROnPlateau
5
5
scheduler is used, but can easily be changed to any of the other schedulers
6
- available.
7
-
8
- Video explanation: https://youtu.be/P31hB37g4Ak
9
- Got any questions leave a comment on youtube :)
6
+ available. I think simply reducing LR by 1/10 or so, when loss plateaus is
7
+ a good default.
10
8
11
9
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
12
10
* 2020-04-10 Initial programming
11
+ * 2022-12-19 Updated comments, made sure it works with latest PyTorch
13
12
14
13
"""
15
14
28
27
29
28
# Hyperparameters
30
29
num_classes = 10
31
- learning_rate = 0.1
30
+ learning_rate = (
31
+ 0.1 # way too high learning rate, but we want to see the scheduler in action
32
+ )
32
33
batch_size = 128
33
34
num_epochs = 100
34
35
47
48
48
49
# Define Scheduler
49
50
scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
50
- optimizer , factor = 0.1 , patience = 5 , verbose = True
51
+ optimizer , factor = 0.1 , patience = 10 , verbose = True
51
52
)
52
53
53
54
# Train Network
67
68
losses .append (loss .item ())
68
69
69
70
# backward
71
+ optimizer .zero_grad ()
70
72
loss .backward ()
71
-
72
- # gradient descent or adam step
73
- # scheduler.step(loss)
74
73
optimizer .step ()
75
- optimizer .zero_grad ()
76
74
77
75
mean_loss = sum (losses ) / len (losses )
76
+ mean_loss = round (mean_loss , 2 ) # we should see difference in loss at 2 decimals
78
77
79
78
# After each epoch do scheduler.step, note in this scheduler we need to send
80
- # in loss for that epoch!
79
+ # in loss for that epoch! This can also be set using validation loss, and also
80
+ # in the forward loop we can do on our batch but then we might need to modify
81
+ # the patience parameter
81
82
scheduler .step (mean_loss )
82
- print (f"Cost at epoch { epoch } is { mean_loss } " )
83
+ print (f"Average loss for epoch { epoch } was { mean_loss } " )
83
84
84
85
# Check accuracy on training & test to see how good our model
85
86
def check_accuracy (loader , model ):
@@ -90,6 +91,7 @@ def check_accuracy(loader, model):
90
91
with torch .no_grad ():
91
92
for x , y in loader :
92
93
x = x .to (device = device )
94
+ x = x .reshape (x .shape [0 ], - 1 )
93
95
y = y .to (device = device )
94
96
95
97
scores = model (x )
0 commit comments