We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 70c0eb9 + d948d37 commit 6f1015dCopy full SHA for 6f1015d
driver.py
@@ -148,11 +148,11 @@ def main():
148
for (img, mask) in tqdm(data_loader):
149
with accelerator.accumulate(model):
150
loss = diffusion(mask, img)
151
+ running_loss += loss.item() * img.size(0)
152
accelerator.log({'loss': loss}) # Log loss to wandb
153
accelerator.backward(loss)
154
optimizer.step()
155
optimizer.zero_grad()
- running_loss += loss.item() * img.size(0)
156
counter += 1
157
epoch_loss = running_loss / len(data_loader)
158
print('Training Loss : {:.4f}'.format(epoch_loss))
0 commit comments