Skip to content

Commit 1551a16

Browse files
authored
Add device property to lazy load functionality (#20183)
1 parent 828fd99 commit 1551a16

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/lightning/fabric/utilities/load.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ def __torch_function__(
140140
loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args]
141141
return func(*loaded_args, **kwargs)
142142

143+
@property
144+
def device(self) -> torch.device:
145+
return torch.device(self.storageinfo[3])
146+
143147
def __getattr__(self, name: str) -> Any:
144148
# These properties don't require materialization and can be accessed through the meta tensor directly
145149
if name in {
@@ -160,7 +164,7 @@ def __getattr__(self, name: str) -> Any:
160164
return getattr(self.metatensor, name)
161165

162166
# materializing these is needed for quantization (see lit-gpt)
163-
if name in {"contiguous", "cuda", "half", "data"}:
167+
if name in {"contiguous", "cuda", "half", "data", "to"}:
164168
return getattr(self._load_tensor(), name)
165169

166170
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

tests/tests_fabric/utilities/test_load.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def test_lazy_load_module(tmp_path):
3131
model1.load_state_dict(checkpoint)
3232

3333
assert isinstance(checkpoint["weight"], _NotYetLoadedTensor)
34+
assert checkpoint["weight"].device == torch.device("cpu")
35+
assert type(checkpoint["weight"].to("cpu")) is torch.Tensor
3436
assert type(model0.weight.data) is torch.Tensor
3537
assert torch.equal(model0.weight, model1.weight)
3638
assert torch.equal(model0.bias, model1.bias)

0 commit comments

Comments
 (0)