Skip to content

Commit 7e749ce

Browse files
committed
机器学习大作业 ConvNeXt篇
机器学习大作业 ConvNeXt篇
1 parent a6603e1 commit 7e749ce

7 files changed

+978
-0
lines changed

ConvNeXt/dataloader.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import torchvision
3+
import torchvision.transforms as transforms
4+
5+
# Data
6+
def get_training_dataloader(batch_size = 64, num_workers = 4, shuffle = True, resize = 32):
7+
print('==> Preparing Train data..')
8+
transform_train = transforms.Compose([
9+
transforms.Resize(resize),
10+
# transforms.RandomCrop(32, padding=4),
11+
# transforms.RandomHorizontalFlip(),
12+
transforms.ToTensor(),
13+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
14+
])
15+
trainset = torchvision.datasets.CIFAR10(
16+
root='./data', train=True, download=True, transform=transform_train)
17+
trainloader = torch.utils.data.DataLoader(
18+
trainset, batch_size=batch_size, shuffle=shuffle, num_workers= num_workers)
19+
return trainloader
20+
21+
def get_test_dataloader(batch_size = 64, num_workers = 4, shuffle = True, resize = 32):
22+
print('==> Preparing Test data..')
23+
transform_test = transforms.Compose([
24+
transforms.Resize(resize),
25+
transforms.ToTensor(),
26+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
27+
])
28+
29+
testset = torchvision.datasets.CIFAR10(
30+
root='./data', train=False, download=True, transform=transform_test)
31+
testloader = torch.utils.data.DataLoader(
32+
testset, batch_size=batch_size, shuffle=shuffle, num_workers= num_workers)
33+
return testloader
34+

