Skip to content

Commit ac639af

Browse files
EnayatUllahfacebook-github-bot
authored andcommitted
Fixed Opacus's Runtime error with an empty batch (issue 612) (#631)
Summary: Pull Request resolved: #631 In case of an empty batch, in the ```clip_and_accumulate``` function, the ```per_sample_clip_factor``` variable is set to a tensor of size 0. However, the device was not specified, which throws a runtime error. Added it. Reviewed By: HuanyuZhang Differential Revision: D53733081 fbshipit-source-id: 9435d4dc1a7f37852bd52d2507c37b7ca1ef11a9
1 parent a7c2853 commit ac639af

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

opacus/optimizers/optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,9 @@ def clip_and_accumulate(self):
396396

397397
if len(self.grad_samples[0]) == 0:
398398
# Empty batch
399-
per_sample_clip_factor = torch.zeros((0,))
399+
per_sample_clip_factor = torch.zeros(
400+
(0,), device=self.grad_samples[0].device
401+
)
400402
else:
401403
per_param_norms = [
402404
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples

0 commit comments

Comments
 (0)