26
26
import torch
27
27
from lightning_utilities .core .rank_zero import rank_zero_only as utils_rank_zero_only
28
28
from torch import Tensor
29
- from torch .distributed .checkpoint .state_dict import get_state_dict , set_state_dict
30
- from torch .distributed .checkpoint .stateful import Stateful
31
29
from torch .nn import Module
32
30
from torch .optim import Optimizer
33
31
from typing_extensions import override
66
64
67
65
if TYPE_CHECKING :
68
66
from torch .distributed .device_mesh import DeviceMesh
69
- from torch .distributed .fsdp import CPUOffloadPolicy , MixedPrecisionPolicy
67
+ from torch .distributed .fsdp import CPUOffloadPolicy , MixedPrecisionPolicy , OffloadPolicy
68
+
69
+ try :
70
+ from torch .distributed .checkpoint .stateful import Stateful
71
+ except ImportError :
72
+ # define a no-op base class for compatibility
73
+ class Stateful :
74
+ pass
75
+
70
76
71
77
log = logging .getLogger (__name__ )
72
78
@@ -113,7 +119,7 @@ class FSDP2Strategy(ParallelStrategy):
113
119
114
120
def __init__ (
115
121
self ,
116
- device_mesh : Union [tuple [int ], "DeviceMesh" ] = None ,
122
+ device_mesh : Optional [ Union [tuple [int ], "DeviceMesh" ] ] = None ,
117
123
accelerator : Optional ["pl.accelerators.Accelerator" ] = None ,
118
124
parallel_devices : Optional [list [torch .device ]] = None ,
119
125
cluster_environment : Optional [ClusterEnvironment ] = None ,
@@ -270,7 +276,7 @@ def _setup_model(self, model: Module) -> Module:
270
276
model .to_empty (device = self .root_device )
271
277
272
278
# Run your custom initialization
273
- def init_weights (m ) :
279
+ def init_weights (m : Module ) -> None :
274
280
if isinstance (m , torch .nn .Linear ):
275
281
torch .nn .init .kaiming_uniform_ (m .weight )
276
282
if m .bias is not None :
@@ -480,6 +486,11 @@ def save_checkpoint(
480
486
path .unlink ()
481
487
path .mkdir (parents = True , exist_ok = True )
482
488
489
+ if self .model is None :
490
+ raise RuntimeError (
491
+ "Cannot save checkpoint: FSDP2Strategy model is not initialized."
492
+ " Please ensure the strategy is set up before saving."
493
+ )
483
494
state_dict = {"fsdp2_checkpoint_state_dict" : AppState (self .model , self .optimizers )}
484
495
_distributed_checkpoint_save (state_dict , path )
485
496
@@ -502,7 +513,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
502
513
return metadata
503
514
504
515
505
- def _init_fsdp2_cpu_offload (cpu_offload : Optional [Union [bool , "CPUOffloadPolicy" ]]) -> "CPUOffloadPolicy " :
516
+ def _init_fsdp2_cpu_offload (cpu_offload : Optional [Union [bool , "CPUOffloadPolicy" ]]) -> "OffloadPolicy " :
506
517
from torch .distributed .fsdp import CPUOffloadPolicy , OffloadPolicy
507
518
508
519
if cpu_offload is None or cpu_offload is False :
@@ -539,17 +550,21 @@ class AppState(Stateful):
539
550
540
551
"""
541
552
542
- def __init__ (self , model , optimizers ) :
553
+ def __init__ (self , model : Module , optimizers : list [ Optimizer ]) -> None :
543
554
self .model = model
544
555
self .optimizers = optimizers
545
556
546
- def state_dict (self ):
557
+ def state_dict (self ) -> dict [str , Any ]:
558
+ from torch .distributed .checkpoint .state_dict import get_state_dict
559
+
547
560
# this line automatically manages FSDP FQN's,
548
561
# as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
549
562
model_state_dict , optimizer_state_dict = get_state_dict (self .model , self .optimizers )
550
563
return {"model" : model_state_dict , "optim" : optimizer_state_dict }
551
564
552
- def load_state_dict (self , state_dict ):
565
+ def load_state_dict (self , state_dict : dict [str , Any ]) -> None :
566
+ from torch .distributed .checkpoint .state_dict import set_state_dict
567
+
553
568
# sets our state dicts on the model and optimizer, now that we've loaded
554
569
set_state_dict (
555
570
self .model , self .optimizers , model_state_dict = state_dict ["model" ], optim_state_dict = state_dict ["optim" ]
0 commit comments