Skip to content

Commit 6e344b8

Browse files
awaelchlilantiga
authored andcommitted
Handle edge case in Fabric.setup() when model has no parameters (#17441)
(cherry picked from commit 0631fa0)
1 parent 8c9cf00 commit 6e344b8

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717
- Fixed an issue with `LightningModule.*_step` methods bypassing the DDP/FSDP wrapper ([#17424](https://github.com/Lightning-AI/lightning/pull/17424))
1818

1919

20+
- Fixed device handling in `Fabric.setup()` when the model has no parameters ([#17441](https://github.com/Lightning-AI/lightning/pull/17441))
21+
22+
2023
## [2.0.1] - 2023-03-30
2124

2225
### Changed

src/lightning/fabric/fabric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def setup(
204204
module = _FabricModule(module, self._precision, original_module=original_module)
205205

206206
# Update the _DeviceDtypeModuleMixin's device parameter
207-
module.to(self.device if move_to_device else next(module.parameters()).device)
207+
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
208208

209209
optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
210210

@@ -248,7 +248,7 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
248248

249249
if not isinstance(self._strategy, FSDPStrategy):
250250
# Update the _DeviceDtypeModuleMixin's device parameter
251-
module.to(self.device if move_to_device else next(module.parameters()).device)
251+
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
252252

253253
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
254254
original_module._fabric = self # type: ignore[assignment]
@@ -741,7 +741,7 @@ def _run_with_setup(self, run_function: Callable, *args: Any, **kwargs: Any) ->
741741
return run_function(*args, **kwargs)
742742

743743
def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
744-
initial_device = next(model.parameters()).device
744+
initial_device = next(model.parameters(), torch.tensor(0)).device
745745
if any(param.device != initial_device for param in model.parameters()):
746746
rank_zero_warn(
747747
"The model passed to `Fabric.setup()` has parameters on different devices. Since `move_to_device=True`,"

tests/tests_fabric/test_fabric.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ def test_setup_module_move_to_device(setup_method, move_to_device, accelerator,
137137
assert fabric_model.device == expected_device
138138
assert fabric.device == target_device
139139

140+
# edge case: model has no parameters
141+
model = nn.Sequential()
142+
fabric_model = setup_method(model, move_to_device=move_to_device)
143+
assert fabric_model.device == target_device if move_to_device else torch.device("cpu")
144+
140145

141146
@RunIf(min_cuda_gpus=1)
142147
@pytest.mark.parametrize("move_to_device", [True, False])

0 commit comments

Comments
 (0)