@@ -204,7 +204,7 @@ def setup(
204204 module = _FabricModule (module , self ._precision , original_module = original_module )
205205
206206 # Update the _DeviceDtypeModuleMixin's device parameter
207- module .to (self .device if move_to_device else next (module .parameters ()).device )
207+ module .to (self .device if move_to_device else next (module .parameters (), torch . tensor ( 0 ) ).device )
208208
209209 optimizers = [_FabricOptimizer (optimizer = optimizer , strategy = self ._strategy ) for optimizer in optimizers ]
210210
@@ -248,7 +248,7 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
248248
249249 if not isinstance (self ._strategy , FSDPStrategy ):
250250 # Update the _DeviceDtypeModuleMixin's device parameter
251- module .to (self .device if move_to_device else next (module .parameters ()).device )
251+ module .to (self .device if move_to_device else next (module .parameters (), torch . tensor ( 0 ) ).device )
252252
253253 if hasattr (original_module , "_fabric" ): # this is probably a LightningModule
254254 original_module ._fabric = self # type: ignore[assignment]
@@ -741,7 +741,7 @@ def _run_with_setup(self, run_function: Callable, *args: Any, **kwargs: Any) ->
741741 return run_function (* args , ** kwargs )
742742
743743 def _move_model_to_device (self , model : nn .Module , optimizers : List [Optimizer ]) -> nn .Module :
744- initial_device = next (model .parameters ()).device
744+ initial_device = next (model .parameters (), torch . tensor ( 0 ) ).device
745745 if any (param .device != initial_device for param in model .parameters ()):
746746 rank_zero_warn (
747747 "The model passed to `Fabric.setup()` has parameters on different devices. Since `move_to_device=True`,"
0 commit comments