13
13
# limitations under the License.
14
14
import os
15
15
from collections import Counter
16
- from typing import Any , Dict , List , Optional , Union
16
+ from typing import Any , cast , Dict , List , Optional , Union
17
17
18
18
import torch
19
- from typing_extensions import Literal
19
+ from typing_extensions import get_args
20
20
21
21
from lightning_fabric .accelerators import ACCELERATOR_REGISTRY
22
22
from lightning_fabric .accelerators .accelerator import Accelerator
41
41
)
42
42
from lightning_fabric .plugins .precision .double import DoublePrecision
43
43
from lightning_fabric .plugins .precision .fsdp import FSDPPrecision
44
+ from lightning_fabric .plugins .precision .precision import _PRECISION_INPUT , _PRECISION_INPUT_INT , _PRECISION_INPUT_STR
44
45
from lightning_fabric .strategies import (
45
46
DDPShardedStrategy ,
46
47
DDPStrategy ,
59
60
60
61
_PLUGIN = Union [Precision , ClusterEnvironment , CheckpointIO ]
61
62
_PLUGIN_INPUT = Union [_PLUGIN , str ]
62
- _PRECISION_INPUT = Literal [16 , 32 , 64 , "bf16" ]
63
63
64
64
65
65
class _Connector :
@@ -113,14 +113,13 @@ def __init__(
113
113
# Get registered strategies, built-in accelerators and precision plugins
114
114
self ._registered_strategies = STRATEGY_REGISTRY .available_strategies ()
115
115
self ._registered_accelerators = ACCELERATOR_REGISTRY .available_accelerators ()
116
- self ._precision_types = ("16" , "32" , "64" , "bf16" )
117
116
118
117
# Raise an exception if there are conflicts between flags
119
118
# Set each valid flag to `self._x_flag` after validation
120
119
# For devices: Assign gpus, etc. to the accelerator flag and devices flag
121
120
self ._strategy_flag : Optional [Union [Strategy , str ]] = None
122
121
self ._accelerator_flag : Optional [Union [Accelerator , str ]] = None
123
- self ._precision_input : Optional [ _PRECISION_INPUT ] = None
122
+ self ._precision_input : _PRECISION_INPUT_STR = "32"
124
123
self ._precision_instance : Optional [Precision ] = None
125
124
self ._cluster_environment_flag : Optional [Union [ClusterEnvironment , str ]] = None
126
125
self ._parallel_devices : List [Union [int , torch .device , str ]] = []
@@ -206,12 +205,10 @@ def _check_config_and_set_final_flags(
206
205
207
206
self ._accelerator_flag = accelerator
208
207
209
- if precision is not None :
210
- if str (precision ) not in self ._precision_types :
211
- raise ValueError (
212
- f"Precision { repr (precision )} is invalid. Allowed precision values: { self ._precision_types } "
213
- )
214
- self ._precision_input = precision
208
+ supported_precision = get_args (_PRECISION_INPUT_STR ) + get_args (_PRECISION_INPUT_INT )
209
+ if precision not in supported_precision :
210
+ raise ValueError (f"Precision { repr (precision )} is invalid. Allowed precision values: { supported_precision } " )
211
+ self ._precision_input = cast (_PRECISION_INPUT_STR , str (precision ))
215
212
216
213
if plugins :
217
214
plugins_flags_types : Dict [str , int ] = Counter ()
@@ -442,10 +439,10 @@ def _check_and_init_precision(self) -> Precision:
442
439
return self ._precision_instance
443
440
444
441
if isinstance (self .accelerator , TPUAccelerator ):
445
- if self ._precision_input == 32 :
442
+ if self ._precision_input == "32" :
446
443
return TPUPrecision ()
447
- elif self ._precision_input in (16 , "bf16" ):
448
- if self ._precision_input == 16 :
444
+ elif self ._precision_input in ("16" , "bf16" ):
445
+ if self ._precision_input == "16" :
449
446
rank_zero_warn (
450
447
"You passed `Fabric(accelerator='tpu', precision=16)` but AMP"
451
448
" is not supported with TPUs. Using `precision='bf16'` instead."
@@ -454,22 +451,22 @@ def _check_and_init_precision(self) -> Precision:
454
451
if isinstance (self .strategy , DeepSpeedStrategy ):
455
452
return DeepSpeedPrecision (self ._precision_input ) # type: ignore
456
453
457
- if self ._precision_input == 32 :
454
+ if self ._precision_input == "32" :
458
455
return Precision ()
459
- if self ._precision_input == 64 :
456
+ if self ._precision_input == "64" :
460
457
return DoublePrecision ()
461
458
462
- if self ._precision_input == 16 and self ._accelerator_flag == "cpu" :
459
+ if self ._precision_input == "16" and self ._accelerator_flag == "cpu" :
463
460
rank_zero_warn (
464
461
"You passed `Fabric(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
465
462
" Using `precision='bf16'` instead."
466
463
)
467
464
self ._precision_input = "bf16"
468
465
469
- if self ._precision_input in (16 , "bf16" ):
466
+ if self ._precision_input in ("16" , "bf16" ):
470
467
rank_zero_info (
471
468
"Using 16-bit Automatic Mixed Precision (AMP)"
472
- if self ._precision_input == 16
469
+ if self ._precision_input == "16"
473
470
else "Using bfloat16 Automatic Mixed Precision (AMP)"
474
471
)
475
472
device = "cpu" if self ._accelerator_flag == "cpu" else "cuda"
@@ -483,7 +480,7 @@ def _check_and_init_precision(self) -> Precision:
483
480
def _validate_precision_choice (self ) -> None :
484
481
"""Validate the combination of choices for precision, and accelerator."""
485
482
if isinstance (self .accelerator , TPUAccelerator ):
486
- if self ._precision_input == 64 :
483
+ if self ._precision_input == "64" :
487
484
raise NotImplementedError (
488
485
"`Fabric(accelerator='tpu', precision=64)` is not implemented."
489
486
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
@@ -536,16 +533,12 @@ def _lazy_init_strategy(self) -> None:
536
533
537
534
@staticmethod
538
535
def _argument_from_env (name : str , current : Any , default : Any ) -> Any :
539
- env_value : Optional [Union [ str , int ] ] = os .environ .get ("LT_" + name .upper ())
536
+ env_value : Optional [str ] = os .environ .get ("LT_" + name .upper ())
540
537
541
538
if env_value is None :
542
539
return current
543
540
544
- if name == "precision" :
545
- # TODO: support precision input as string, then this special handling is not needed
546
- env_value = int (env_value ) if env_value in ("16" , "32" , "64" ) else env_value
547
-
548
- if env_value is not None and env_value != current and current != default :
541
+ if env_value is not None and env_value != str (current ) and str (current ) != str (default ):
549
542
raise ValueError (
550
543
f"Your code has `Fabric({ name } ={ current !r} , ...)` but it conflicts with the value "
551
544
f"`--{ name } ={ current } ` set through the CLI. "
0 commit comments