diff --git a/docs/source/framework/callbacks.rst b/docs/source/framework/callbacks.rst index c3e7f415a5..ace77553d8 100644 --- a/docs/source/framework/callbacks.rst +++ b/docs/source/framework/callbacks.rst @@ -22,7 +22,7 @@ We offer several pre-written callbacks which are ready to be used out of the box BaseCSVWriter EarlyStopping GarbageCollector - IterationTimeLogger + IterationTimeLogger Lambda LearningRateMonitor MemorySnapshot diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 54a9126b49..b246a470f4 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -11,7 +11,7 @@ import logging from abc import ABCMeta, abstractmethod from copy import deepcopy -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import ( Any, Callable, @@ -85,6 +85,29 @@ class SWALRParams: swa_lrs: Union[List[float], float] = 0.05 +@dataclass +class GradScalerParams: + """ + Dataclass to store parameters for gradient scaling. + + Args: + init_scale: Initial scale factor. Default: 2.**16 (65536) + growth_factor: Factor by which the scale is multiplied during update + if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + Default: 2.0 + backoff_factor: Factor by which the scale is multiplied during update + if inf/NaN gradients occur in an iteration. Default: 0.5 + growth_interval: Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + Default: 2000 + """ + + init_scale: float = 2.0**16 + growth_factor: float = 2.0 + backoff_factor: float = 0.5 + growth_interval: int = 2000 + + @dataclass class SWAParams: """ @@ -517,6 +540,7 @@ def __init__( zero_grad_at_train_step_start: bool = False, global_mesh: Optional[GlobalMeshCoordinator] = None, enable_loss_parallel: bool = False, + grad_scaler_params: Optional[GradScalerParams] = None, ) -> None: super().__init__( module=module, @@ -565,10 +589,15 @@ def __init__( ) self.grad_scaler: Optional[GradScaler] = None + if grad_scaler_params is not None: + grad_scaler_kwargs = asdict(grad_scaler_params) + else: + grad_scaler_kwargs = None if self.precision: self.grad_scaler = get_grad_scaler_from_precision( self.precision, is_fsdp1_module=_is_fsdp1_module(self.module), + grad_scaler_kwargs=grad_scaler_kwargs, ) self.step_lr_interval = step_lr_interval diff --git a/torchtnt/utils/precision.py b/torchtnt/utils/precision.py index 279c18faf4..3a83513b24 100644 --- a/torchtnt/utils/precision.py +++ b/torchtnt/utils/precision.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Mapping, Optional +from typing import Any, Mapping, Optional import torch from torch.amp.grad_scaler import GradScaler @@ -38,7 +38,10 @@ def convert_precision_str_to_dtype(precision: str) -> Optional[torch.dtype]: def get_grad_scaler_from_precision( - precision: torch.dtype, *, is_fsdp1_module: Optional[bool] = False + precision: torch.dtype, + *, + is_fsdp1_module: Optional[bool] = False, + grad_scaler_kwargs: Optional[dict[str, Any]] = None, ) -> Optional[GradScaler]: """ Returns the correct grad scaler to use based on the precision and whether @@ -48,16 +51,27 @@ def get_grad_scaler_from_precision( Args: precision: the precision being used is_fsdp1_module: whether the grad scaler is for an FSDP1 module + grad_scaler_kwargs: optional parameters for configuring the grad scaler + (init_scale, growth_factor, backoff_factor, growth_interval) Returns: The appropriate grad scaler to use, ``None`` if no grad scaler should be used. """ if precision == torch.float16: + + if grad_scaler_kwargs is None: + grad_scaler_kwargs = {} + if is_fsdp1_module: from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler - return ShardedGradScaler() + return ShardedGradScaler( + **grad_scaler_kwargs, + ) else: - return GradScaler("cuda") + return GradScaler( + device="cuda", + **grad_scaler_kwargs, + ) return None