Skip to content

Commit daa8667

Browse files
committed
update
1 parent cc6de82 commit daa8667

File tree

1 file changed

+2
-1
lines changed
  • src/lightning/fabric/strategies

1 file changed

+2
-1
lines changed

src/lightning/fabric/strategies/fsdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from lightning_utilities.core.imports import RequirementCache
3232
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
3333
from torch import Tensor
34-
from torch.distributed.tensor import DTensor
3534
from torch.nn import Module
3635
from torch.optim import Optimizer
3736
from typing_extensions import TypeGuard, override
@@ -797,6 +796,8 @@ def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
797796

798797

799798
def _optimizer_has_dtensor_params(optimizer: Optimizer) -> bool:
799+
from torch.distributed.tensor import DTensor
800+
800801
return any(isinstance(param, DTensor) for group in optimizer.param_groups for param in group["params"])
801802

802803

0 commit comments

Comments
 (0)