1+ '''Train CIFAR10 with PyTorch.'''
2+ import imp
3+ import torch
4+ import torch .nn as nn
5+ import torch .optim as optim
6+ import torch .backends .cudnn as cudnn
7+ import torchvision
8+ import torchvision .transforms as transforms
9+
10+ import os
11+ import argparse
12+ from utils import get_acc ,EarlyStopping
13+ from dataloader import get_test_dataloader , get_training_dataloader
14+ from tqdm import tqdm
15+
16+
17+ classes = ('plane' , 'car' , 'bird' , 'cat' , 'deer' ,
18+ 'dog' , 'frog' , 'horse' , 'ship' , 'truck' )
19+
20+ if __name__ == '__main__' :
21+ parser = argparse .ArgumentParser (description = 'PyTorch CIFAR10 Training' )
22+ parser .add_argument ('--lr' , default = 0.1 , type = float , help = 'learning rate' )
23+ parser .add_argument ('--cuda' , action = 'store_true' , default = False , help = ' use GPU?' )
24+ parser .add_argument ('--batch-size' , default = 64 , type = int , help = "Batch Size for Training" )
25+ parser .add_argument ('--num-workers' , default = 2 , type = int , help = 'num-workers' )
26+ parser .add_argument ('--net' , type = str , default = 'MobileNetv1' , help = 'net type' )
27+ parser .add_argument ('--epochs' , type = int , default = 20 , help = 'Epochs' )
28+ parser .add_argument ('--resume' , '-r' , action = 'store_true' , help = 'resume from checkpoint' )
29+ parser .add_argument ('--patience' , '-p' , type = int , default = 7 , help = 'patience for Early stop' )
30+ args = parser .parse_args ()
31+
32+ print (args )
33+ best_acc = 0 # best test accuracy
34+ start_epoch = 0 # start from epoch 0 or last checkpoint epoch
35+
36+ # Train Data
37+ trainloader = get_training_dataloader (batch_size = args .batch_size , num_workers = args .num_workers )
38+ testloader = get_test_dataloader (batch_size = args .batch_size , num_workers = args .num_workers , shuffle = False )
39+ # Model
40+ print ('==> Building model..' )
41+ if args .net == 'VGG16' :
42+ from nets .VGG import VGG
43+ net = VGG ('VGG16' )
44+ elif args .net == 'VGG19' :
45+ from nets .VGG import VGG
46+ net = VGG ('VGG19' )
47+ elif args .net == 'ResNet18' :
48+ from nets .ResNet import ResNet18
49+ net = ResNet18 ()
50+ elif args .net == 'ResNet34' :
51+ from nets .ResNet import ResNet34
52+ net = ResNet34 ()
53+ elif args .net == 'LeNet' :
54+ from nets .LeNet5 import LeNet5
55+ net = LeNet5 ()
56+ elif args .net == 'AlexNet' :
57+ from nets .AlexNet import AlexNet
58+ net = AlexNet ()
59+ elif args .net == 'DenseNet' :
60+ from nets .DenseNet import densenet_cifar
61+ net = densenet_cifar ()
62+ elif args .net == 'MobileNetv1' :
63+ from nets .MobileNetv1 import MobileNet
64+ net = MobileNet ()
65+ elif args .net == 'MobileNetv2' :
66+ from nets .MobileNetv2 import MobileNetV2
67+ net = MobileNetV2 ()
68+
69+ if args .cuda and torch .cuda .is_available ():
70+ device = 'cuda'
71+ net = torch .nn .DataParallel (net )
72+ cudnn .benchmark = True
73+ else :
74+ device = 'cpu'
75+
76+
77+ if args .resume :
78+ # Load checkpoint.
79+ print ('==> Resuming from checkpoint..' )
80+ assert os .path .isdir ('checkpoint' ), 'Error: no checkpoint directory found!'
81+ checkpoint = torch .load ('./checkpoint/{}_ckpt.pth' .format (args .net ))
82+ net .load_state_dict (checkpoint ['net' ])
83+ best_acc = checkpoint ['acc' ]
84+ start_epoch = checkpoint ['epoch' ]
85+ args .lr = checkpoint ['lr' ]
86+
87+ early_stopping = EarlyStopping (patience = args .patience , verbose = True )
88+ criterion = nn .CrossEntropyLoss ()
89+ optimizer = optim .SGD (net .parameters (), lr = args .lr ,
90+ momentum = 0.9 , weight_decay = 5e-4 )
91+ scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (optimizer , T_max = 200 )
92+
93+ epochs = args .epochs
94+ def train (epoch ):
95+ epoch_step = len (trainloader )
96+ if epoch_step == 0 :
97+ raise ValueError ("数据集过小,无法进行训练,请扩充数据集,或者减小batchsize" )
98+ net .train ()
99+ train_loss = 0
100+ train_acc = 0
101+ print ('Start Train' )
102+ with tqdm (total = epoch_step ,desc = f'Epoch { epoch + 1 } /{ epochs } ' ,postfix = dict ,mininterval = 0.3 ) as pbar :
103+ for step ,(im ,label ) in enumerate (trainloader ,start = 0 ):
104+ im = im .to (device )
105+ label = label .to (device )
106+ #---------------------
107+ # 释放内存
108+ #---------------------
109+ if hasattr (torch .cuda , 'empty_cache' ):
110+ torch .cuda .empty_cache ()
111+ #----------------------#
112+ # 清零梯度
113+ #----------------------#
114+ optimizer .zero_grad ()
115+ #----------------------#
116+ # 前向传播forward
117+ #----------------------#
118+ outputs = net (im )
119+ #----------------------#
120+ # 计算损失
121+ #----------------------#
122+ loss = criterion (outputs ,label )
123+ train_loss += loss .data
124+ train_acc += get_acc (outputs ,label )
125+ #----------------------#
126+ # 反向传播
127+ #----------------------#
128+ # backward
129+ loss .backward ()
130+ # 更新参数
131+ optimizer .step ()
132+ lr = optimizer .param_groups [0 ]['lr' ]
133+ pbar .set_postfix (** {'Train Loss' : train_loss .item ()/ (step + 1 ),
134+ 'Train Acc' :train_acc .item ()/ (step + 1 ),
135+ 'Lr' : lr })
136+ pbar .update (1 )
137+ # train_loss = train_loss.item() / len(trainloader)
138+ # train_acc = train_acc.item() * 100 / len(trainloader)
139+ print ('Finish Train' )
140+ def test (epoch ):
141+ global best_acc
142+ epoch_step_test = len (testloader )
143+ if epoch_step_test == 0 :
144+ raise ValueError ("数据集过小,无法进行训练,请扩充数据集,或者减小batchsize" )
145+
146+ net .eval ()
147+ test_loss = 0
148+ test_acc = 0
149+ print ('Start Test' )
150+ #--------------------------------
151+ # 相同方法,同train
152+ #--------------------------------
153+ with tqdm (total = epoch_step_test ,desc = f'Epoch { epoch + 1 } /{ epochs } ' ,postfix = dict ,mininterval = 0.3 ) as pbar2 :
154+ for step ,(im ,label ) in enumerate (testloader ,start = 0 ):
155+ im = im .to (device )
156+ label = label .to (device )
157+ with torch .no_grad ():
158+ if step >= epoch_step_test :
159+ break
160+
161+ # 释放内存
162+ if hasattr (torch .cuda , 'empty_cache' ):
163+ torch .cuda .empty_cache ()
164+ #----------------------#
165+ # 前向传播
166+ #----------------------#
167+ outputs = net (im )
168+ loss = criterion (outputs ,label )
169+ test_loss += loss .data
170+ test_acc += get_acc (outputs ,label )
171+
172+ pbar2 .set_postfix (** {'Test Acc' : test_acc .item ()/ (step + 1 ),
173+ 'Test Loss' : test_loss .item () / (step + 1 )})
174+ pbar2 .update (1 )
175+ lr = optimizer .param_groups [0 ]['lr' ]
176+ test_acc = test_acc .item () * 100 / len (testloader )
177+ # Save checkpoint.
178+ if test_acc > best_acc :
179+ print ('Saving..' )
180+ state = {
181+ 'net' : net .state_dict (),
182+ 'acc' : test_acc ,
183+ 'epoch' : epoch ,
184+ 'lr' : lr ,
185+ }
186+ if not os .path .isdir ('checkpoint' ):
187+ os .mkdir ('checkpoint' )
188+ torch .save (state , './checkpoint/{}_ckpt.pth' .format (args .net ))
189+ best_acc = test_acc
190+
191+ print ('Finish Test' )
192+
193+ early_stopping (test_loss , net )
194+ # 若满足 early stopping 要求
195+ if early_stopping .early_stop :
196+ print ("Early stopping" )
197+ # 结束模型训练
198+ exit ()
199+
200+ for epoch in range (start_epoch , epochs ):
201+ train (epoch )
202+ test (epoch )
203+ scheduler .step ()
204+ torch .cuda .empty_cache ()
205+
0 commit comments