Skip to content

Commit 798cba0

Browse files
committed
Update utils2.py
1 parent 79bbb1a commit 798cba0

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

CIFAR10_code/utils2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ def train(epoch, epochs, model, dataloader, criterion, optimizer, scheduler = No
8989

9090

9191
acc = get_acc(output,target)
92-
running_accuracy += acc
93-
running_loss += loss.data
92+
running_accuracy += acc.item()
93+
running_loss += loss.item()
9494
loss.backward()
9595
optimizer.step()
9696

97-
pbar.set_postfix(**{'Train Acc' : running_accuracy.item()/(step+1),
98-
'Train Loss' :running_loss.item()/(step+1)})
97+
pbar.set_postfix(**{'Train Acc' : running_accuracy/(step+1),
98+
'Train Loss' :running_loss/(step+1)})
9999
pbar.update(1)
100100
if scheduler:
101101
scheduler.step(running_loss)
@@ -136,10 +136,10 @@ def evaluation(epoch, epochs, model, dataloader, criterion):
136136
loss = criterion(output, target)
137137
acc = get_acc(output,target)
138138

139-
test_accuracy += acc
139+
test_accuracy += acc.item()
140140
test_loss += loss.item()
141141

142-
pbar.set_postfix(**{'Eval Acc' : test_accuracy.item()/(step+1),
142+
pbar.set_postfix(**{'Eval Acc' : test_accuracy/(step+1),
143143
'Eval Loss' :test_loss/(step+1)})
144144
pbar.update(1)
145145

0 commit comments

Comments
 (0)