Skip to content

Commit d82ca75

Browse files
authored
fix: fix checkpointing when val_period does not divide save_period (#1229)
Signed-off-by: ashors1 <[email protected]>
1 parent f7645f3 commit d82ca75

File tree

5 files changed

+93
-24
lines changed

5 files changed

+93
-24
lines changed

nemo_rl/algorithms/dpo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,8 @@ def dpo_train(
633633
):
634634
warnings.warn(
635635
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
636-
"Saving most recent k checkpoints instead."
636+
"This checkpoint will not be saved as top-k."
637637
)
638-
master_config["checkpointing"]["metric_name"] = None
639638

640639
with timer.time("checkpointing"):
641640
print(f"Saving checkpoint for step {total_steps + 1}...")

nemo_rl/algorithms/grpo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,9 +882,8 @@ def grpo_train(
882882
):
883883
warnings.warn(
884884
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
885-
"Saving most recent k checkpoints instead."
885+
"This checkpoint will not be saved as top-k."
886886
)
887-
master_config["checkpointing"]["metric_name"] = None
888887

889888
with timer.time("checkpointing"):
890889
print(

nemo_rl/algorithms/sft.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,8 @@ def sft_train(
506506
):
507507
warnings.warn(
508508
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
509-
"Saving most recent k checkpoints instead."
509+
"This checkpoint will not be saved as top-k."
510510
)
511-
master_config["checkpointing"]["metric_name"] = None
512511

513512
with timer.time("checkpointing"):
514513
print(f"Saving checkpoint for step {total_steps + 1}...")

nemo_rl/utils/checkpoint.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -202,25 +202,18 @@ def remove_old_checkpoints(self, exclude_latest: bool = True) -> None:
202202
if self.metric_name is None:
203203
checkpoint_history.sort(key=lambda x: x[0], reverse=True)
204204
else:
205-
try:
206-
# sort by metric value first, then by step number (for equal metrics, prefer more recent)
207-
if self.higher_is_better:
208-
# For higher_is_better=True: higher metric values first, then higher step numbers
209-
checkpoint_history.sort(
210-
key=lambda x: (x[2][self.metric_name], x[0]), reverse=True
211-
)
212-
else:
213-
# For higher_is_better=False: lower metric values first, then higher step numbers for equal values
214-
checkpoint_history.sort(
215-
key=lambda x: (x[2][self.metric_name], -x[0])
216-
)
217-
except KeyError:
218-
warnings.warn(
219-
f"Metric {self.metric_name} not found in checkpoint history. Keeping most recent k checkpoints."
205+
# sort by metric value first, then by step number (for equal metrics, prefer more recent)
206+
if self.higher_is_better:
207+
# For higher_is_better=True: higher metric values first, then higher step numbers
208+
checkpoint_history.sort(
209+
key=lambda x: (x[2].get(self.metric_name, -float("inf")), x[0]),
210+
reverse=True,
211+
)
212+
else:
213+
# For higher_is_better=False: lower metric values first, then higher step numbers for equal values
214+
checkpoint_history.sort(
215+
key=lambda x: (x[2].get(self.metric_name, float("inf")), -x[0])
220216
)
221-
checkpoint_history.sort(key=lambda x: x[0], reverse=True)
222-
223-
self.metric_name = None
224217

225218
# remove checkpoints that are not in the top-k
226219
for checkpoint in checkpoint_history[self.keep_top_k :]:

tests/unit/utils/test_checkpoint.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,85 @@ def test_remove_old_checkpoints_topk_bias_recent_if_equal(
141141
assert sorted(remaining_steps) == sorted(expected_steps)
142142

143143

144+
def test_remove_old_checkpoints_topk_some_missing_val_metric(
145+
checkpoint_manager, checkpoint_dir
146+
):
147+
# Create checkpoints where some have validation metrics and others don't
148+
steps = [1, 2, 3, 4, 10, 11, 12]
149+
# Some checkpoints have loss metrics, others don't have any validation metrics
150+
training_infos = [
151+
{"loss": 0.5}, # step 1 - has loss
152+
{"loss": 0.3}, # step 2 - has loss
153+
{"other_metric": 0.8}, # step 3 - missing loss metric
154+
{"loss": 0.2}, # step 4 - has loss
155+
{}, # step 10 - missing loss metric
156+
{"loss": 1.0}, # has loss but not in top-k
157+
{}, # step 12 - missing loss (latest)
158+
]
159+
160+
for step, training_info in zip(steps, training_infos):
161+
tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info)
162+
checkpoint_manager.finalize_checkpoint(tmp_dir)
163+
164+
# Check if only top-k checkpoints are kept
165+
remaining_dirs = list(checkpoint_dir.glob("step_*"))
166+
assert (
167+
len(remaining_dirs) == checkpoint_manager.keep_top_k + 1
168+
) # +1 because we exclude the latest
169+
170+
# Checkpoints with missing validation metrics should be treated as having the worst possible value
171+
# Since higher_is_better=False, missing metrics get float("inf") which is worst
172+
# So checkpoints with actual loss values should be preferred over those without
173+
remaining_steps = []
174+
for dir_path in remaining_dirs:
175+
step_num = int(dir_path.name.split("_")[1])
176+
remaining_steps.append(step_num)
177+
178+
# Should keep checkpoints with actual loss values (steps 1, 2, 4, 12)
179+
# and exclude those without loss metrics (steps 3, 10)
180+
# The latest checkpoint (step 12) is always kept
181+
expected_steps = [1, 2, 4, 12] # Steps with loss metrics, plus latest
182+
assert sorted(remaining_steps) == sorted(expected_steps)
183+
184+
185+
def test_remove_old_checkpoints_topk_most_missing_val_metric(
186+
checkpoint_manager, checkpoint_dir
187+
):
188+
# Create checkpoints where some have validation metrics and others don't
189+
steps = [1, 2, 3, 4, 10, 12]
190+
# Some checkpoints have loss metrics, others don't have any validation metrics
191+
training_infos = [
192+
{"loss": 0.2}, # step 1 - has loss
193+
{}, # step 2 - has loss
194+
{"other_metric": 0.8}, # step 3 - missing loss metric
195+
{}, # step 4 - has loss
196+
{}, # step 10 - missing loss metric
197+
{}, # step 12 - missing loss (latest)
198+
]
199+
200+
for step, training_info in zip(steps, training_infos):
201+
tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info)
202+
checkpoint_manager.finalize_checkpoint(tmp_dir)
203+
204+
# Check if only top-k checkpoints are kept
205+
remaining_dirs = list(checkpoint_dir.glob("step_*"))
206+
assert len(remaining_dirs) == checkpoint_manager.keep_top_k
207+
208+
# Checkpoints with missing validation metrics should be treated as having the worst possible value
209+
# Since higher_is_better=False, missing metrics get float("inf") which is worst
210+
# So checkpoints with actual loss values should be preferred over those without
211+
remaining_steps = []
212+
for dir_path in remaining_dirs:
213+
step_num = int(dir_path.name.split("_")[1])
214+
remaining_steps.append(step_num)
215+
216+
# Should keep checkpoints with actual loss values (step 1)
217+
# followed by the most recent steps
218+
# The latest checkpoint (step 12) is always kept
219+
expected_steps = [1, 10, 12] # Steps with loss metrics, plus latest
220+
assert sorted(remaining_steps) == sorted(expected_steps)
221+
222+
144223
def test_get_best_checkpoint_path(checkpoint_manager, checkpoint_dir):
145224
# Create multiple checkpoints with different loss values
146225
steps = [1, 2, 3]

0 commit comments

Comments
 (0)