1+
2+ import os
3+ import torch
4+ import torch .nn as nn
5+ import argparse
6+ import numpy as np
7+ from tqdm import tqdm
8+ from dataloader import get_test_dataloader
9+ import torch .backends .cudnn as cudnn
10+
11+ def eval_top1 (outputs , label ):
12+ total = outputs .shape [0 ]
13+ outputs = torch .softmax (outputs , dim = - 1 )
14+ _ , pred_y = outputs .data .max (dim = 1 ) # 得到概率
15+ correct = (pred_y == label ).sum ().data
16+ return correct / total
17+
18+ def eval_top5 (outputs , label ):
19+ total = outputs .shape [0 ]
20+ outputs = torch .softmax (outputs , dim = - 1 )
21+ pred_y = np .argsort (- outputs .cpu ().numpy ())
22+ pred_y_top5 = pred_y [:,:5 ]
23+ correct = 0
24+ for i in range (total ):
25+ if label [i ].cpu ().numpy () in pred_y_top5 [i ]:
26+ correct += 1
27+ return correct / total
28+
29+
30+ if __name__ == '__main__' :
31+ parser = argparse .ArgumentParser (description = 'PyTorch CIFAR10 Test' )
32+ parser .add_argument ('--cuda' , action = 'store_true' , default = False , help = ' use GPU?' )
33+ parser .add_argument ('--batch-size' , default = 64 , type = int , help = "Batch Size for Test" )
34+ parser .add_argument ('--num-workers' , default = 2 , type = int , help = 'num-workers' )
35+ parser .add_argument ('--net' , type = str , choices = ['LeNet5' , 'AlexNet' , 'VGG16' ,'VGG19' ,'ResNet18' ,'ResNet34' ,
36+ 'DenseNet' ,'MobileNetv1' ,'MobileNetv2' ,'ResNeXt' ,'ConvNeXt-T' ,'ConvNeXt-S' ,'ConvNeXt-B' ], default = 'MobileNetv1' , help = 'net type' )
37+ parser .add_argument ('--resize' ,type = int ,default = 32 )
38+ args = parser .parse_args ()
39+ testloader = get_test_dataloader (batch_size = args .batch_size , num_workers = args .num_workers , shuffle = False , resize = args .resize )
40+
41+ # Model
42+ print ('==> Building model..' )
43+ if args .net == 'VGG16' :
44+ from nets .VGG import VGG
45+ net = VGG ('VGG16' )
46+ elif args .net == 'VGG19' :
47+ from nets .VGG import VGG
48+ net = VGG ('VGG19' )
49+ elif args .net == 'ResNet18' :
50+ from nets .ResNet import ResNet18
51+ net = ResNet18 ()
52+ elif args .net == 'ResNet34' :
53+ from nets .ResNet import ResNet34
54+ net = ResNet34 ()
55+ elif args .net == 'LeNet5' :
56+ from nets .LeNet5 import LeNet5
57+ net = LeNet5 ()
58+ elif args .net == 'AlexNet' :
59+ from nets .AlexNet import AlexNet
60+ net = AlexNet ()
61+ elif args .net == 'DenseNet' :
62+ from nets .DenseNet import densenet_cifar
63+ net = densenet_cifar ()
64+ elif args .net == 'MobileNetv1' :
65+ from nets .MobileNetv1 import MobileNet
66+ net = MobileNet ()
67+ elif args .net == 'MobileNetv2' :
68+ from nets .MobileNetv2 import MobileNetV2
69+ net = MobileNetV2 ()
70+ elif args .net == 'ResNeXt' :
71+ from nets .ResNeXt import ResNeXt50
72+ net = ResNeXt50 (10 )
73+ elif args .net == 'ConvNeXt-T' :
74+ from nets .ConvNeXt import convnext_tiny
75+ net = convnext_tiny (10 )
76+ elif args .net == 'ConvNeXt-S' :
77+ from nets .ConvNeXt import convnext_small
78+ net = convnext_small (10 )
79+ elif args .net == 'ConvNeXt-B' :
80+ from nets .ConvNeXt import convnext_base
81+ net = convnext_base (10 )
82+
83+ from torchinfo import summary
84+ summary (net ,(2 ,3 ,224 ,224 ))
85+
86+ criterion = nn .CrossEntropyLoss ()
87+
88+ # Load checkpoint.
89+ print ('==> Resuming from checkpoint..' )
90+ assert os .path .isdir ('checkpoint' ), 'Error: no checkpoint directory found!'
91+ checkpoint = torch .load ('./checkpoint/{}_ckpt.pth' .format (args .net ))
92+
93+ weights_dict = {}
94+ for k , v in checkpoint ['net' ].items ():
95+ new_k = k .replace ('module.' , '' ) if 'module' in k else k
96+ weights_dict [new_k ] = v
97+
98+ net .load_state_dict (weights_dict )
99+
100+ if args .cuda and torch .cuda .is_available ():
101+ device = 'cuda'
102+ net = torch .nn .DataParallel (net )
103+ cudnn .benchmark = True
104+ else :
105+ device = 'cpu'
106+
107+ epoch_step_test = len (testloader )
108+ if epoch_step_test == 0 :
109+ raise ValueError ("数据集过小,无法进行训练,请扩充数据集,或者减小batchsize" )
110+
111+ net .eval ()
112+ test_acc_top1 = 0
113+ test_acc_top5 = 0
114+ print ('Start Test' )
115+ #--------------------------------
116+ # 相同方法,同train
117+ #--------------------------------
118+ with tqdm (total = epoch_step_test ,desc = f'Test Acc' ,postfix = dict ,mininterval = 0.3 ) as pbar2 :
119+ for step ,(im ,label ) in enumerate (testloader ,start = 0 ):
120+ im = im .to (device )
121+ label = label .to (device )
122+ with torch .no_grad ():
123+ if step >= epoch_step_test :
124+ break
125+
126+ # 释放内存
127+ if hasattr (torch .cuda , 'empty_cache' ):
128+ torch .cuda .empty_cache ()
129+ #----------------------#
130+ # 前向传播
131+ #----------------------#
132+ outputs = net (im )
133+ loss = criterion (outputs ,label )
134+ test_acc_top1 += eval_top1 (outputs ,label )
135+ test_acc_top5 += eval_top5 (outputs ,label )
136+ pbar2 .set_postfix (** {'Test Acc Top1' : test_acc_top1 .item ()/ (step + 1 ),
137+ 'Test Acc Top5' : test_acc_top5 / (step + 1 )})
138+ pbar2 .update (1 )
139+
140+ top1 = test_acc_top1 .item ()/ len (testloader )
141+ top5 = test_acc_top5 / len (testloader )
142+ print ("top-1 accuracy = %.2f%%" % (top1 * 100 ))
143+ print ("top-5 accuracy = %.2f%%" % (top5 * 100 ))
0 commit comments