1
+ """
2
+ Example code of how to use mixed precision training with PyTorch. In this
3
+ case with a (very) small and simple CNN training on MNIST dataset. This
4
+ example is based on the official PyTorch documentation on mixed precision
5
+ training.
6
+
7
+ Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
8
+ * 2020-04-10 Initial programming
9
+ * 2022-12-19 Updated comments, made sure it works with latest PyTorch
10
+
11
+ """
12
+
1
13
# Imports
2
14
import torch
3
15
import torch .nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
4
16
import torch .optim as optim # For all Optimization algorithms, SGD, Adam, etc.
5
17
import torch .nn .functional as F # All functions that don't have any parameters
6
- from torch .utils .data import DataLoader # Gives easier dataset managment and creates mini batches
18
+ from torch .utils .data import (
19
+ DataLoader ,
20
+ ) # Gives easier dataset managment and creates mini batches
7
21
import torchvision .datasets as datasets # Has standard datasets we can import in a nice way
8
22
import torchvision .transforms as transforms # Transformations we can perform on our dataset
9
23
12
26
class CNN (nn .Module ):
13
27
def __init__ (self , in_channels = 1 , num_classes = 10 ):
14
28
super (CNN , self ).__init__ ()
15
- self .conv1 = nn .Conv2d (in_channels = 1 , out_channels = 420 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = (1 , 1 ))
29
+ self .conv1 = nn .Conv2d (
30
+ in_channels = 1 ,
31
+ out_channels = 420 ,
32
+ kernel_size = (3 , 3 ),
33
+ stride = (1 , 1 ),
34
+ padding = (1 , 1 ),
35
+ )
16
36
self .pool = nn .MaxPool2d (kernel_size = (2 , 2 ), stride = (2 , 2 ))
17
- self .conv2 = nn .Conv2d (in_channels = 420 , out_channels = 1000 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = (1 , 1 ))
37
+ self .conv2 = nn .Conv2d (
38
+ in_channels = 420 ,
39
+ out_channels = 1000 ,
40
+ kernel_size = (3 , 3 ),
41
+ stride = (1 , 1 ),
42
+ padding = (1 , 1 ),
43
+ )
18
44
self .fc1 = nn .Linear (1000 * 7 * 7 , num_classes )
19
45
20
46
def forward (self , x ):
@@ -29,7 +55,8 @@ def forward(self, x):
29
55
30
56
31
57
# Set device
32
- device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
58
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
59
+ assert device == "cuda" , "GPU not available"
33
60
34
61
# Hyperparameters
35
62
in_channel = 1
@@ -39,9 +66,13 @@ def forward(self, x):
39
66
num_epochs = 5
40
67
41
68
# Load Data
42
- train_dataset = datasets .MNIST (root = 'dataset/' , train = True , transform = transforms .ToTensor (), download = True )
69
+ train_dataset = datasets .MNIST (
70
+ root = "dataset/" , train = True , transform = transforms .ToTensor (), download = True
71
+ )
43
72
train_loader = DataLoader (dataset = train_dataset , batch_size = batch_size , shuffle = True )
44
- test_dataset = datasets .MNIST (root = 'dataset/' , train = False , transform = transforms .ToTensor (), download = True )
73
+ test_dataset = datasets .MNIST (
74
+ root = "dataset/" , train = False , transform = transforms .ToTensor (), download = True
75
+ )
45
76
test_loader = DataLoader (dataset = test_dataset , batch_size = batch_size , shuffle = True )
46
77
47
78
# Initialize network
@@ -89,10 +120,12 @@ def check_accuracy(loader, model):
89
120
num_correct += (predictions == y ).sum ()
90
121
num_samples += predictions .size (0 )
91
122
92
- print (f'Got { num_correct } / { num_samples } with accuracy { float (num_correct ) / float (num_samples ) * 100 :.2f} ' )
123
+ print (
124
+ f"Got { num_correct } / { num_samples } with accuracy { float (num_correct ) / float (num_samples ) * 100 :.2f} "
125
+ )
93
126
94
127
model .train ()
95
128
96
129
97
130
check_accuracy (train_loader , model )
98
- check_accuracy (test_loader , model )
131
+ check_accuracy (test_loader , model )
0 commit comments