Skip to content

Commit 35bf1a2

Browse files
committed
time to test fsdp2
1 parent d8a4d84 commit 35bf1a2

File tree

2 files changed

+99
-195
lines changed

2 files changed

+99
-195
lines changed

src/lightning/fabric/utilities/init.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,21 @@ def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurs
113113
if isinstance(obj, Module):
114114
return any(t.is_meta for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse)))
115115
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")
116+
117+
118+
def _has_all_dtensor_params_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool:
119+
from torch.distributed.tensor import DTensor
120+
121+
if isinstance(obj, Optimizer):
122+
return all(
123+
isinstance(t, DTensor)
124+
for param_group in obj.param_groups
125+
for t in param_group["params"]
126+
if isinstance(t, Parameter)
127+
)
128+
if isinstance(obj, Module):
129+
return all(
130+
isinstance(t, DTensor)
131+
for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse))
132+
)
133+
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")

0 commit comments

Comments
 (0)