Skip to content

Commit 8dd9b80

Browse files
authored
Fix gradient clipping (#1438)
* Fix gradient clipping * Relax accuracy constraint
1 parent b2707c9 commit 8dd9b80

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

pytorch_lightning/trainer/training_tricks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def clip_gradients(self):
4141
total_norm = torch.zeros([], device=device if parameters else None)
4242
for p in parameters:
4343
param_norm = p.grad.data.norm(norm_type) ** norm_type
44-
total_norm.add_(param_norm)
44+
total_norm.add_(param_norm)
4545
total_norm = (total_norm ** (1. / norm_type))
4646
eps = EPSILON_FP16 if self.precision == 16 else EPSILON
4747
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)

tests/trainer/test_trainer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,3 +658,30 @@ def on_batch_start(self, trainer, pl_module):
658658
assert not trainer.interrupted
659659
trainer.fit(model)
660660
assert trainer.interrupted
661+
662+
663+
def test_gradient_clipping(tmpdir):
664+
"""
665+
Test gradient clipping
666+
"""
667+
tutils.reset_seed()
668+
669+
hparams = tutils.get_default_hparams()
670+
model = LightningTestModel(hparams)
671+
672+
# test that gradient is clipped correctly
673+
def _optimizer_step(*args, **kwargs):
674+
parameters = model.parameters()
675+
grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
676+
assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm)
677+
678+
trainer = Trainer(max_steps=1,
679+
max_epochs=1,
680+
gradient_clip_val=1.0,
681+
default_save_path=tmpdir)
682+
683+
# for the test
684+
model.optimizer_step = _optimizer_step
685+
model.prev_called_batch_idx = 0
686+
687+
trainer.fit(model)

0 commit comments

Comments
 (0)