Skip to content

Commit d948d37

Browse files
authored
Fix loss logging
1 parent 70c0eb9 commit d948d37

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ def main():
148148
for (img, mask) in tqdm(data_loader):
149149
with accelerator.accumulate(model):
150150
loss = diffusion(mask, img)
151+
running_loss += loss.item() * img.size(0)
151152
accelerator.log({'loss': loss}) # Log loss to wandb
152153
accelerator.backward(loss)
153154
optimizer.step()
154155
optimizer.zero_grad()
155-
running_loss += loss.item() * img.size(0)
156156
counter += 1
157157
epoch_loss = running_loss / len(data_loader)
158158
print('Training Loss : {:.4f}'.format(epoch_loss))

0 commit comments

Comments
 (0)