Skip to content

Commit cdb2d1f

Browse files
authored
bug: fsdp cannot load optimizor state using dcp (#3904)
1 parent a89f201 commit cdb2d1f

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/accelerate/utils/fsdp_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,13 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
311311
else input_dir
312312
)
313313
logger.info(f"Loading Optimizer from {ckpt_dir}")
314-
optim_state = {"optimizer": optimizer.state_dict()}
314+
if fsdp_plugin.fsdp_version == 2:
315+
from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
316+
317+
optim_state = get_optimizer_state_dict(model, optimizer, options=sd_options)
318+
else:
319+
optim_state = FSDP.optim_state_dict(model, optimizer)
320+
optim_state = {"optimizer": optim_state}
315321
dist_cp.load(
316322
optim_state,
317323
checkpoint_id=ckpt_dir,

0 commit comments

Comments
 (0)