Skip to content

Commit 4c3ce60

Browse files
awaelchlicarmocca
andauthored
Update precision input type annotations (#14857)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 6bf6540 commit 4c3ce60

40 files changed

+229
-195
lines changed

src/lightning_fabric/connector.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414
import os
1515
from collections import Counter
16-
from typing import Any, Dict, List, Optional, Union
16+
from typing import Any, cast, Dict, List, Optional, Union
1717

1818
import torch
19-
from typing_extensions import Literal
19+
from typing_extensions import get_args
2020

2121
from lightning_fabric.accelerators import ACCELERATOR_REGISTRY
2222
from lightning_fabric.accelerators.accelerator import Accelerator
@@ -41,6 +41,7 @@
4141
)
4242
from lightning_fabric.plugins.precision.double import DoublePrecision
4343
from lightning_fabric.plugins.precision.fsdp import FSDPPrecision
44+
from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR
4445
from lightning_fabric.strategies import (
4546
DDPShardedStrategy,
4647
DDPStrategy,
@@ -59,7 +60,6 @@
5960

6061
_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
6162
_PLUGIN_INPUT = Union[_PLUGIN, str]
62-
_PRECISION_INPUT = Literal[16, 32, 64, "bf16"]
6363

6464

6565
class _Connector:
@@ -113,14 +113,13 @@ def __init__(
113113
# Get registered strategies, built-in accelerators and precision plugins
114114
self._registered_strategies = STRATEGY_REGISTRY.available_strategies()
115115
self._registered_accelerators = ACCELERATOR_REGISTRY.available_accelerators()
116-
self._precision_types = ("16", "32", "64", "bf16")
117116

118117
# Raise an exception if there are conflicts between flags
119118
# Set each valid flag to `self._x_flag` after validation
120119
# For devices: Assign gpus, etc. to the accelerator flag and devices flag
121120
self._strategy_flag: Optional[Union[Strategy, str]] = None
122121
self._accelerator_flag: Optional[Union[Accelerator, str]] = None
123-
self._precision_input: Optional[_PRECISION_INPUT] = None
122+
self._precision_input: _PRECISION_INPUT_STR = "32"
124123
self._precision_instance: Optional[Precision] = None
125124
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
126125
self._parallel_devices: List[Union[int, torch.device, str]] = []
@@ -206,12 +205,10 @@ def _check_config_and_set_final_flags(
206205

207206
self._accelerator_flag = accelerator
208207

209-
if precision is not None:
210-
if str(precision) not in self._precision_types:
211-
raise ValueError(
212-
f"Precision {repr(precision)} is invalid. Allowed precision values: {self._precision_types}"
213-
)
214-
self._precision_input = precision
208+
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
209+
if precision not in supported_precision:
210+
raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}")
211+
self._precision_input = cast(_PRECISION_INPUT_STR, str(precision))
215212

216213
if plugins:
217214
plugins_flags_types: Dict[str, int] = Counter()
@@ -442,10 +439,10 @@ def _check_and_init_precision(self) -> Precision:
442439
return self._precision_instance
443440

444441
if isinstance(self.accelerator, TPUAccelerator):
445-
if self._precision_input == 32:
442+
if self._precision_input == "32":
446443
return TPUPrecision()
447-
elif self._precision_input in (16, "bf16"):
448-
if self._precision_input == 16:
444+
elif self._precision_input in ("16", "bf16"):
445+
if self._precision_input == "16":
449446
rank_zero_warn(
450447
"You passed `Fabric(accelerator='tpu', precision=16)` but AMP"
451448
" is not supported with TPUs. Using `precision='bf16'` instead."
@@ -454,22 +451,22 @@ def _check_and_init_precision(self) -> Precision:
454451
if isinstance(self.strategy, DeepSpeedStrategy):
455452
return DeepSpeedPrecision(self._precision_input) # type: ignore
456453

457-
if self._precision_input == 32:
454+
if self._precision_input == "32":
458455
return Precision()
459-
if self._precision_input == 64:
456+
if self._precision_input == "64":
460457
return DoublePrecision()
461458

462-
if self._precision_input == 16 and self._accelerator_flag == "cpu":
459+
if self._precision_input == "16" and self._accelerator_flag == "cpu":
463460
rank_zero_warn(
464461
"You passed `Fabric(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
465462
" Using `precision='bf16'` instead."
466463
)
467464
self._precision_input = "bf16"
468465

469-
if self._precision_input in (16, "bf16"):
466+
if self._precision_input in ("16", "bf16"):
470467
rank_zero_info(
471468
"Using 16-bit Automatic Mixed Precision (AMP)"
472-
if self._precision_input == 16
469+
if self._precision_input == "16"
473470
else "Using bfloat16 Automatic Mixed Precision (AMP)"
474471
)
475472
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
@@ -483,7 +480,7 @@ def _check_and_init_precision(self) -> Precision:
483480
def _validate_precision_choice(self) -> None:
484481
"""Validate the combination of choices for precision, and accelerator."""
485482
if isinstance(self.accelerator, TPUAccelerator):
486-
if self._precision_input == 64:
483+
if self._precision_input == "64":
487484
raise NotImplementedError(
488485
"`Fabric(accelerator='tpu', precision=64)` is not implemented."
489486
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
@@ -536,16 +533,12 @@ def _lazy_init_strategy(self) -> None:
536533

537534
@staticmethod
538535
def _argument_from_env(name: str, current: Any, default: Any) -> Any:
539-
env_value: Optional[Union[str, int]] = os.environ.get("LT_" + name.upper())
536+
env_value: Optional[str] = os.environ.get("LT_" + name.upper())
540537

541538
if env_value is None:
542539
return current
543540

544-
if name == "precision":
545-
# TODO: support precision input as string, then this special handling is not needed
546-
env_value = int(env_value) if env_value in ("16", "32", "64") else env_value
547-
548-
if env_value is not None and env_value != current and current != default:
541+
if env_value is not None and env_value != str(current) and str(current) != str(default):
549542
raise ValueError(
550543
f"Your code has `Fabric({name}={current!r}, ...)` but it conflicts with the value "
551544
f"`--{name}={current}` set through the CLI. "

src/lightning_fabric/plugins/precision/deepspeed.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,25 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, TYPE_CHECKING
14+
from typing import Any, cast, TYPE_CHECKING, Union
1515

1616
import torch
1717
from lightning_utilities.core.imports import RequirementCache
1818
from torch import Tensor
19-
from typing_extensions import Literal
19+
from typing_extensions import get_args, Literal
2020

2121
from lightning_fabric.plugins.precision.precision import Precision
2222
from lightning_fabric.plugins.precision.utils import _convert_fp_tensor
23-
from lightning_fabric.utilities.enums import PrecisionType
2423
from lightning_fabric.utilities.types import Steppable
2524

2625
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
2726
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
2827
import deepspeed
2928

29+
_PRECISION_INPUT_INT = Literal[32, 16]
30+
_PRECISION_INPUT_STR = Literal["32", "16", "bf16"]
31+
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
32+
3033

3134
class DeepSpeedPrecision(Precision):
3235
"""Precision plugin for DeepSpeed integration.
@@ -39,19 +42,17 @@ class DeepSpeedPrecision(Precision):
3942
If unsupported ``precision`` is provided.
4043
"""
4144

42-
def __init__(self, precision: Literal[16, 32, "bf16"]) -> None:
43-
supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT)
45+
def __init__(self, precision: _PRECISION_INPUT) -> None:
46+
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
4447
if precision not in supported_precision:
4548
raise ValueError(
4649
f"`precision={precision!r})` is not supported in DeepSpeed."
47-
f" `precision` must be one of: {(x.value for x in supported_precision)}."
50+
f" `precision` must be one of: {supported_precision}."
4851
)
49-
50-
super().__init__()
51-
self.precision = precision
52+
self.precision = cast(_PRECISION_INPUT_STR, str(precision))
5253

5354
def convert_input(self, data: Tensor) -> Tensor:
54-
precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16, 32: torch.float32}
55+
precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16, "32": torch.float32}
5556
dst_type = precision_to_type[self.precision]
5657
return _convert_fp_tensor(data, dst_type)
5758

src/lightning_fabric/plugins/precision/double.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
from torch import Tensor
1919
from torch.nn import Module
20+
from typing_extensions import Literal
2021

2122
from lightning_fabric.plugins.precision.precision import Precision
2223
from lightning_fabric.plugins.precision.utils import _convert_fp_tensor
@@ -25,7 +26,7 @@
2526
class DoublePrecision(Precision):
2627
"""Plugin for training with double (``torch.float64``) precision."""
2728

28-
precision: int = 64
29+
precision: Literal["64"] = "64"
2930

3031
def convert_module(self, module: Module) -> Module:
3132
return module.double()

src/lightning_fabric/plugins/precision/fsdp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing_extensions import Literal
1818

1919
from lightning_fabric.plugins.precision.native_amp import MixedPrecision
20-
from lightning_fabric.utilities.enums import PrecisionType
2120
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
2221

2322
if TYPE_CHECKING:
@@ -29,7 +28,7 @@ class FSDPPrecision(MixedPrecision):
2928
"""AMP for Fully Sharded Data Parallel training."""
3029

3130
def __init__(
32-
self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None
31+
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None
3332
) -> None:
3433
if not _TORCH_GREATER_EQUAL_1_12:
3534
raise NotImplementedError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.")
@@ -39,16 +38,16 @@ def __init__(
3938
super().__init__(
4039
precision=precision,
4140
device=device,
42-
scaler=(ShardedGradScaler() if scaler is None and precision == 16 else None),
41+
scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None),
4342
)
4443

4544
@property
4645
def mixed_precision_config(self) -> "TorchMixedPrecision":
4746
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
4847

49-
if self.precision == PrecisionType.HALF:
48+
if self.precision == "16":
5049
dtype = torch.float16
51-
elif self.precision == PrecisionType.BFLOAT:
50+
elif self.precision == "bf16":
5251
dtype = torch.bfloat16
5352
else:
5453
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")

src/lightning_fabric/plugins/precision/native_amp.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Any, Dict, Generator, Optional
15+
from typing import Any, cast, Dict, Generator, Optional
1616

1717
import torch
1818
from torch import Tensor
@@ -36,16 +36,15 @@ class MixedPrecision(Precision):
3636
"""
3737

3838
def __init__(
39-
self, precision: Literal[16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
39+
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
4040
) -> None:
41-
super().__init__()
42-
if scaler is None and precision == 16:
41+
self.precision = cast(Literal["16", "bf16"], str(precision))
42+
if scaler is None and self.precision == "16":
4343
with _patch_cuda_is_available():
4444
# if possible, we defer CUDA initialization to support strategies that will attempt forks
4545
scaler = torch.cuda.amp.GradScaler()
46-
if scaler is not None and precision == "bf16":
46+
if scaler is not None and self.precision == "bf16":
4747
raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")
48-
self.precision = precision
4948
self.device = device
5049
self.scaler = scaler
5150

@@ -55,7 +54,7 @@ def forward_context(self) -> Generator[None, None, None]:
5554
yield
5655

5756
def convert_input(self, data: Tensor) -> Tensor:
58-
precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16}
57+
precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16}
5958
dst_type = precision_to_type[self.precision]
6059
return _convert_fp_tensor(data, dst_type)
6160

src/lightning_fabric/plugins/precision/precision.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,23 @@
1818
from torch import Tensor
1919
from torch.nn import Module
2020
from torch.optim import Optimizer
21+
from typing_extensions import Literal
2122

2223
from lightning_fabric.plugins.precision.utils import _convert_fp_tensor
2324
from lightning_fabric.utilities.types import _PARAMETERS, Optimizable
2425

26+
_PRECISION_INPUT_INT = Literal[64, 32, 16]
27+
_PRECISION_INPUT_STR = Literal["64", "32", "16", "bf16"]
28+
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
29+
2530

2631
class Precision:
2732
"""Base class for all plugins handling the precision-specific parts of the training.
2833
2934
The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.
3035
"""
3136

32-
precision: Union[str, int] = 32
37+
precision: _PRECISION_INPUT_STR = "32"
3338

3439
def convert_module(self, module: Module) -> Module:
3540
"""Convert the module parameters to the precision type this plugin handles.

src/lightning_fabric/plugins/precision/tpu_bf16.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import torch
1717
from torch import Tensor
18+
from typing_extensions import Literal
1819

1920
from lightning_fabric.plugins.precision import TPUPrecision
2021
from lightning_fabric.plugins.precision.utils import _convert_fp_tensor
@@ -23,7 +24,7 @@
2324
class TPUBf16Precision(TPUPrecision):
2425
"""Plugin that enables bfloats on TPUs."""
2526

26-
precision: str = "bf16"
27+
precision: Literal["bf16"] = "bf16"
2728

2829
def __init__(self) -> None:
2930
super().__init__()

src/lightning_fabric/strategies/deepspeed.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from lightning_fabric.strategies.ddp import DDPStrategy
3232
from lightning_fabric.strategies.strategy import _Sharded
3333
from lightning_fabric.utilities.distributed import log
34-
from lightning_fabric.utilities.enums import PrecisionType
3534
from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only
3635
from lightning_fabric.utilities.seed import reset_seed
3736
from lightning_fabric.utilities.types import _PATH
@@ -349,9 +348,9 @@ def module_sharded_context(self) -> Generator[None, None, None]:
349348
if self.zero_stage_3:
350349
assert self._config_initialized
351350

352-
if self.precision.precision == PrecisionType.HALF:
351+
if self.precision.precision == "16":
353352
dtype = torch.float16
354-
elif self.precision.precision == PrecisionType.BFLOAT:
353+
elif self.precision.precision == "bf16":
355354
dtype = torch.bfloat16
356355
else:
357356
dtype = torch.float32
@@ -499,7 +498,7 @@ def _format_config(self) -> None:
499498

500499
def _format_precision_config(self) -> None:
501500
assert isinstance(self.config, dict)
502-
if self.precision.precision == PrecisionType.HALF:
501+
if self.precision.precision == "16":
503502
if "fp16" not in self.config:
504503
# FP16 is a DeepSpeed standalone AMP implementation
505504
rank_zero_info("Enabling DeepSpeed FP16.")
@@ -511,7 +510,7 @@ def _format_precision_config(self) -> None:
511510
"hysteresis": self.hysteresis,
512511
"min_loss_scale": self.min_loss_scale,
513512
}
514-
elif "bf16" not in self.config and self.precision.precision == PrecisionType.BFLOAT:
513+
elif "bf16" not in self.config and self.precision.precision == "bf16":
515514
rank_zero_info("Enabling DeepSpeed BF16.")
516515
self.config["bf16"] = {"enabled": True}
517516

src/lightning_fabric/strategies/fairscale.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from lightning_fabric.plugins.collectives.torch_collective import default_pg_timeout
2727
from lightning_fabric.strategies.ddp import DDPStrategy
2828
from lightning_fabric.strategies.strategy import _BackwardSyncControl
29-
from lightning_fabric.utilities.enums import PrecisionType
3029
from lightning_fabric.utilities.imports import _IS_WINDOWS
3130

3231
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn")
@@ -116,7 +115,7 @@ def _reinit_optimizers_with_oss(optimizers: List[Optimizer], precision: Precisio
116115
if not isinstance(optimizer, OSS):
117116
optim_class = type(optimizer)
118117
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
119-
is_fp16 = precision.precision in (PrecisionType.MIXED, PrecisionType.HALF)
118+
is_fp16 = precision.precision == "16"
120119
# For multi-node training, compressing the model shards in fp16 before broadcasting
121120
# improves performance. When using PyTorch AMP, it will not degrade
122121
# the model performance.

0 commit comments

Comments
 (0)