1313# limitations under the License.
1414import os
1515from collections import Counter
16- from typing import Any , Dict , List , Optional , Union
16+ from typing import Any , cast , Dict , List , Optional , Union
1717
1818import torch
19- from typing_extensions import Literal
19+ from typing_extensions import get_args
2020
2121from lightning_fabric .accelerators import ACCELERATOR_REGISTRY
2222from lightning_fabric .accelerators .accelerator import Accelerator
4141)
4242from lightning_fabric .plugins .precision .double import DoublePrecision
4343from lightning_fabric .plugins .precision .fsdp import FSDPPrecision
44+ from lightning_fabric .plugins .precision .precision import _PRECISION_INPUT , _PRECISION_INPUT_INT , _PRECISION_INPUT_STR
4445from lightning_fabric .strategies import (
4546 DDPShardedStrategy ,
4647 DDPStrategy ,
5960
6061_PLUGIN = Union [Precision , ClusterEnvironment , CheckpointIO ]
6162_PLUGIN_INPUT = Union [_PLUGIN , str ]
62- _PRECISION_INPUT = Literal [16 , 32 , 64 , "bf16" ]
6363
6464
6565class _Connector :
@@ -113,14 +113,13 @@ def __init__(
113113 # Get registered strategies, built-in accelerators and precision plugins
114114 self ._registered_strategies = STRATEGY_REGISTRY .available_strategies ()
115115 self ._registered_accelerators = ACCELERATOR_REGISTRY .available_accelerators ()
116- self ._precision_types = ("16" , "32" , "64" , "bf16" )
117116
118117 # Raise an exception if there are conflicts between flags
119118 # Set each valid flag to `self._x_flag` after validation
120119 # For devices: Assign gpus, etc. to the accelerator flag and devices flag
121120 self ._strategy_flag : Optional [Union [Strategy , str ]] = None
122121 self ._accelerator_flag : Optional [Union [Accelerator , str ]] = None
123- self ._precision_input : Optional [ _PRECISION_INPUT ] = None
122+ self ._precision_input : _PRECISION_INPUT_STR = "32"
124123 self ._precision_instance : Optional [Precision ] = None
125124 self ._cluster_environment_flag : Optional [Union [ClusterEnvironment , str ]] = None
126125 self ._parallel_devices : List [Union [int , torch .device , str ]] = []
@@ -206,12 +205,10 @@ def _check_config_and_set_final_flags(
206205
207206 self ._accelerator_flag = accelerator
208207
209- if precision is not None :
210- if str (precision ) not in self ._precision_types :
211- raise ValueError (
212- f"Precision { repr (precision )} is invalid. Allowed precision values: { self ._precision_types } "
213- )
214- self ._precision_input = precision
208+ supported_precision = get_args (_PRECISION_INPUT_STR ) + get_args (_PRECISION_INPUT_INT )
209+ if precision not in supported_precision :
210+ raise ValueError (f"Precision { repr (precision )} is invalid. Allowed precision values: { supported_precision } " )
211+ self ._precision_input = cast (_PRECISION_INPUT_STR , str (precision ))
215212
216213 if plugins :
217214 plugins_flags_types : Dict [str , int ] = Counter ()
@@ -442,10 +439,10 @@ def _check_and_init_precision(self) -> Precision:
442439 return self ._precision_instance
443440
444441 if isinstance (self .accelerator , TPUAccelerator ):
445- if self ._precision_input == 32 :
442+ if self ._precision_input == "32" :
446443 return TPUPrecision ()
447- elif self ._precision_input in (16 , "bf16" ):
448- if self ._precision_input == 16 :
444+ elif self ._precision_input in ("16" , "bf16" ):
445+ if self ._precision_input == "16" :
449446 rank_zero_warn (
450447 "You passed `Fabric(accelerator='tpu', precision=16)` but AMP"
451448 " is not supported with TPUs. Using `precision='bf16'` instead."
@@ -454,22 +451,22 @@ def _check_and_init_precision(self) -> Precision:
454451 if isinstance (self .strategy , DeepSpeedStrategy ):
455452 return DeepSpeedPrecision (self ._precision_input ) # type: ignore
456453
457- if self ._precision_input == 32 :
454+ if self ._precision_input == "32" :
458455 return Precision ()
459- if self ._precision_input == 64 :
456+ if self ._precision_input == "64" :
460457 return DoublePrecision ()
461458
462- if self ._precision_input == 16 and self ._accelerator_flag == "cpu" :
459+ if self ._precision_input == "16" and self ._accelerator_flag == "cpu" :
463460 rank_zero_warn (
464461 "You passed `Fabric(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
465462 " Using `precision='bf16'` instead."
466463 )
467464 self ._precision_input = "bf16"
468465
469- if self ._precision_input in (16 , "bf16" ):
466+ if self ._precision_input in ("16" , "bf16" ):
470467 rank_zero_info (
471468 "Using 16-bit Automatic Mixed Precision (AMP)"
472- if self ._precision_input == 16
469+ if self ._precision_input == "16"
473470 else "Using bfloat16 Automatic Mixed Precision (AMP)"
474471 )
475472 device = "cpu" if self ._accelerator_flag == "cpu" else "cuda"
@@ -483,7 +480,7 @@ def _check_and_init_precision(self) -> Precision:
483480 def _validate_precision_choice (self ) -> None :
484481 """Validate the combination of choices for precision, and accelerator."""
485482 if isinstance (self .accelerator , TPUAccelerator ):
486- if self ._precision_input == 64 :
483+ if self ._precision_input == "64" :
487484 raise NotImplementedError (
488485 "`Fabric(accelerator='tpu', precision=64)` is not implemented."
489486 " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
@@ -536,16 +533,12 @@ def _lazy_init_strategy(self) -> None:
536533
537534 @staticmethod
538535 def _argument_from_env (name : str , current : Any , default : Any ) -> Any :
539- env_value : Optional [Union [ str , int ] ] = os .environ .get ("LT_" + name .upper ())
536+ env_value : Optional [str ] = os .environ .get ("LT_" + name .upper ())
540537
541538 if env_value is None :
542539 return current
543540
544- if name == "precision" :
545- # TODO: support precision input as string, then this special handling is not needed
546- env_value = int (env_value ) if env_value in ("16" , "32" , "64" ) else env_value
547-
548- if env_value is not None and env_value != current and current != default :
541+ if env_value is not None and env_value != str (current ) and str (current ) != str (default ):
549542 raise ValueError (
550543 f"Your code has `Fabric({ name } ={ current !r} , ...)` but it conflicts with the value "
551544 f"`--{ name } ={ current } ` set through the CLI. "
0 commit comments