Skip to content

Commit b164f58

Browse files
polinabinder1camirr-nv
authored andcommitted
fixing the ESM2 checkpointing issue (#842)
### Description This addresses: #757 ### Type of changes In the original code the optimizer was not saved in the checkpoint, but it is expected in the megatron strategy. Saving the optimizer is added to the checkpoint callback and the test has been updated. Changes were made to have the checkpointing callback in the nemo logger. This is the same as the training path in sub-packages/bionemo-llm/src/bionemo/llm/train.py. --------- Signed-off-by: Polina Binder <pbinder@nvidia.com> Signed-off-by: polinabinder1 <pbinder@nvidia.com> Signed-off-by: Ubuntu <camirr@nvidia.com>
1 parent cab1f65 commit b164f58

File tree

2 files changed

+23
-53
lines changed

2 files changed

+23
-53
lines changed

sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,11 @@ def main(
338338
# Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
339339
filename="{epoch}-{step}-{consumed_samples}",
340340
# Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
341+
# Save both the weights and the optimizer state.
342+
save_weights_only=False,
343+
save_optim_on_train_end=True,
341344
)
345+
342346
callbacks.append(checkpoint_callback)
343347

344348
auto_resume = resume.AutoResume(

sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py

Lines changed: 19 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def get_command_line_args(arg_name, arg_value) -> str:
184184
arg_str = f"--{arg_name.replace('_', '-')}={arg_value}"
185185
return arg_str
186186

187-
cmd = "train_esm2 " + " ".join(get_command_line_args(arg_name, arg_value) for arg_name, arg_value in args.items())
187+
cmd = f"train_esm2 {' '.join(get_command_line_args(arg_name, arg_value) for arg_name, arg_value in args.items())}"
188188
return cmd
189189

190190

@@ -332,10 +332,6 @@ def test_main_runs(tmp_path, dummy_protein_dataset, dummy_parquet_train_val_inpu
332332

333333

334334
@pytest.mark.slow
335-
@pytest.mark.xfail(
336-
reason="ESM2 training fails to resume from checkpoints. "
337-
"Issue: https://github.com/NVIDIA/bionemo-framework/issues/757"
338-
)
339335
def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, dummy_parquet_train_val_inputs):
340336
max_steps_first_run = 4
341337
max_steps_second_run = 6
@@ -361,11 +357,9 @@ def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, du
361357
)
362358

363359
# The first training command to finish at max_steps_first_run
364-
stdout_first_run, stderr_first_run, returncode_first_run = run_command_in_subprocess(
365-
command=command_first_run, path=str(tmp_path)
366-
)
367-
368-
assert returncode_first_run == 0, "Command failed."
360+
# stdout_first_run, stderr_first_run, returncode_first_run
361+
#
362+
stdout_first_run = run_command_in_subprocess(command=command_first_run, path=str(tmp_path))
369363

370364
assert f"Training epoch 0, iteration 0/{max_steps_first_run - 1}" in stdout_first_run
371365
# Extract and validate global steps
@@ -403,53 +397,25 @@ def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, du
403397
wandb_project=None,
404398
experiment_name=experiment_name,
405399
)
406-
stdout_second_run, stderr_second_run, returncode_second_run = run_command_in_subprocess(
407-
command=command_second_run, path=str(tmp_path)
408-
)
400+
stdout_second_run = run_command_in_subprocess(command=command_second_run, path=str(tmp_path))
409401

410-
# Verify that the command failed with a non-zero exit code
411-
# This assertion will fail if the resume functionality gets fixed, prompting us to update the test
412-
assert returncode_second_run != 0, (
413-
"Resuming training passed. "
414-
"The resume functionality works and this test needs to be updated. "
415-
"Update issue https://github.com/NVIDIA/bionemo-framework/issues/757"
416-
)
402+
# Verify the model can continue training from the checkpoint without errors
417403

418-
# Verify the error message contains the expected exception type and error message
419-
# The test fails if we get a different error than expected, which would require investigation
420-
assert (
421-
"megatron.core.dist_checkpointing.core.CheckpointingException" in stderr_second_run
422-
and "Cannot find global shape metadata for N-D flattened tensor ShardedTensor" in stderr_second_run
423-
), f"ESM2 training resuming failed due to an unexpected error.\nActual stderr: {stderr_second_run}..."
424-
425-
# Output the error for logging purposes
426-
pytest.fail(
427-
"Detected expected failure with megatron.core.dist_checkpointing.core.CheckpointingException as anticipated. "
428-
"This is a known issue tracked in: https://github.com/NVIDIA/bionemo-framework/issues/757"
429-
)
404+
global_steps_second_run = extract_global_steps_from_log(stdout_second_run)
430405

431-
### TODO: The following section should be enabled when issue #757 is resolved ###
432-
# Once the issue is fixed, we'll need to:
433-
# 1. Remove the assertion expecting a non-zero return code
434-
# 2. Replace with below assertions that verify successful resumption
435-
# 3. Check for specific markers in the output that indicate successful state restoration
436-
# 4. Verify the model can continue training from the checkpoint without errors
406+
assert global_steps_second_run[0] == max_steps_first_run
407+
assert global_steps_second_run[-1] == max_steps_second_run - 1
408+
assert len(global_steps_second_run) == max_steps_second_run - max_steps_first_run
437409

438-
# global_steps_second_run = extract_global_steps_from_log(stdout_second_run)
439-
#
440-
# assert global_steps_second_run[0] == max_steps_first_run
441-
# assert global_steps_second_run[-1] == max_steps_second_run - 1
442-
# assert len(global_steps_second_run) == max_steps_second_run - max_steps_first_run
443-
#
444-
# expected_checkpoint_second_run_suffix = f"step={max_steps_second_run - 1}"
445-
# matching_subfolders = [
446-
# p
447-
# for p in checkpoints_dir.iterdir()
448-
# if p.is_dir() and (expected_checkpoint_second_run_suffix in p.name and "last" in p.name)
449-
# ]
450-
# assert matching_subfolders, (
451-
# f"No checkpoint subfolder ending with '{expected_checkpoint_second_run_suffix}' found in {checkpoints_dir}."
452-
# )
410+
expected_checkpoint_second_run_suffix = f"step={max_steps_second_run - 1}"
411+
matching_subfolders = [
412+
p
413+
for p in checkpoints_dir.iterdir()
414+
if p.is_dir() and (expected_checkpoint_second_run_suffix in p.name and "last" in p.name)
415+
]
416+
assert matching_subfolders, (
417+
f"No checkpoint subfolder ending with '{expected_checkpoint_second_run_suffix}' found in {checkpoints_dir}."
418+
)
453419

454420

455421
@pytest.mark.parametrize("limit_val_batches", [0.0, 1.0, 4, None])

0 commit comments

Comments
 (0)