Skip to content

Commit 52615ae

Browse files
committed
Update utils2.py
1 parent 798cba0 commit 52615ae

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

CIFAR10_code/utils2.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,55 @@ def evaluation(epoch, epochs, model, dataloader, criterion):
144144
pbar.update(1)
145145

146146
test_loss, test_accuracy = test_loss/eval_step, test_accuracy/eval_step
147-
return test_loss, test_accuracy
147+
return test_loss, test_accuracy
148+
149+
def test(model, dataloader):
150+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
151+
correct = 0 # 定义预测正确的图片数,初始化为0
152+
total = 0 # 总共参与测试的图片数,也初始化为0
153+
model.eval()
154+
with torch.no_grad():
155+
for data in dataloader: # 循环每一个batch
156+
images, labels = data
157+
images = images.to(device)
158+
labels = labels.to(device)
159+
model.eval() # 把模型转为test模式
160+
if hasattr(torch.cuda, 'empty_cache'):
161+
torch.cuda.empty_cache()
162+
outputs = model(images) # 输入网络进行测试
163+
164+
# outputs.data是一个4x10张量,将每一行的最大的那一列的值和序号各自组成一个一维张量返回,第一个是值的张量,第二个是序号的张量。
165+
_, predicted = torch.max(outputs.data, 1)
166+
total += labels.size(0) # 更新测试图片的数量
167+
correct += (predicted == labels).sum() # 更新正确分类的图片的数量
168+
169+
print('Accuracy of the network on the %d test images: %.2f %%' % (total, 100 * correct / total))
170+
171+
172+
def test_precls(model, dataloader, classes):
173+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
174+
# 定义2个存储每类中测试正确的个数的 列表,初始化为0
175+
class_correct = list(0. for i in range(10))
176+
class_total = list(0. for i in range(10))
177+
# testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=True, num_workers=2)
178+
model.eval()
179+
with torch.no_grad():
180+
for data in dataloader:
181+
images, labels = data
182+
images = images.to(device)
183+
labels = labels.to(device)
184+
if hasattr(torch.cuda, 'empty_cache'):
185+
torch.cuda.empty_cache()
186+
outputs = model(images)
187+
188+
_, predicted = torch.max(outputs.data, 1)
189+
190+
c = (predicted == labels).squeeze()
191+
for i in range(len(images)): # 因为每个batch都有4张图片,所以还需要一个4的小循环
192+
label = labels[i] # 对各个类的进行各自累加
193+
class_correct[label] += c[i]
194+
class_total[label] += 1
195+
196+
197+
for i in range(10):
198+
print('Accuracy of %5s : %.2f %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

0 commit comments

Comments
 (0)