Skip to content

Commit c5d0f74

Browse files
committed
fix bug about batch_count (#60)
1 parent a978169 commit c5d0f74

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

code/d2lzh_pytorch/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,8 @@ def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epo
230230
net = net.to(device)
231231
print("training on ", device)
232232
loss = torch.nn.CrossEntropyLoss()
233-
batch_count = 0
234233
for epoch in range(num_epochs):
235-
train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
234+
train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
236235
for X, y in train_iter:
237236
X = X.to(device)
238237
y = y.to(device)

docs/chapter05_CNN/5.5_lenet.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,8 @@ def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epo
131131
net = net.to(device)
132132
print("training on ", device)
133133
loss = torch.nn.CrossEntropyLoss()
134-
batch_count = 0
135134
for epoch in range(num_epochs):
136-
train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
135+
train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
137136
for X, y in train_iter:
138137
X = X.to(device)
139138
y = y.to(device)

0 commit comments

Comments
 (0)