Skip to content

Commit a6260f3

Browse files
committed
利用py文件训练
1 parent 8c83c13 commit a6260f3

File tree

4 files changed

+356
-111
lines changed

4 files changed

+356
-111
lines changed

CIFAR10_code/dataloader.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torchvision
3+
import torchvision.transforms as transforms
4+
5+
# Data
6+
def get_training_dataloader(batch_size = 64, num_workers = 4, shuffle = True):
7+
print('==> Preparing Train data..')
8+
transform_train = transforms.Compose([
9+
transforms.RandomCrop(32, padding=4),
10+
transforms.RandomHorizontalFlip(),
11+
transforms.ToTensor(),
12+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
13+
])
14+
trainset = torchvision.datasets.CIFAR10(
15+
root='./data', train=True, download=True, transform=transform_train)
16+
trainloader = torch.utils.data.DataLoader(
17+
trainset, batch_size=batch_size, shuffle=shuffle, num_workers= num_workers)
18+
return trainloader
19+
20+
def get_test_dataloader(batch_size = 64, num_workers = 4, shuffle = True):
21+
print('==> Preparing Test data..')
22+
transform_test = transforms.Compose([
23+
transforms.ToTensor(),
24+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
25+
])
26+
27+
testset = torchvision.datasets.CIFAR10(
28+
root='./data', train=False, download=True, transform=transform_test)
29+
testloader = torch.utils.data.DataLoader(
30+
testset, batch_size=batch_size, shuffle=shuffle, num_workers= num_workers)
31+
return testloader
32+

CIFAR10_code/nets/LeNet5.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
'''
2+
LetNet in Pytorch
3+
'''
4+
import torch
5+
import torch.nn as nn
6+
7+
class LeNet5(nn.Module):
8+
def __init__(self, num_classes = 10, init_weights=True):
9+
super(LeNet5,self).__init__()
10+
self.conv1 = nn.Sequential(
11+
# 输入 32x32x3 -> 28x28x6 (32-5)/1 + 1=28
12+
nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5,stride=1),
13+
nn.ReLU(),
14+
# 经过2x2的maxpool,变成14x14 (28-2)/2+1
15+
nn.MaxPool2d(kernel_size=2,stride=2)
16+
)
17+
18+
self.conv2 = nn.Sequential(
19+
# 输入 14x14x6 -> 10x10x16 (14-5)/1 + 1 = 10
20+
nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,stride=1),
21+
nn.ReLU(),
22+
# (10-2)/2 + 1 = 5
23+
nn.MaxPool2d(kernel_size=2,stride=2)
24+
)
25+
26+
self.fc = nn.Sequential(
27+
nn.Linear(5*5*16,120),
28+
nn.ReLU(),
29+
nn.Linear(120,84),
30+
nn.ReLU(),
31+
nn.Linear(84,num_classes)
32+
)
33+
if init_weights:
34+
self._initialize_weights()
35+
def forward(self,x):
36+
x = self.conv1(x)
37+
x = self.conv2(x)
38+
# 要把多维度的tensor展平成一维
39+
x = x.view(x.size()[0],-1)
40+
x = self.fc(x)
41+
return x
42+
43+
def _initialize_weights(self):
44+
for m in self.modules():
45+
if isinstance(m, nn.Conv2d):
46+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
47+
if m.bias is not None:
48+
nn.init.constant_(m.bias, 0)
49+
elif isinstance(m, nn.BatchNorm2d):
50+
nn.init.constant_(m.weight, 1)
51+
nn.init.constant_(m.bias, 0)
52+
elif isinstance(m, nn.Linear):
53+
nn.init.normal_(m.weight, 0, 0.01)
54+
nn.init.constant_(m.bias, 0)
55+
56+
def test():
57+
net = LeNet5()
58+
x = torch.randn(2,3,32,32)
59+
y = net(x)
60+
print(y.size())
61+
from torchinfo import summary
62+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
63+
net = net.to(device)
64+
summary(net,(2,3,32,32))
65+

