Skip to content

Commit 6f1015d

Browse files
authored
Merge pull request #17 from desi-ivanov/patch-1
Fix loss logging
2 parents 70c0eb9 + d948d37 commit 6f1015d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ def main():
148148
for (img, mask) in tqdm(data_loader):
149149
with accelerator.accumulate(model):
150150
loss = diffusion(mask, img)
151+
running_loss += loss.item() * img.size(0)
151152
accelerator.log({'loss': loss}) # Log loss to wandb
152153
accelerator.backward(loss)
153154
optimizer.step()
154155
optimizer.zero_grad()
155-
running_loss += loss.item() * img.size(0)
156156
counter += 1
157157
epoch_loss = running_loss / len(data_loader)
158158
print('Training Loss : {:.4f}'.format(epoch_loss))

0 commit comments

Comments
 (0)