Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/framework/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
BaseCSVWriter
EarlyStopping
GarbageCollector
IterationTimeLogger
IterationTimeLogger
Lambda
LearningRateMonitor
MemorySnapshot
Expand Down
31 changes: 30 additions & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -85,6 +85,29 @@ class SWALRParams:
swa_lrs: Union[List[float], float] = 0.05


@dataclass
class GradScalerParams:
"""
Dataclass to store parameters for gradient scaling.

Args:
init_scale: Initial scale factor. Default: 2.**16 (65536)
growth_factor: Factor by which the scale is multiplied during update
if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
Default: 2.0
backoff_factor: Factor by which the scale is multiplied during update
if inf/NaN gradients occur in an iteration. Default: 0.5
growth_interval: Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
Default: 2000
"""

init_scale: float = 2.0**16
growth_factor: float = 2.0
backoff_factor: float = 0.5
growth_interval: int = 2000


@dataclass
class SWAParams:
"""
Expand Down Expand Up @@ -517,6 +540,7 @@ def __init__(
zero_grad_at_train_step_start: bool = False,
global_mesh: Optional[GlobalMeshCoordinator] = None,
enable_loss_parallel: bool = False,
grad_scaler_params: Optional[GradScalerParams] = None,
) -> None:
super().__init__(
module=module,
Expand Down Expand Up @@ -565,10 +589,15 @@ def __init__(
)

self.grad_scaler: Optional[GradScaler] = None
if grad_scaler_params is not None:
grad_scaler_kwargs = asdict(grad_scaler_params)
else:
grad_scaler_kwargs = None
if self.precision:
self.grad_scaler = get_grad_scaler_from_precision(
self.precision,
is_fsdp1_module=_is_fsdp1_module(self.module),
grad_scaler_kwargs=grad_scaler_kwargs,
)

self.step_lr_interval = step_lr_interval
Expand Down
22 changes: 18 additions & 4 deletions torchtnt/utils/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Mapping, Optional
from typing import Any, Mapping, Optional

import torch
from torch.amp.grad_scaler import GradScaler
Expand Down Expand Up @@ -38,7 +38,10 @@ def convert_precision_str_to_dtype(precision: str) -> Optional[torch.dtype]:


def get_grad_scaler_from_precision(
precision: torch.dtype, *, is_fsdp1_module: Optional[bool] = False
precision: torch.dtype,
*,
is_fsdp1_module: Optional[bool] = False,
grad_scaler_kwargs: Optional[dict[str, Any]] = None,
) -> Optional[GradScaler]:
"""
Returns the correct grad scaler to use based on the precision and whether
Expand All @@ -48,16 +51,27 @@ def get_grad_scaler_from_precision(
Args:
precision: the precision being used
is_fsdp1_module: whether the grad scaler is for an FSDP1 module
grad_scaler_kwargs: optional parameters for configuring the grad scaler
(init_scale, growth_factor, backoff_factor, growth_interval)

Returns:
The appropriate grad scaler to use, ``None`` if no grad scaler should be used.
"""

if precision == torch.float16:

if grad_scaler_kwargs is None:
grad_scaler_kwargs = {}

if is_fsdp1_module:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

return ShardedGradScaler()
return ShardedGradScaler(
**grad_scaler_kwargs,
)
else:
return GradScaler("cuda")
return GradScaler(
device="cuda",
**grad_scaler_kwargs,
)
return None