ConvNeXt/eval.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
2+
import os
3+
import torch
4+
import torch.nn as nn
5+
import argparse
6+
import numpy as np
7+
from tqdm import tqdm
8+
from dataloader import get_test_dataloader
9+
import torch.backends.cudnn as cudnn
10+
11+
def eval_top1(outputs, label):
12+
total = outputs.shape[0]
13+
outputs = torch.softmax(outputs, dim=-1)
14+
_, pred_y = outputs.data.max(dim=1) # 得到概率
15+
correct = (pred_y == label).sum().data
16+
return correct / total
17+
18+
def eval_top5(outputs, label):
19+
total = outputs.shape[0]
20+
outputs = torch.softmax(outputs, dim=-1)
21+
pred_y = np.argsort(-outputs.cpu().numpy())
22+
pred_y_top5 = pred_y[:,:5]
23+
correct = 0
24+
for i in range(total):
25+
if label[i].cpu().numpy() in pred_y_top5[i]:
26+
correct += 1
27+
return correct / total
28+
29+
30+
if __name__ == '__main__':
31+
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Test')
32+
parser.add_argument('--cuda', action='store_true', default=False, help =' use GPU?')
33+
parser.add_argument('--batch-size', default=64, type=int, help = "Batch Size for Test")
34+
parser.add_argument('--num-workers', default=2, type=int, help = 'num-workers')
35+
parser.add_argument('--net', type = str, choices=['LeNet5', 'AlexNet', 'VGG16','VGG19','ResNet18','ResNet34',
36+
'DenseNet','MobileNetv1','MobileNetv2','ResNeXt','ConvNeXt-T','ConvNeXt-S','ConvNeXt-B'], default='MobileNetv1', help='net type')
37+
parser.add_argument('--resize',type=int,default=32)
38+
args = parser.parse_args()
39+
testloader = get_test_dataloader(batch_size = args.batch_size, num_workers = args.num_workers, shuffle=False, resize = args.resize)
40+
41+
# Model
42+
print('==> Building model..')
43+
if args.net == 'VGG16':
44+
from nets.VGG import VGG
45+
net = VGG('VGG16')
46+
elif args.net == 'VGG19':
47+
from nets.VGG import VGG
48+
net = VGG('VGG19')
49+
elif args.net == 'ResNet18':
50+
from nets.ResNet import ResNet18
51+
net = ResNet18()
52+
elif args.net == 'ResNet34':
53+
from nets.ResNet import ResNet34
54+
net = ResNet34()
55+
elif args.net == 'LeNet5':
56+
from nets.LeNet5 import LeNet5
57+
net = LeNet5()
58+
elif args.net == 'AlexNet':
59+
from nets.AlexNet import AlexNet
60+
net = AlexNet()
61+
elif args.net == 'DenseNet':
62+
from nets.DenseNet import densenet_cifar
63+
net = densenet_cifar()
64+
elif args.net == 'MobileNetv1':
65+
from nets.MobileNetv1 import MobileNet
66+
net = MobileNet()
67+
elif args.net == 'MobileNetv2':
68+
from nets.MobileNetv2 import MobileNetV2
69+
net = MobileNetV2()
70+
elif args.net == 'ResNeXt':
71+
from nets.ResNeXt import ResNeXt50
72+
net = ResNeXt50(10)
73+
elif args.net == 'ConvNeXt-T':
74+
from nets.ConvNeXt import convnext_tiny
75+
net = convnext_tiny(10)
76+
elif args.net == 'ConvNeXt-S':
77+
from nets.ConvNeXt import convnext_small
78+
net = convnext_small(10)
79+
elif args.net == 'ConvNeXt-B':
80+
from nets.ConvNeXt import convnext_base
81+
net = convnext_base(10)
82+
83+
from torchinfo import summary
84+
summary(net,(2,3,224,224))
85+
86+
criterion = nn.CrossEntropyLoss()
87+
88+
# Load checkpoint.
89+
print('==> Resuming from checkpoint..')
90+
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
91+
checkpoint = torch.load('./checkpoint/{}_ckpt.pth'.format(args.net))
92+
93+
weights_dict = {}
94+
for k, v in checkpoint['net'].items():
95+
new_k = k.replace('module.', '') if 'module' in k else k
96+
weights_dict[new_k] = v
97+
98+
net.load_state_dict(weights_dict)
99+
100+
if args.cuda and torch.cuda.is_available():
101+
device = 'cuda'
102+
net = torch.nn.DataParallel(net)
103+
cudnn.benchmark = True
104+
else:
105+
device = 'cpu'
106+
107+
epoch_step_test = len(testloader)
108+
if epoch_step_test == 0:
109+
raise ValueError("数据集过小,无法进行训练,请扩充数据集,或者减小batchsize")
110+
111+
net.eval()
112+
test_acc_top1 = 0
113+
test_acc_top5 = 0
114+
print('Start Test')
115+
#--------------------------------
116+
# 相同方法,同train
117+
#--------------------------------
118+
with tqdm(total=epoch_step_test,desc=f'Test Acc',postfix=dict,mininterval=0.3) as pbar2:
119+
for step,(im,label) in enumerate(testloader,start=0):
120+
im = im.to(device)
121+
label = label.to(device)
122+
with torch.no_grad():
123+
if step >= epoch_step_test:
124+
break
125+
126+
# 释放内存
127+
if hasattr(torch.cuda, 'empty_cache'):
128+
torch.cuda.empty_cache()
129+
#----------------------#
130+
# 前向传播
131+
#----------------------#
132+
outputs = net(im)
133+
loss = criterion(outputs,label)
134+
test_acc_top1 += eval_top1(outputs,label)
135+
test_acc_top5 += eval_top5(outputs,label)
136+
pbar2.set_postfix(**{'Test Acc Top1': test_acc_top1.item()/(step+1),
137+
'Test Acc Top5': test_acc_top5 / (step + 1)})
138+
pbar2.update(1)
139+
140+
top1 = test_acc_top1.item()/ len(testloader)
141+
top5 = test_acc_top5 / len(testloader)
142+
print("top-1 accuracy = %.2f%%" % (top1*100))
143+
print("top-5 accuracy = %.2f%%" % (top5*100))

0 commit comments

Comments
 (0)