CIFAR10_code/train.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
'''Train CIFAR10 with PyTorch.'''
2+
import imp
3+
import torch
4+
import torch.nn as nn
5+
import torch.optim as optim
6+
import torch.backends.cudnn as cudnn
7+
import torchvision
8+
import torchvision.transforms as transforms
9+
10+
import os
11+
import argparse
12+
from utils import get_acc,EarlyStopping
13+
from dataloader import get_test_dataloader, get_training_dataloader
14+
from tqdm import tqdm
15+
16+
17+
classes = ('plane', 'car', 'bird', 'cat', 'deer',
18+
'dog', 'frog', 'horse', 'ship', 'truck')
19+
20+
if __name__ == '__main__':
21+
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
22+
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
23+
parser.add_argument('--cuda', action='store_true', default=False, help =' use GPU?')
24+
parser.add_argument('--batch-size', default=64, type=int, help = "Batch Size for Training")
25+
parser.add_argument('--num-workers', default=2, type=int, help = 'num-workers')
26+
parser.add_argument('--net', type = str, default='MobileNetv1', help='net type')
27+
parser.add_argument('--epochs', type = int, default=20, help = 'Epochs')
28+
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
29+
parser.add_argument('--patience', '-p', type = int, default=7, help='patience for Early stop')
30+
args = parser.parse_args()
31+
32+
print(args)
33+
best_acc = 0 # best test accuracy
34+
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
35+
36+
# Train Data
37+
trainloader = get_training_dataloader(batch_size = args.batch_size, num_workers = args.num_workers)
38+
testloader = get_test_dataloader(batch_size = args.batch_size, num_workers = args.num_workers, shuffle=False)
39+
# Model
40+
print('==> Building model..')
41+
if args.net == 'VGG16':
42+
from nets.VGG import VGG
43+
net = VGG('VGG16')
44+
elif args.net == 'VGG19':
45+
from nets.VGG import VGG
46+
net = VGG('VGG19')
47+
elif args.net == 'ResNet18':
48+
from nets.ResNet import ResNet18
49+
net = ResNet18()
50+
elif args.net == 'ResNet34':
51+
from nets.ResNet import ResNet34
52+
net = ResNet34()
53+
elif args.net == 'LeNet':
54+
from nets.LeNet5 import LeNet5
55+
net = LeNet5()
56+
elif args.net == 'AlexNet':
57+
from nets.AlexNet import AlexNet
58+
net = AlexNet()
59+
elif args.net == 'DenseNet':
60+
from nets.DenseNet import densenet_cifar
61+
net = densenet_cifar()
62+
elif args.net == 'MobileNetv1':
63+
from nets.MobileNetv1 import MobileNet
64+
net = MobileNet()
65+
elif args.net == 'MobileNetv2':
66+
from nets.MobileNetv2 import MobileNetV2
67+
net = MobileNetV2()
68+
69+
if args.cuda and torch.cuda.is_available():
70+
device = 'cuda'
71+
net = torch.nn.DataParallel(net)
72+
cudnn.benchmark = True
73+
else:
74+
device = 'cpu'
75+
76+
77+
if args.resume:
78+
# Load checkpoint.
79+
print('==> Resuming from checkpoint..')
80+
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
81+
checkpoint = torch.load('./checkpoint/{}_ckpt.pth'.format(args.net))
82+
net.load_state_dict(checkpoint['net'])
83+
best_acc = checkpoint['acc']
84+
start_epoch = checkpoint['epoch']
85+
args.lr = checkpoint['lr']
86+
87+
early_stopping = EarlyStopping(patience = args.patience, verbose=True)
88+
criterion = nn.CrossEntropyLoss()
89+
optimizer = optim.SGD(net.parameters(), lr=args.lr,
90+
momentum=0.9, weight_decay=5e-4)
91+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
92+
93+
epochs = args.epochs
94+
def train(epoch):
95+
epoch_step = len(trainloader)
96+
if epoch_step == 0:
97+
raise ValueError("数据集过小,无法进行训练,请扩充数据集,或者减小batchsize")
98+
net.train()
99+
train_loss = 0
100+
train_acc = 0
101+
print('Start Train')
102+
with tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{epochs}',postfix=dict,mininterval=0.3) as pbar:
103+
for step,(im,label) in enumerate(trainloader,start=0):
104+
im = im.to(device)
105+
label = label.to(device)
106+
#---------------------
107+
# 释放内存
108+
#---------------------
109+
if hasattr(torch.cuda, 'empty_cache'):
110+
torch.cuda.empty_cache()
111+
#----------------------#
112+
# 清零梯度
113+
#----------------------#
114+
optimizer.zero_grad()
115+
#----------------------#
116+
# 前向传播forward
117+
#----------------------#
118+
outputs = net(im)
119+
#----------------------#
120+
# 计算损失
121+
#----------------------#
122+
loss = criterion(outputs,label)
123+
train_loss += loss.data
124+
train_acc += get_acc(outputs,label)
125+
#----------------------#
126+
# 反向传播
127+
#----------------------#
128+
# backward
129+
loss.backward()
130+
# 更新参数
131+
optimizer.step()
132+
lr = optimizer.param_groups[0]['lr']
133+
pbar.set_postfix(**{'Train Loss' : train_loss.item()/(step+1),
134+
'Train Acc' :train_acc.item()/(step+1),
135+
'Lr' : lr})
136+
pbar.update(1)
137+
# train_loss = train_loss.item() / len(trainloader)
138+
# train_acc = train_acc.item() * 100 / len(trainloader)
139+
print('Finish Train')
140+
def test(epoch):
141+
global best_acc
142+
epoch_step_test = len(testloader)
143+
if epoch_step_test == 0:
144+
raise ValueError("数据集过小,无法进行训练,请扩充数据集,或者减小batchsize")
145+
146+
net.eval()
147+
test_loss = 0
148+
test_acc = 0
149+
print('Start Test')
150+
#--------------------------------
151+
# 相同方法,同train
152+
#--------------------------------
153+
with tqdm(total=epoch_step_test,desc=f'Epoch {epoch + 1}/{epochs}',postfix=dict,mininterval=0.3) as pbar2:
154+
for step,(im,label) in enumerate(testloader,start=0):
155+
im = im.to(device)
156+
label = label.to(device)
157+
with torch.no_grad():
158+
if step >= epoch_step_test:
159+
break
160+
161+
# 释放内存
162+
if hasattr(torch.cuda, 'empty_cache'):
163+
torch.cuda.empty_cache()
164+
#----------------------#
165+
# 前向传播
166+
#----------------------#
167+
outputs = net(im)
168+
loss = criterion(outputs,label)
169+
test_loss += loss.data
170+
test_acc += get_acc(outputs,label)
171+
172+
pbar2.set_postfix(**{'Test Acc': test_acc.item()/(step+1),
173+
'Test Loss': test_loss.item() / (step + 1)})
174+
pbar2.update(1)
175+
lr = optimizer.param_groups[0]['lr']
176+
test_acc = test_acc.item() * 100 / len(testloader)
177+
# Save checkpoint.
178+
if test_acc > best_acc:
179+
print('Saving..')
180+
state = {
181+
'net': net.state_dict(),
182+
'acc': test_acc,
183+
'epoch': epoch,
184+
'lr': lr,
185+
}
186+
if not os.path.isdir('checkpoint'):
187+
os.mkdir('checkpoint')
188+
torch.save(state, './checkpoint/{}_ckpt.pth'.format(args.net))
189+
best_acc = test_acc
190+
191+
print('Finish Test')
192+
193+
early_stopping(test_loss, net)
194+
# 若满足 early stopping 要求
195+
if early_stopping.early_stop:
196+
print("Early stopping")
197+
# 结束模型训练
198+
exit()
199+
200+
for epoch in range(start_epoch, epochs):
201+
train(epoch)
202+
test(epoch)
203+
scheduler.step()
204+
torch.cuda.empty_cache()
205+

0 commit comments

Comments
 (0)