@@ -100,18 +100,18 @@ class _Connector:
100100
101101 def __init__ (
102102 self ,
103- accelerator : Optional [ Union [str , Accelerator ]] = None ,
104- strategy : Optional [ Union [str , Strategy ]] = None ,
105- devices : Optional [ Union [List [int ], str , int ]] = None ,
103+ accelerator : Union [str , Accelerator ] = "auto" ,
104+ strategy : Union [str , Strategy ] = "auto" ,
105+ devices : Union [List [int ], str , int ] = "auto" ,
106106 num_nodes : int = 1 ,
107107 precision : _PRECISION_INPUT = "32-true" ,
108108 plugins : Optional [Union [_PLUGIN_INPUT , List [_PLUGIN_INPUT ]]] = None ,
109109 ) -> None :
110110
111111 # These arguments can be set through environment variables set by the CLI
112- accelerator = self ._argument_from_env ("accelerator" , accelerator , default = None )
113- strategy = self ._argument_from_env ("strategy" , strategy , default = None )
114- devices = self ._argument_from_env ("devices" , devices , default = None )
112+ accelerator = self ._argument_from_env ("accelerator" , accelerator , default = "auto" )
113+ strategy = self ._argument_from_env ("strategy" , strategy , default = "auto" )
114+ devices = self ._argument_from_env ("devices" , devices , default = "auto" )
115115 num_nodes = self ._argument_from_env ("num_nodes" , num_nodes , default = 1 )
116116 precision = self ._argument_from_env ("precision" , precision , default = "32-true" )
117117
@@ -123,8 +123,8 @@ def __init__(
123123 # Raise an exception if there are conflicts between flags
124124 # Set each valid flag to `self._x_flag` after validation
125125 # For devices: Assign gpus, etc. to the accelerator flag and devices flag
126- self ._strategy_flag : Optional [ Union [Strategy , str ]] = None
127- self ._accelerator_flag : Optional [ Union [Accelerator , str ]] = None
126+ self ._strategy_flag : Union [Strategy , str ] = "auto"
127+ self ._accelerator_flag : Union [Accelerator , str ] = "auto"
128128 self ._precision_input : _PRECISION_INPUT_STR = "32-true"
129129 self ._precision_instance : Optional [Precision ] = None
130130 self ._cluster_environment_flag : Optional [Union [ClusterEnvironment , str ]] = None
@@ -141,7 +141,7 @@ def __init__(
141141
142142 # 2. Instantiate Accelerator
143143 # handle `auto`, `None` and `gpu`
144- if self ._accelerator_flag == "auto" or self . _accelerator_flag is None :
144+ if self ._accelerator_flag == "auto" :
145145 self ._accelerator_flag = self ._choose_auto_accelerator ()
146146 elif self ._accelerator_flag == "gpu" :
147147 self ._accelerator_flag = self ._choose_gpu_accelerator_backend ()
@@ -152,7 +152,7 @@ def __init__(
152152 self .cluster_environment : ClusterEnvironment = self ._choose_and_init_cluster_environment ()
153153
154154 # 4. Instantiate Strategy - Part 1
155- if self ._strategy_flag is None :
155+ if self ._strategy_flag == "auto" :
156156 self ._strategy_flag = self ._choose_strategy ()
157157 # In specific cases, ignore user selection and fall back to a different strategy
158158 self ._check_strategy_and_fallback ()
@@ -166,8 +166,8 @@ def __init__(
166166
167167 def _check_config_and_set_final_flags (
168168 self ,
169- strategy : Optional [ Union [str , Strategy ] ],
170- accelerator : Optional [ Union [str , Accelerator ] ],
169+ strategy : Union [str , Strategy ],
170+ accelerator : Union [str , Accelerator ],
171171 precision : _PRECISION_INPUT ,
172172 plugins : Optional [Union [_PLUGIN_INPUT , List [_PLUGIN_INPUT ]]],
173173 ) -> None :
@@ -188,26 +188,24 @@ def _check_config_and_set_final_flags(
188188 if isinstance (strategy , str ):
189189 strategy = strategy .lower ()
190190
191- if strategy is not None :
192- self ._strategy_flag = strategy
191+ self ._strategy_flag = strategy
193192
194- if strategy is not None and strategy not in self ._registered_strategies and not isinstance (strategy , Strategy ):
193+ if strategy != "auto" and strategy not in self ._registered_strategies and not isinstance (strategy , Strategy ):
195194 raise ValueError (
196195 f"You selected an invalid strategy name: `strategy={ strategy !r} `."
197196 " It must be either a string or an instance of `lightning.fabric.strategies.Strategy`."
198- " Example choices: ddp, ddp_spawn, deepspeed, dp, ..."
197+ " Example choices: auto, ddp, ddp_spawn, deepspeed, dp, ..."
199198 " Find a complete list of options in our documentation at https://lightning.ai"
200199 )
201200
202201 if (
203- accelerator is not None
204- and accelerator not in self ._registered_accelerators
202+ accelerator not in self ._registered_accelerators
205203 and accelerator not in ("auto" , "gpu" )
206204 and not isinstance (accelerator , Accelerator )
207205 ):
208206 raise ValueError (
209207 f"You selected an invalid accelerator name: `accelerator={ accelerator !r} `."
210- f" Available names are: { ', ' .join (self ._registered_accelerators )} ."
208+ f" Available names are: auto, { ', ' .join (self ._registered_accelerators )} ."
211209 )
212210
213211 # MPS accelerator is incompatible with DDP family of strategies. It supports single-device operation only.
@@ -256,9 +254,9 @@ def _check_config_and_set_final_flags(
256254 # handle the case when the user passes in a strategy instance which has an accelerator, precision,
257255 # checkpoint io or cluster env set up
258256 # TODO: improve the error messages below
259- if self . _strategy_flag and isinstance (self ._strategy_flag , Strategy ):
257+ if isinstance (self ._strategy_flag , Strategy ):
260258 if self ._strategy_flag ._accelerator :
261- if self ._accelerator_flag :
259+ if self ._accelerator_flag != "auto" :
262260 raise ValueError ("accelerator set through both strategy class and accelerator flag, choose one" )
263261 else :
264262 self ._accelerator_flag = self ._strategy_flag ._accelerator
@@ -297,9 +295,7 @@ def _check_config_and_set_final_flags(
297295 self ._accelerator_flag = "cuda"
298296 self ._parallel_devices = self ._strategy_flag .parallel_devices
299297
300- def _check_device_config_and_set_final_flags (
301- self , devices : Optional [Union [List [int ], str , int ]], num_nodes : int
302- ) -> None :
298+ def _check_device_config_and_set_final_flags (self , devices : Union [List [int ], str , int ], num_nodes : int ) -> None :
303299 self ._num_nodes_flag = int (num_nodes ) if num_nodes is not None else 1
304300 self ._devices_flag = devices
305301
@@ -314,21 +310,14 @@ def _check_device_config_and_set_final_flags(
314310 f" using { accelerator_name } accelerator."
315311 )
316312
317- if self ._devices_flag == "auto" and self ._accelerator_flag is None :
318- raise ValueError (
319- f"You passed `devices={ devices } ` but haven't specified"
320- " `accelerator=('auto'|'tpu'|'gpu'|'cpu'|'mps')` for the devices mapping."
321- )
322-
323313 def _choose_auto_accelerator (self ) -> str :
324314 """Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
325- if self ._accelerator_flag == "auto" :
326- if TPUAccelerator .is_available ():
327- return "tpu"
328- if MPSAccelerator .is_available ():
329- return "mps"
330- if CUDAAccelerator .is_available ():
331- return "cuda"
315+ if TPUAccelerator .is_available ():
316+ return "tpu"
317+ if MPSAccelerator .is_available ():
318+ return "mps"
319+ if CUDAAccelerator .is_available ():
320+ return "cuda"
332321 return "cpu"
333322
334323 @staticmethod
@@ -337,7 +326,6 @@ def _choose_gpu_accelerator_backend() -> str:
337326 return "mps"
338327 if CUDAAccelerator .is_available ():
339328 return "cuda"
340-
341329 raise RuntimeError ("No supported gpu backend found!" )
342330
343331 def _set_parallel_devices_and_init_accelerator (self ) -> None :
@@ -368,7 +356,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
368356 self ._parallel_devices = accelerator_cls .get_parallel_devices (self ._devices_flag )
369357
370358 def _set_devices_flag_if_auto_passed (self ) -> None :
371- if self ._devices_flag == "auto" or self . _devices_flag is None :
359+ if self ._devices_flag == "auto" :
372360 self ._devices_flag = self .accelerator .auto_device_count ()
373361
374362 def _choose_and_init_cluster_environment (self ) -> ClusterEnvironment :
@@ -527,7 +515,7 @@ def _lazy_init_strategy(self) -> None:
527515 raise RuntimeError (
528516 f"`Fabric(strategy={ self ._strategy_flag !r} )` is not compatible with an interactive"
529517 " environment. Run your code as a script, or choose one of the compatible strategies:"
530- f" `Fabric(strategy=None| 'dp'|'ddp_notebook')`."
518+ f" `Fabric(strategy='dp'|'ddp_notebook')`."
531519 " In case you are spawning processes yourself, make sure to include the Fabric"
532520 " creation inside the worker function."
533521 )
0 commit comments