3
3
and modifies this to train on the CIFAR10 dataset. The same method generalizes
4
4
well to other datasets, but the modifications to the network may need to be changed.
5
5
6
- Video explanation: https://youtu.be/U4bHxEhMGNk
7
- Got any questions leave a comment on youtube :)
8
-
9
6
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
10
7
* 2020-04-08 Initial coding
8
+ * 2022-12-19 Updated comments, minor code changes, made sure it works with latest PyTorch
11
9
12
10
"""
13
11
22
20
) # Gives easier dataset managment and creates mini batches
23
21
import torchvision .datasets as datasets # Has standard datasets we can import in a nice way
24
22
import torchvision .transforms as transforms # Transformations we can perform on our dataset
23
+ from tqdm import tqdm
25
24
26
- # Set device
27
25
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
28
26
29
27
# Hyperparameters
32
30
batch_size = 1024
33
31
num_epochs = 5
34
32
35
- # Simple Identity class that let's input pass without changes
36
- class Identity (nn .Module ):
37
- def __init__ (self ):
38
- super (Identity , self ).__init__ ()
39
-
40
- def forward (self , x ):
41
- return x
42
-
43
-
44
33
# Load pretrain model & modify it
45
- model = torchvision .models .vgg16 (pretrained = True )
34
+ model = torchvision .models .vgg16 (weights = "DEFAULT" )
46
35
47
36
# If you want to do finetuning then set requires_grad = False
48
37
# Remove these two lines if you want to train entire model,
49
38
# and only want to load the pretrain weights.
50
39
for param in model .parameters ():
51
40
param .requires_grad = False
52
41
53
- model .avgpool = Identity ()
42
+ model .avgpool = nn . Identity ()
54
43
model .classifier = nn .Sequential (
55
44
nn .Linear (512 , 100 ), nn .ReLU (), nn .Linear (100 , num_classes )
56
45
)
@@ -71,7 +60,7 @@ def forward(self, x):
71
60
for epoch in range (num_epochs ):
72
61
losses = []
73
62
74
- for batch_idx , (data , targets ) in enumerate (train_loader ):
63
+ for batch_idx , (data , targets ) in enumerate (tqdm ( train_loader ) ):
75
64
# Get data to cuda if possible
76
65
data = data .to (device = device )
77
66
targets = targets .to (device = device )
0 commit comments