Skip to content

Commit 227858a

Browse files
committed
计算top1和top5误差
1 parent ae05d9b commit 227858a

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

CIFAR10_code/eval.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
2+
import os
3+
import torch
4+
import torch.nn as nn
5+
import argparse
6+
import numpy as np
7+
from tqdm import tqdm
8+
from dataloader import get_test_dataloader
9+
import torch.backends.cudnn as cudnn
10+
11+
def eval_top1(outputs, label):
12+
total = outputs.shape[0]
13+
outputs = torch.softmax(outputs, dim=-1)
14+
_, pred_y = outputs.data.max(dim=1) # 得到概率
15+
correct = (pred_y == label).sum().data
16+
return correct / total
17+
18+
def eval_top5(outputs, label):
19+
total = outputs.shape[0]
20+
outputs = torch.softmax(outputs, dim=-1)
21+
pred_y = np.argsort(-outputs.cpu().numpy())
22+
pred_y_top5 = pred_y[:,:5]
23+
correct = 0
24+
for i in range(total):
25+
if label[i].cpu().numpy() in pred_y_top5[i]:
26+
correct += 1
27+
return correct / total
28+
29+
30+
if __name__ == '__main__':
31+
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Test')
32+
parser.add_argument('--cuda', action='store_true', default=False, help =' use GPU?')
33+
parser.add_argument('--batch-size', default=64, type=int, help = "Batch Size for Test")
34+
parser.add_argument('--num-workers', default=2, type=int, help = 'num-workers')
35+
parser.add_argument('--net', type = str, default='MobileNetv1', help='net type')
36+
args = parser.parse_args()
37+
testloader = get_test_dataloader()
38+
39+
# Model
40+
print('==> Building model..')
41+
if args.net == 'VGG16':
42+
from nets.VGG import VGG
43+
net = VGG('VGG16')
44+
elif args.net == 'VGG19':
45+
from nets.VGG import VGG
46+
net = VGG('VGG19')
47+
elif args.net == 'ResNet18':
48+
from nets.ResNet import ResNet18
49+
net = ResNet18()
50+
elif args.net == 'ResNet34':
51+
from nets.ResNet import ResNet34
52+
net = ResNet34()
53+
elif args.net == 'LeNet':
54+
from nets.LeNet5 import LeNet5
55+
net = LeNet5()
56+
elif args.net == 'AlexNet':
57+
from nets.AlexNet import AlexNet
58+
net = AlexNet()
59+
elif args.net == 'DenseNet':
60+
from nets.DenseNet import densenet_cifar
61+
net = densenet_cifar()
62+
elif args.net == 'MobileNetv1':
63+
from nets.MobileNetv1 import MobileNet
64+
net = MobileNet()
65+
elif args.net == 'MobileNetv2':
66+
from nets.MobileNetv2 import MobileNetV2
67+
net = MobileNetV2()
68+
69+
if args.cuda and torch.cuda.is_available():
70+
device = 'cuda'
71+
net = torch.nn.DataParallel(net)
72+
cudnn.benchmark = True
73+
else:
74+
device = 'cpu'
75+
76+
criterion = nn.CrossEntropyLoss()
77+
78+
# Load checkpoint.
79+
print('==> Resuming from checkpoint..')
80+
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
81+
checkpoint = torch.load('./checkpoint/{}_ckpt.pth'.format(args.net))
82+
net.load_state_dict(checkpoint['net'])
83+
84+
85+
epoch_step_test = len(testloader)
86+
if epoch_step_test == 0:
87+
raise ValueError("数据集过小,无法进行训练,请扩充数据集,或者减小batchsize")
88+
89+
net.eval()
90+
test_acc_top1 = 0
91+
test_acc_top5 = 0
92+
print('Start Test')
93+
#--------------------------------
94+
# 相同方法,同train
95+
#--------------------------------
96+
with tqdm(total=epoch_step_test,desc=f'Test Acc',postfix=dict,mininterval=0.3) as pbar2:
97+
for step,(im,label) in enumerate(testloader,start=0):
98+
im = im.to(device)
99+
label = label.to(device)
100+
with torch.no_grad():
101+
if step >= epoch_step_test:
102+
break
103+
104+
# 释放内存
105+
if hasattr(torch.cuda, 'empty_cache'):
106+
torch.cuda.empty_cache()
107+
#----------------------#
108+
# 前向传播
109+
#----------------------#
110+
outputs = net(im)
111+
loss = criterion(outputs,label)
112+
test_acc_top1 += eval_top1(outputs,label)
113+
test_acc_top5 += eval_top5(outputs,label)
114+
pbar2.set_postfix(**{'Test Acc Top1': test_acc_top1.item()/(step+1),
115+
'Test Acc Top5': test_acc_top5 / (step + 1)})
116+
pbar2.update(1)
117+
118+
top1 = test_acc_top1.item()/ len(testloader)
119+
top5 = test_acc_top5 / len(testloader)
120+
print("top-1 accuracy = %.2f%%" % (top1*100))
121+
print("top-5 accuracy = %.2f%%" % (top5*100))

0 commit comments

Comments
 (0)