Skip to content

Commit 9279f49

Browse files
committed
fix mypy issues and install-pkg ci
1 parent daa8667 commit 9279f49

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

src/lightning/fabric/utilities/registry.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ def _load_external_callbacks(group: str) -> list[Any]:
3636
A list of all callbacks collected from external factories.
3737
3838
"""
39-
factories = (
40-
entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type]
41-
)
39+
factories = entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {})
4240

4341
external_callbacks: list[Any] = []
4442
for factory in factories:

src/lightning/pytorch/strategies/fsdp2.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
import torch
2727
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
2828
from torch import Tensor
29-
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
30-
from torch.distributed.checkpoint.stateful import Stateful
3129
from torch.nn import Module
3230
from torch.optim import Optimizer
3331
from typing_extensions import override
@@ -66,7 +64,15 @@
6664

6765
if TYPE_CHECKING:
6866
from torch.distributed.device_mesh import DeviceMesh
69-
from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy
67+
from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
68+
69+
try:
70+
from torch.distributed.checkpoint.stateful import Stateful
71+
except ImportError:
72+
# define a no-op base class for compatibility
73+
class Stateful:
74+
pass
75+
7076

7177
log = logging.getLogger(__name__)
7278

@@ -113,7 +119,7 @@ class FSDP2Strategy(ParallelStrategy):
113119

114120
def __init__(
115121
self,
116-
device_mesh: Union[tuple[int], "DeviceMesh"] = None,
122+
device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None,
117123
accelerator: Optional["pl.accelerators.Accelerator"] = None,
118124
parallel_devices: Optional[list[torch.device]] = None,
119125
cluster_environment: Optional[ClusterEnvironment] = None,
@@ -270,7 +276,7 @@ def _setup_model(self, model: Module) -> Module:
270276
model.to_empty(device=self.root_device)
271277

272278
# Run your custom initialization
273-
def init_weights(m):
279+
def init_weights(m: Module) -> None:
274280
if isinstance(m, torch.nn.Linear):
275281
torch.nn.init.kaiming_uniform_(m.weight)
276282
if m.bias is not None:
@@ -480,6 +486,11 @@ def save_checkpoint(
480486
path.unlink()
481487
path.mkdir(parents=True, exist_ok=True)
482488

489+
if self.model is None:
490+
raise RuntimeError(
491+
"Cannot save checkpoint: FSDP2Strategy model is not initialized."
492+
" Please ensure the strategy is set up before saving."
493+
)
483494
state_dict = {"fsdp2_checkpoint_state_dict": AppState(self.model, self.optimizers)}
484495
_distributed_checkpoint_save(state_dict, path)
485496

@@ -502,7 +513,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
502513
return metadata
503514

504515

505-
def _init_fsdp2_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffloadPolicy"]]) -> "CPUOffloadPolicy":
516+
def _init_fsdp2_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffloadPolicy"]]) -> "OffloadPolicy":
506517
from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy
507518

508519
if cpu_offload is None or cpu_offload is False:
@@ -539,17 +550,21 @@ class AppState(Stateful):
539550
540551
"""
541552

542-
def __init__(self, model, optimizers):
553+
def __init__(self, model: Module, optimizers: list[Optimizer]) -> None:
543554
self.model = model
544555
self.optimizers = optimizers
545556

546-
def state_dict(self):
557+
def state_dict(self) -> dict[str, Any]:
558+
from torch.distributed.checkpoint.state_dict import get_state_dict
559+
547560
# this line automatically manages FSDP FQN's,
548561
# as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
549562
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizers)
550563
return {"model": model_state_dict, "optim": optimizer_state_dict}
551564

552-
def load_state_dict(self, state_dict):
565+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
566+
from torch.distributed.checkpoint.state_dict import set_state_dict
567+
553568
# sets our state dicts on the model and optimizer, now that we've loaded
554569
set_state_dict(
555570
self.model, self.optimizers, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]

0 commit comments

Comments
 (0)