Skip to content

Commit 9e58d8a

Browse files
awaelchliRohit Gupta
authored andcommitted
Update warnings in TrainingTricksConnector (#9595)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent 22bd118 commit 9e58d8a

File tree

5 files changed

+55
-10
lines changed

5 files changed

+55
-10
lines changed

pytorch_lightning/trainer/connectors/training_trick_connector.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,35 @@ def __init__(self, trainer):
2424

2525
def on_trainer_init(
2626
self,
27-
gradient_clip_val: float,
27+
gradient_clip_val: Union[int, float],
2828
gradient_clip_algorithm: str,
2929
track_grad_norm: Union[int, float, str],
3030
accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
3131
truncated_bptt_steps: Optional[int],
3232
terminate_on_nan: bool,
3333
):
34-
35-
self.trainer.terminate_on_nan = terminate_on_nan
34+
if not isinstance(terminate_on_nan, bool):
35+
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")
3636

3737
# gradient clipping
38-
if gradient_clip_algorithm not in list(GradClipAlgorithmType):
39-
raise MisconfigurationException(f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}")
40-
self.trainer.gradient_clip_val = gradient_clip_val
41-
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
38+
if not isinstance(gradient_clip_val, (int, float)):
39+
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
40+
41+
if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
42+
raise MisconfigurationException(
43+
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "
44+
f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
45+
)
4246

4347
# gradient norm tracking
4448
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != "inf":
45-
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
49+
raise MisconfigurationException(
50+
f"`track_grad_norm` should be an int, a float or 'inf' (infinity norm). Got {track_grad_norm}."
51+
)
52+
53+
self.trainer.terminate_on_nan = terminate_on_nan
54+
self.trainer.gradient_clip_val = gradient_clip_val
55+
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower())
4656
self.trainer.track_grad_norm = float(track_grad_norm)
4757

4858
# accumulated grads

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(
109109
checkpoint_callback: bool = True,
110110
callbacks: Optional[Union[List[Callback], Callback]] = None,
111111
default_root_dir: Optional[str] = None,
112-
gradient_clip_val: float = 0.0,
112+
gradient_clip_val: Union[int, float] = 0.0,
113113
gradient_clip_algorithm: str = "norm",
114114
process_position: int = 0,
115115
num_nodes: int = 1,

pytorch_lightning/utilities/enums.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ class GradClipAlgorithmType(LightningEnum):
118118
VALUE = "value"
119119
NORM = "norm"
120120

121+
@staticmethod
122+
def supported_type(val: str) -> bool:
123+
return any(x.value == val for x in GradClipAlgorithmType)
124+
125+
@staticmethod
126+
def supported_types() -> List[str]:
127+
return [x.value for x in GradClipAlgorithmType]
128+
121129

122130
class AutoRestartBatchKeys(LightningEnum):
123131
"""

tests/trainer/test_trainer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,16 @@ def training_step(self, batch, batch_idx):
800800
assert torch.isfinite(param).all()
801801

802802

803+
def test_invalid_terminate_on_nan(tmpdir):
804+
with pytest.raises(TypeError, match="`terminate_on_nan` should be a bool"):
805+
Trainer(default_root_dir=tmpdir, terminate_on_nan="False")
806+
807+
808+
def test_invalid_track_grad_norm(tmpdir):
809+
with pytest.raises(MisconfigurationException, match="`track_grad_norm` should be an int, a float"):
810+
Trainer(default_root_dir=tmpdir, track_grad_norm="nan")
811+
812+
803813
def test_nan_params_detection(tmpdir):
804814
class CurrentModel(BoringModel):
805815
test_batch_nan = 3
@@ -1005,6 +1015,16 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde
10051015
trainer.fit(model)
10061016

10071017

1018+
def test_invalid_gradient_clip_value(tmpdir):
1019+
with pytest.raises(TypeError, match="`gradient_clip_val` should be an int or a float"):
1020+
Trainer(default_root_dir=tmpdir, gradient_clip_val=(1, 2))
1021+
1022+
1023+
def test_invalid_gradient_clip_algo(tmpdir):
1024+
with pytest.raises(MisconfigurationException, match="`gradient_clip_algorithm` norm2 is invalid"):
1025+
Trainer(default_root_dir=tmpdir, gradient_clip_algorithm="norm2")
1026+
1027+
10081028
def test_gpu_choice(tmpdir):
10091029
trainer_options = dict(default_root_dir=tmpdir)
10101030
# Only run if CUDA is available

tests/utilities/test_enums.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pytorch_lightning.utilities import DeviceType
1+
from pytorch_lightning.utilities.enums import DeviceType, GradClipAlgorithmType
22

33

44
def test_consistency():
@@ -9,3 +9,10 @@ def test_consistency():
99
# hash cannot be case invariant
1010
assert DeviceType.TPU not in {"TPU", "CPU"}
1111
assert DeviceType.TPU in {"tpu", "CPU"}
12+
13+
14+
def test_gradient_clip_algorithms():
15+
assert GradClipAlgorithmType.supported_types() == ["value", "norm"]
16+
assert GradClipAlgorithmType.supported_type("norm")
17+
assert GradClipAlgorithmType.supported_type("value")
18+
assert not GradClipAlgorithmType.supported_type("norm2")

0 commit comments

Comments
 (0)