We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 38e9ef1 commit da69573Copy full SHA for da69573
nemo_rl/models/policy/megatron_policy_worker.py
@@ -1770,6 +1770,8 @@ def save_checkpoint(
1770
if not is_training:
1771
self.model.eval()
1772
1773
+ if self.should_disable_forward_pre_hook:
1774
+ self.disable_forward_pre_hook()
1775
save_checkpoint(
1776
state=self.mcore_state,
1777
model=[self.model],
@@ -1784,6 +1786,8 @@ def save_checkpoint(
1784
1786
blocking=True,
1785
1787
terminate=True,
1788
)
1789
1790
+ self.enable_forward_pre_hook()
1791
1792
if not is_training: # Restore training state if it was changed
1793
self.model.train()
0 commit comments