Skip to content

Commit b7c6162

Browse files
authored
Merge pull request #6 from dsbuddy/accelerated_train
Slight error in accelerator
2 parents 85538c7 + e58f12c commit b7c6162

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def main():
128128
with accelerator.accumulate(model):
129129
loss = diffusion(mask, img)
130130
accelerator.log({'loss': loss}) # Log loss to wandb
131-
loss.backward()
131+
accelerator.backward(loss)
132132
optimizer.step()
133133
optimizer.zero_grad()
134134
running_loss += loss.item() * img.size(0)

0 commit comments

Comments
 (0)