@@ -304,16 +304,82 @@ def get_metrics(self, trainer, model):
304304 return items
305305
306306
307- def get_strategy (strategy ):
308- if strategy ['name' ] == 'auto' :
309- return 'auto'
310-
311- from lightning .pytorch .strategies import StrategyRegistry
312- if strategy ['name' ] not in StrategyRegistry :
313- available_names = ", " .join (sorted (StrategyRegistry .keys ())) or "none"
314- raise ValueError (f"Invalid strategy name { strategy ['name' ]} . Available names: { available_names } " )
315-
316- data = StrategyRegistry [strategy ['name' ]]
317- params = data ['init_params' ]
318- params .update ({k : v for k , v in strategy .items () if k != 'name' })
319- return data ['strategy' ](** utils .filter_kwargs (params , data ['strategy' ]))
307+ def get_strategy (
308+ devices = "auto" ,
309+ num_nodes = 1 ,
310+ accelerator = "auto" ,
311+ strategy = {"name" : "auto" },
312+ precision = None ,
313+ ):
314+ from lightning .fabric .utilities .device_parser import _determine_root_gpu_device
315+ from lightning .pytorch .accelerators import AcceleratorRegistry
316+ from lightning .pytorch .accelerators .cuda import CUDAAccelerator
317+ from lightning .pytorch .accelerators .mps import MPSAccelerator
318+ from lightning .pytorch .strategies import Strategy , SingleDeviceStrategy , StrategyRegistry
319+ from lightning .pytorch .trainer .connectors import accelerator_connector
320+ from lightning .pytorch .utilities .rank_zero import rank_zero_warn
321+ class _DsAcceleratorConnector (accelerator_connector ._AcceleratorConnector ):
322+ def __init__ (self ) -> None :
323+ accelerator_connector ._register_external_accelerators_and_strategies ()
324+ self ._registered_strategies = StrategyRegistry .available_strategies ()
325+ self ._accelerator_types = AcceleratorRegistry .available_accelerators ()
326+ self ._parallel_devices = []
327+ self ._check_config_and_set_final_flags (
328+ strategy = strategy ["name" ],
329+ accelerator = accelerator ,
330+ precision = precision ,
331+ plugins = [],
332+ sync_batchnorm = False ,
333+ )
334+ if self ._accelerator_flag == "auto" :
335+ self ._accelerator_flag = self ._choose_auto_accelerator ()
336+ elif self ._accelerator_flag == "gpu" :
337+ self ._accelerator_flag = self ._choose_gpu_accelerator_backend ()
338+ self ._check_device_config_and_set_final_flags (devices = devices , num_nodes = num_nodes )
339+ self ._set_parallel_devices_and_init_accelerator ()
340+ if self ._strategy_flag == "auto" :
341+ self ._strategy_flag = self ._choose_strategy ()
342+ self ._check_strategy_and_fallback ()
343+ self ._init_strategy ()
344+ for k in ["colossalai" , "bagua" , "hpu" , "hpu_parallel" , "hpu_single" , "ipu" , "ipu_strategy" ]:
345+ if k in StrategyRegistry :
346+ StrategyRegistry .remove (k )
347+
348+ def _init_strategy (self ) -> None :
349+ assert isinstance (self ._strategy_flag , (str , Strategy ))
350+ if isinstance (self ._strategy_flag , str ):
351+ if self ._strategy_flag not in StrategyRegistry :
352+ available_names = ", " .join (sorted (StrategyRegistry .available_strategies ())) or "none"
353+ raise KeyError (f"Invalid strategy name { strategy ['name' ]} . Available names: { available_names } " )
354+ data = StrategyRegistry [self ._strategy_flag ]
355+ params = {}
356+ # Replicate additional logic for _choose_strategy when dealing with single device strategies
357+ if issubclass (data ["strategy" ], SingleDeviceStrategy ):
358+ if self ._accelerator_flag == "hpu" :
359+ params = {"device" : torch .device ("hpu" )}
360+ elif self ._accelerator_flag == "tpu" :
361+ params = {"device" : self ._parallel_devices [0 ]}
362+ elif data ["strategy" ] is SingleDeviceStrategy :
363+ if isinstance (self ._accelerator_flag , (CUDAAccelerator , MPSAccelerator )) or (
364+ isinstance (self ._accelerator_flag , str ) and self ._accelerator_flag in ("cuda" , "gpu" , "mps" )
365+ ):
366+ params = {"device" : _determine_root_gpu_device (self ._parallel_devices )}
367+ else :
368+ params = {"device" : "cpu" }
369+ else :
370+ raise NotImplementedError
371+ params .update (data ["init_params" ])
372+ params .update ({k : v for k , v in strategy .items () if k != "name" })
373+ self .strategy = data ["strategy" ](** utils .filter_kwargs (params , data ["strategy" ]))
374+ elif isinstance (self ._strategy_flag , SingleDeviceStrategy ):
375+ params = {"device" : self ._strategy_flag .root_device }
376+ params .update ({k : v for k , v in strategy .items () if k != "name" })
377+ self .strategy = self ._strategy_flag .__class__ (** utils .filter_kwargs (params , self ._strategy_flag .__class__ ))
378+ else :
379+ rank_zero_warn (
380+ f"Inferred strategy { self ._strategy_flag .__class__ .__name__ } cannot take custom configurations."
381+ f"To use custom configurations, please specify the strategy name explicitly."
382+ )
383+ self .strategy = self ._strategy_flag
384+
385+ return _DsAcceleratorConnector ().strategy
0 commit comments