Skip to content

Commit 2b16530

Browse files
authored
fix dp-sgd example (#873)
see [this issue](#467)
1 parent 056ff1f commit 2b16530

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/dp_cifar10/cifar10_transforms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ def compute_norms(sample_grads):
5151
batch_size = sample_grads[0].shape[0]
5252
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
5353
norms = torch.stack(norms, dim=0).norm(2, dim=0)
54-
return norms
54+
return norms, batch_size
5555

5656

5757
def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
5858
sample_grads = tuple(param.grad_sample for param in model.parameters())
5959

6060
# step 0: compute the norms
61-
sample_norms = compute_norms(sample_grads)
61+
sample_norms, batch_size = compute_norms(sample_grads)
6262

6363
# step 1: compute clipping factors
6464
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
@@ -76,7 +76,7 @@ def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise
7676

7777
# step 4: assign the new grads, delete the sample grads
7878
for param, param_grad in zip(model.parameters(), grads):
79-
param.grad = param_grad
79+
param.grad = param_grad/batch_size
8080
del param.grad_sample
8181

8282

@@ -492,4 +492,4 @@ def parse_args():
492492

493493

494494
if __name__ == "__main__":
495-
main()
495+
main()

0 commit comments

Comments
 (0)