@@ -51,14 +51,14 @@ def compute_norms(sample_grads):
51
51
batch_size = sample_grads [0 ].shape [0 ]
52
52
norms = [sample_grad .view (batch_size , - 1 ).norm (2 , dim = - 1 ) for sample_grad in sample_grads ]
53
53
norms = torch .stack (norms , dim = 0 ).norm (2 , dim = 0 )
54
- return norms
54
+ return norms , batch_size
55
55
56
56
57
57
def clip_and_accumulate_and_add_noise (model , max_per_sample_grad_norm = 1.0 , noise_multiplier = 1.0 ):
58
58
sample_grads = tuple (param .grad_sample for param in model .parameters ())
59
59
60
60
# step 0: compute the norms
61
- sample_norms = compute_norms (sample_grads )
61
+ sample_norms , batch_size = compute_norms (sample_grads )
62
62
63
63
# step 1: compute clipping factors
64
64
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6 )
@@ -76,7 +76,7 @@ def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise
76
76
77
77
# step 4: assign the new grads, delete the sample grads
78
78
for param , param_grad in zip (model .parameters (), grads ):
79
- param .grad = param_grad
79
+ param .grad = param_grad / batch_size
80
80
del param .grad_sample
81
81
82
82
@@ -492,4 +492,4 @@ def parse_args():
492
492
493
493
494
494
if __name__ == "__main__" :
495
- main ()
495
+ main ()
0 commit comments