|
11 | 11 | import logging |
12 | 12 | from abc import ABCMeta, abstractmethod |
13 | 13 | from copy import deepcopy |
14 | | -from dataclasses import dataclass |
| 14 | +from dataclasses import asdict, dataclass |
15 | 15 | from typing import ( |
16 | 16 | Any, |
17 | 17 | Callable, |
@@ -85,6 +85,29 @@ class SWALRParams: |
85 | 85 | swa_lrs: Union[List[float], float] = 0.05 |
86 | 86 |
|
87 | 87 |
|
| 88 | +@dataclass |
| 89 | +class GradScalerParams: |
| 90 | + """ |
| 91 | + Dataclass to store parameters for gradient scaling. |
| 92 | +
|
| 93 | + Args: |
| 94 | + init_scale: Initial scale factor. Default: 2.**16 (65536) |
| 95 | + growth_factor: Factor by which the scale is multiplied during update |
| 96 | + if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. |
| 97 | + Default: 2.0 |
| 98 | + backoff_factor: Factor by which the scale is multiplied during update |
| 99 | + if inf/NaN gradients occur in an iteration. Default: 0.5 |
| 100 | + growth_interval: Number of consecutive iterations without inf/NaN gradients |
| 101 | + that must occur for the scale to be multiplied by ``growth_factor``. |
| 102 | + Default: 2000 |
| 103 | + """ |
| 104 | + |
| 105 | + init_scale: float = 2.0**16 |
| 106 | + growth_factor: float = 2.0 |
| 107 | + backoff_factor: float = 0.5 |
| 108 | + growth_interval: int = 2000 |
| 109 | + |
| 110 | + |
88 | 111 | @dataclass |
89 | 112 | class SWAParams: |
90 | 113 | """ |
@@ -517,6 +540,7 @@ def __init__( |
517 | 540 | zero_grad_at_train_step_start: bool = False, |
518 | 541 | global_mesh: Optional[GlobalMeshCoordinator] = None, |
519 | 542 | enable_loss_parallel: bool = False, |
| 543 | + grad_scaler_params: Optional[GradScalerParams] = None, |
520 | 544 | ) -> None: |
521 | 545 | super().__init__( |
522 | 546 | module=module, |
@@ -565,10 +589,15 @@ def __init__( |
565 | 589 | ) |
566 | 590 |
|
567 | 591 | self.grad_scaler: Optional[GradScaler] = None |
| 592 | + if grad_scaler_params is not None: |
| 593 | + grad_scaler_kwargs = asdict(grad_scaler_params) |
| 594 | + else: |
| 595 | + grad_scaler_kwargs = None |
568 | 596 | if self.precision: |
569 | 597 | self.grad_scaler = get_grad_scaler_from_precision( |
570 | 598 | self.precision, |
571 | 599 | is_fsdp1_module=_is_fsdp1_module(self.module), |
| 600 | + grad_scaler_kwargs=grad_scaler_kwargs, |
572 | 601 | ) |
573 | 602 |
|
574 | 603 | self.step_lr_interval = step_lr_interval |
|
0 commit comments