Skip to content

Commit caa2dd9

Browse files
committed
update
1 parent 9279f49 commit caa2dd9

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/lightning/fabric/utilities/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _load_external_callbacks(group: str) -> list[Any]:
3636
A list of all callbacks collected from external factories.
3737
3838
"""
39-
factories = entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {})
39+
factories = entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type]
4040

4141
external_callbacks: list[Any] = []
4242
for factory in factories:

src/lightning/pytorch/strategies/fsdp2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@
6767
from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
6868

6969
try:
70-
from torch.distributed.checkpoint.stateful import Stateful
70+
from torch.distributed.checkpoint.stateful import Stateful as _TorchStateful
7171
except ImportError:
72-
# define a no-op base class for compatibility
73-
class Stateful:
72+
73+
class _TorchStateful: # type: ignore[no-redef]
7474
pass
7575

7676

@@ -541,7 +541,7 @@ def _init_fsdp2_mp_policy(mp_policy: Optional["MixedPrecisionPolicy"]) -> Option
541541

542542

543543
# Code taken from: https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#saving
544-
class AppState(Stateful):
544+
class AppState(_TorchStateful):
545545
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the
546546
Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs.
547547

0 commit comments

Comments
 (0)