Skip to content

Commit da69573

Browse files
authored
fix: checkpoint saving with distributed optimizer + overlap param gather (#949)
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
1 parent 38e9ef1 commit da69573

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

nemo_rl/models/policy/megatron_policy_worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,8 @@ def save_checkpoint(
17701770
if not is_training:
17711771
self.model.eval()
17721772

1773+
if self.should_disable_forward_pre_hook:
1774+
self.disable_forward_pre_hook()
17731775
save_checkpoint(
17741776
state=self.mcore_state,
17751777
model=[self.model],
@@ -1784,6 +1786,8 @@ def save_checkpoint(
17841786
blocking=True,
17851787
terminate=True,
17861788
)
1789+
if self.should_disable_forward_pre_hook:
1790+
self.enable_forward_pre_hook()
17871791

17881792
if not is_training: # Restore training state if it was changed
17891793
self.model.train()

0 commit comments

Comments
 (0)