1+
2+ import os
3+ import torch
4+ import torch .nn as nn
5+ import argparse
6+ import numpy as np
7+ from tqdm import tqdm
8+ from dataloader import get_test_dataloader
9+ import torch .backends .cudnn as cudnn
10+
11+ def eval_top1 (outputs , label ):
12+ total = outputs .shape [0 ]
13+ outputs = torch .softmax (outputs , dim = - 1 )
14+ _ , pred_y = outputs .data .max (dim = 1 ) # 得到概率
15+ correct = (pred_y == label ).sum ().data
16+ return correct / total
17+
18+ def eval_top5 (outputs , label ):
19+ total = outputs .shape [0 ]
20+ outputs = torch .softmax (outputs , dim = - 1 )
21+ pred_y = np .argsort (- outputs .cpu ().numpy ())
22+ pred_y_top5 = pred_y [:,:5 ]
23+ correct = 0
24+ for i in range (total ):
25+ if label [i ].cpu ().numpy () in pred_y_top5 [i ]:
26+ correct += 1
27+ return correct / total
28+
29+
30+ if __name__ == '__main__' :
31+ parser = argparse .ArgumentParser (description = 'PyTorch CIFAR10 Test' )
32+ parser .add_argument ('--cuda' , action = 'store_true' , default = False , help = ' use GPU?' )
33+ parser .add_argument ('--batch-size' , default = 64 , type = int , help = "Batch Size for Test" )
34+ parser .add_argument ('--num-workers' , default = 2 , type = int , help = 'num-workers' )
35+ parser .add_argument ('--net' , type = str , default = 'MobileNetv1' , help = 'net type' )
36+ args = parser .parse_args ()
37+ testloader = get_test_dataloader ()
38+
39+ # Model
40+ print ('==> Building model..' )
41+ if args .net == 'VGG16' :
42+ from nets .VGG import VGG
43+ net = VGG ('VGG16' )
44+ elif args .net == 'VGG19' :
45+ from nets .VGG import VGG
46+ net = VGG ('VGG19' )
47+ elif args .net == 'ResNet18' :
48+ from nets .ResNet import ResNet18
49+ net = ResNet18 ()
50+ elif args .net == 'ResNet34' :
51+ from nets .ResNet import ResNet34
52+ net = ResNet34 ()
53+ elif args .net == 'LeNet' :
54+ from nets .LeNet5 import LeNet5
55+ net = LeNet5 ()
56+ elif args .net == 'AlexNet' :
57+ from nets .AlexNet import AlexNet
58+ net = AlexNet ()
59+ elif args .net == 'DenseNet' :
60+ from nets .DenseNet import densenet_cifar
61+ net = densenet_cifar ()
62+ elif args .net == 'MobileNetv1' :
63+ from nets .MobileNetv1 import MobileNet
64+ net = MobileNet ()
65+ elif args .net == 'MobileNetv2' :
66+ from nets .MobileNetv2 import MobileNetV2
67+ net = MobileNetV2 ()
68+
69+ if args .cuda and torch .cuda .is_available ():
70+ device = 'cuda'
71+ net = torch .nn .DataParallel (net )
72+ cudnn .benchmark = True
73+ else :
74+ device = 'cpu'
75+
76+ criterion = nn .CrossEntropyLoss ()
77+
78+ # Load checkpoint.
79+ print ('==> Resuming from checkpoint..' )
80+ assert os .path .isdir ('checkpoint' ), 'Error: no checkpoint directory found!'
81+ checkpoint = torch .load ('./checkpoint/{}_ckpt.pth' .format (args .net ))
82+ net .load_state_dict (checkpoint ['net' ])
83+
84+
85+ epoch_step_test = len (testloader )
86+ if epoch_step_test == 0 :
87+ raise ValueError ("数据集过小,无法进行训练,请扩充数据集,或者减小batchsize" )
88+
89+ net .eval ()
90+ test_acc_top1 = 0
91+ test_acc_top5 = 0
92+ print ('Start Test' )
93+ #--------------------------------
94+ # 相同方法,同train
95+ #--------------------------------
96+ with tqdm (total = epoch_step_test ,desc = f'Test Acc' ,postfix = dict ,mininterval = 0.3 ) as pbar2 :
97+ for step ,(im ,label ) in enumerate (testloader ,start = 0 ):
98+ im = im .to (device )
99+ label = label .to (device )
100+ with torch .no_grad ():
101+ if step >= epoch_step_test :
102+ break
103+
104+ # 释放内存
105+ if hasattr (torch .cuda , 'empty_cache' ):
106+ torch .cuda .empty_cache ()
107+ #----------------------#
108+ # 前向传播
109+ #----------------------#
110+ outputs = net (im )
111+ loss = criterion (outputs ,label )
112+ test_acc_top1 += eval_top1 (outputs ,label )
113+ test_acc_top5 += eval_top5 (outputs ,label )
114+ pbar2 .set_postfix (** {'Test Acc Top1' : test_acc_top1 .item ()/ (step + 1 ),
115+ 'Test Acc Top5' : test_acc_top5 / (step + 1 )})
116+ pbar2 .update (1 )
117+
118+ top1 = test_acc_top1 .item ()/ len (testloader )
119+ top5 = test_acc_top5 / len (testloader )
120+ print ("top-1 accuracy = %.2f%%" % (top1 * 100 ))
121+ print ("top-5 accuracy = %.2f%%" % (top5 * 100 ))
0 commit comments