@@ -144,4 +144,55 @@ def evaluation(epoch, epochs, model, dataloader, criterion):
144144 pbar .update (1 )
145145
146146 test_loss , test_accuracy = test_loss / eval_step , test_accuracy / eval_step
147- return test_loss , test_accuracy
147+ return test_loss , test_accuracy
148+
149+ def test (model , dataloader ):
150+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
151+ correct = 0 # 定义预测正确的图片数,初始化为0
152+ total = 0 # 总共参与测试的图片数,也初始化为0
153+ model .eval ()
154+ with torch .no_grad ():
155+ for data in dataloader : # 循环每一个batch
156+ images , labels = data
157+ images = images .to (device )
158+ labels = labels .to (device )
159+ model .eval () # 把模型转为test模式
160+ if hasattr (torch .cuda , 'empty_cache' ):
161+ torch .cuda .empty_cache ()
162+ outputs = model (images ) # 输入网络进行测试
163+
164+ # outputs.data是一个4x10张量,将每一行的最大的那一列的值和序号各自组成一个一维张量返回,第一个是值的张量,第二个是序号的张量。
165+ _ , predicted = torch .max (outputs .data , 1 )
166+ total += labels .size (0 ) # 更新测试图片的数量
167+ correct += (predicted == labels ).sum () # 更新正确分类的图片的数量
168+
169+ print ('Accuracy of the network on the %d test images: %.2f %%' % (total , 100 * correct / total ))
170+
171+
172+ def test_precls (model , dataloader , classes ):
173+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
174+ # 定义2个存储每类中测试正确的个数的 列表,初始化为0
175+ class_correct = list (0. for i in range (10 ))
176+ class_total = list (0. for i in range (10 ))
177+ # testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=True, num_workers=2)
178+ model .eval ()
179+ with torch .no_grad ():
180+ for data in dataloader :
181+ images , labels = data
182+ images = images .to (device )
183+ labels = labels .to (device )
184+ if hasattr (torch .cuda , 'empty_cache' ):
185+ torch .cuda .empty_cache ()
186+ outputs = model (images )
187+
188+ _ , predicted = torch .max (outputs .data , 1 )
189+
190+ c = (predicted == labels ).squeeze ()
191+ for i in range (len (images )): # 因为每个batch都有4张图片,所以还需要一个4的小循环
192+ label = labels [i ] # 对各个类的进行各自累加
193+ class_correct [label ] += c [i ]
194+ class_total [label ] += 1
195+
196+
197+ for i in range (10 ):
198+ print ('Accuracy of %5s : %.2f %%' % (classes [i ], 100 * class_correct [i ] / class_total [i ]))
0 commit comments