Skip to content

Commit 234ded8

Browse files
authored
Avoid moving the model to device if move_to_device=False (#19152)
1 parent 5a03612 commit 234ded8

File tree

6 files changed

+52
-32
lines changed

6 files changed

+52
-32
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4444
- Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074))
4545

4646

47+
- Avoid moving the model to device if `move_to_device=False` is passed ([#19152](https://github.com/Lightning-AI/lightning/pull/19152))
48+
49+
4750
- Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
4851

4952

src/lightning/fabric/fabric.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
_update_dataloader,
6565
has_iterable_dataset,
6666
)
67+
from lightning.fabric.utilities.device_dtype_mixin import _update_properties
6768
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
6869
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
6970
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
@@ -243,9 +244,11 @@ def setup(
243244

244245
module = _FabricModule(module, self._precision, original_module=original_module)
245246

246-
if not isinstance(self._strategy, (FSDPStrategy, XLAFSDPStrategy)):
247-
# Update the _DeviceDtypeModuleMixin's device parameter
248-
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
247+
# Update the _DeviceDtypeModuleMixin's device parameter
248+
# NOTE: for sharded strategies or manual device placement, there's no single root device
249+
_update_properties(
250+
module, device=self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device
251+
)
249252

250253
optimizers = [
251254
_FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
@@ -295,9 +298,11 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
295298
module = self._strategy.setup_module(module)
296299
module = _FabricModule(module, self._precision, original_module=original_module)
297300

298-
if not isinstance(self._strategy, (FSDPStrategy, XLAFSDPStrategy)):
299-
# Update the _DeviceDtypeModuleMixin's device parameter
300-
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
301+
# Update the _DeviceDtypeModuleMixin's device parameter
302+
# NOTE: for sharded strategies or manual device placement, there's no single root device
303+
_update_properties(
304+
module, device=self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device
305+
)
301306

302307
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
303308
original_module._fabric = self # type: ignore[assignment]

src/lightning/fabric/utilities/device_dtype_mixin.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def to(self, *args: Any, **kwargs: Any) -> Self:
5050
"""See :meth:`torch.nn.Module.to`."""
5151
# this converts `str` device to `torch.device`
5252
device, dtype = torch._C._nn._parse_to(*args, **kwargs)[:2]
53-
self.__update_properties(device=device, dtype=dtype)
53+
_update_properties(self, device=device, dtype=dtype)
5454
return super().to(*args, **kwargs)
5555

5656
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
@@ -70,43 +70,46 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
7070
device = torch.device("cuda", torch.cuda.current_device())
7171
elif isinstance(device, int):
7272
device = torch.device("cuda", index=device)
73-
self.__update_properties(device=device)
73+
_update_properties(self, device=device)
7474
return super().cuda(device=device)
7575

7676
def cpu(self) -> Self:
7777
"""See :meth:`torch.nn.Module.cpu`."""
78-
self.__update_properties(device=torch.device("cpu"))
78+
_update_properties(self, device=torch.device("cpu"))
7979
return super().cpu()
8080

8181
def type(self, dst_type: Union[str, torch.dtype]) -> Self:
8282
"""See :meth:`torch.nn.Module.type`."""
83-
self.__update_properties(dtype=dst_type)
83+
_update_properties(self, dtype=dst_type)
8484
return super().type(dst_type=dst_type)
8585

8686
def float(self) -> Self:
8787
"""See :meth:`torch.nn.Module.float`."""
88-
self.__update_properties(dtype=torch.float)
88+
_update_properties(self, dtype=torch.float)
8989
return super().float()
9090

9191
def double(self) -> Self:
9292
"""See :meth:`torch.nn.Module.double`."""
93-
self.__update_properties(dtype=torch.double)
93+
_update_properties(self, dtype=torch.double)
9494
return super().double()
9595

9696
def half(self) -> Self:
9797
"""See :meth:`torch.nn.Module.half`."""
98-
self.__update_properties(dtype=torch.half)
98+
_update_properties(self, dtype=torch.half)
9999
return super().half()
100100

101-
def __update_properties(
102-
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
103-
) -> None:
104-
def apply_fn(module: Union[_DeviceDtypeModuleMixin, Module]) -> None:
105-
if not isinstance(module, _DeviceDtypeModuleMixin):
106-
return
107-
if device is not None:
108-
module._device = device
109-
if dtype is not None:
110-
module._dtype = dtype
111-
112-
self.apply(apply_fn)
101+
102+
def _update_properties(
103+
root: torch.nn.Module, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
104+
) -> None:
105+
def apply_fn(module: Union[_DeviceDtypeModuleMixin, Module]) -> None:
106+
if not isinstance(module, _DeviceDtypeModuleMixin):
107+
return
108+
# cannot use `module.to()` because we don't actually want to move the model in case there are multiple
109+
# devices types (such as partial meta parameters)
110+
if device is not None:
111+
module._device = device
112+
if dtype is not None:
113+
module._dtype = dtype
114+
115+
root.apply(apply_fn)

tests/tests_fabric/strategies/test_fsdp_integration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,9 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
304304
else:
305305
assert isinstance(next(fabric_model.parameters()), FlatParameter)
306306

307-
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for sharded models
308-
assert fabric_model.device == torch.device("cpu")
307+
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for models with pieces on
308+
# different devices
309+
assert fabric_model.device == torch.device("cuda", fabric.local_rank)
309310
assert fabric.device == torch.device("cuda", fabric.local_rank)
310311

311312

tests/tests_fabric/strategies/test_xla_fsdp_integration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,9 @@ def _test_setup_module_move_to_device(fabric, move_to_device):
190190
fabric_model = fabric.setup_module(model, move_to_device=move_to_device)
191191
fabric_module_mock.assert_not_called()
192192

193-
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for sharded models
194-
assert fabric_model.device == torch.device("cpu")
193+
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for models with pieces on
194+
# different devices
195+
assert fabric_model.device.type == "xla"
195196
assert fabric.device.type == "xla"
196197

197198

tests/tests_fabric/test_fabric.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ def test_setup_module_parameters_on_different_devices(setup_method, move_to_devi
157157

158158
fabric = Fabric(accelerator="cuda", devices=1)
159159

160-
module0 = nn.Linear(1, 2).to(device0)
161-
module1 = nn.Linear(1, 2).to(device1)
160+
module0 = nn.Linear(1, 2, device=device0)
161+
module1 = nn.Linear(1, 2, device=device1)
162162
model = nn.Sequential(module0, module1)
163163

164164
setup_method = getattr(fabric, setup_method)
@@ -174,7 +174,14 @@ def test_setup_module_parameters_on_different_devices(setup_method, move_to_devi
174174
assert module1.weight.device == module1.bias.device == device1
175175
else:
176176
with no_warning_call(expected_warning=PossibleUserWarning, match=match):
177-
setup_method(model, move_to_device=move_to_device)
177+
fabric_model = setup_method(model, move_to_device=move_to_device)
178+
179+
# the first device is set at the root
180+
assert fabric_model.device == device0
181+
assert fabric_model._device == device0
182+
# the weights were not moved
183+
assert module0.weight.device == module0.bias.device == device0
184+
assert module1.weight.device == module1.bias.device == device1
178185

179186

180187
def test_setup_module_and_optimizers():

0 commit comments

Comments
 (0)