We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 70c0eb9 commit d948d37Copy full SHA for d948d37
driver.py
@@ -148,11 +148,11 @@ def main():
148
for (img, mask) in tqdm(data_loader):
149
with accelerator.accumulate(model):
150
loss = diffusion(mask, img)
151
+ running_loss += loss.item() * img.size(0)
152
accelerator.log({'loss': loss}) # Log loss to wandb
153
accelerator.backward(loss)
154
optimizer.step()
155
optimizer.zero_grad()
- running_loss += loss.item() * img.size(0)
156
counter += 1
157
epoch_loss = running_loss / len(data_loader)
158
print('Training Loss : {:.4f}'.format(epoch_loss))
0 commit comments