Skip to content

Commit 799ac57

Browse files
manuelknottfacebook-github-bot
authored andcommitted
Configurable GradScaler Parameters (#1038)
Summary: TorchTNT's `AutoUnit` initializes a GradScaler if `precision == torch.float16` and uses default parameters of `from torch.amp.grad_scaler.GradScaler` or `torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler` respectively. Some projects require custom arguments for their GradScaler, thus, I propose to expose these parameters and make them configurable. I propose to implement them as a dataclass analoguous to `SWALRParams` and `SWAParams`. Also added integration tests in `test_precision.py`. Differential Revision: D86090032
1 parent b83990d commit 799ac57

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212
from abc import ABCMeta, abstractmethod
1313
from copy import deepcopy
14-
from dataclasses import dataclass
14+
from dataclasses import asdict, dataclass
1515
from typing import (
1616
Any,
1717
Callable,
@@ -85,6 +85,29 @@ class SWALRParams:
8585
swa_lrs: Union[List[float], float] = 0.05
8686

8787

88+
@dataclass
89+
class GradScalerParams:
90+
"""
91+
Dataclass to store parameters for gradient scaling.
92+
93+
Args:
94+
init_scale: Initial scale factor. Default: 2.**16 (65536)
95+
growth_factor: Factor by which the scale is multiplied during update
96+
if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
97+
Default: 2.0
98+
backoff_factor: Factor by which the scale is multiplied during update
99+
if inf/NaN gradients occur in an iteration. Default: 0.5
100+
growth_interval: Number of consecutive iterations without inf/NaN gradients
101+
that must occur for the scale to be multiplied by ``growth_factor``.
102+
Default: 2000
103+
"""
104+
105+
init_scale: float = 2.0**16
106+
growth_factor: float = 2.0
107+
backoff_factor: float = 0.5
108+
growth_interval: int = 2000
109+
110+
88111
@dataclass
89112
class SWAParams:
90113
"""
@@ -517,6 +540,7 @@ def __init__(
517540
zero_grad_at_train_step_start: bool = False,
518541
global_mesh: Optional[GlobalMeshCoordinator] = None,
519542
enable_loss_parallel: bool = False,
543+
grad_scaler_params: Optional[GradScalerParams] = None,
520544
) -> None:
521545
super().__init__(
522546
module=module,
@@ -565,10 +589,15 @@ def __init__(
565589
)
566590

567591
self.grad_scaler: Optional[GradScaler] = None
592+
if grad_scaler_params is not None:
593+
grad_scaler_kwargs = asdict(grad_scaler_params)
594+
else:
595+
grad_scaler_kwargs = None
568596
if self.precision:
569597
self.grad_scaler = get_grad_scaler_from_precision(
570598
self.precision,
571599
is_fsdp1_module=_is_fsdp1_module(self.module),
600+
grad_scaler_kwargs=grad_scaler_kwargs,
572601
)
573602

574603
self.step_lr_interval = step_lr_interval

torchtnt/utils/precision.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# pyre-strict
99

10-
from typing import Mapping, Optional
10+
from typing import Any, Mapping, Optional
1111

1212
import torch
1313
from torch.amp.grad_scaler import GradScaler
@@ -38,7 +38,10 @@ def convert_precision_str_to_dtype(precision: str) -> Optional[torch.dtype]:
3838

3939

4040
def get_grad_scaler_from_precision(
41-
precision: torch.dtype, *, is_fsdp1_module: Optional[bool] = False
41+
precision: torch.dtype,
42+
*,
43+
is_fsdp1_module: Optional[bool] = False,
44+
grad_scaler_kwargs: Optional[dict[str, Any]] = None,
4245
) -> Optional[GradScaler]:
4346
"""
4447
Returns the correct grad scaler to use based on the precision and whether
@@ -48,16 +51,27 @@ def get_grad_scaler_from_precision(
4851
Args:
4952
precision: the precision being used
5053
is_fsdp1_module: whether the grad scaler is for an FSDP1 module
54+
grad_scaler_kwargs: optional parameters for configuring the grad scaler
55+
(init_scale, growth_factor, backoff_factor, growth_interval)
5156
5257
Returns:
5358
The appropriate grad scaler to use, ``None`` if no grad scaler should be used.
5459
"""
5560

5661
if precision == torch.float16:
62+
63+
if grad_scaler_kwargs is None:
64+
grad_scaler_kwargs = {}
65+
5766
if is_fsdp1_module:
5867
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
5968

60-
return ShardedGradScaler()
69+
return ShardedGradScaler(
70+
**grad_scaler_kwargs,
71+
)
6172
else:
62-
return GradScaler("cuda")
73+
return GradScaler(
74+
device="cuda",
75+
**grad_scaler_kwargs,
76+
)
6377
return None

0 commit comments

Comments
 (0)