Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,11 @@ def main(
# Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
filename="{epoch}-{step}-{consumed_samples}",
# Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
# Save both the weights and the optimizer state.
save_weights_only=False,
save_optim_on_train_end=True,
)

callbacks.append(checkpoint_callback)

auto_resume = resume.AutoResume(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def get_command_line_args(arg_name, arg_value) -> str:
return arg_str

cmd = "train_esm2 " + " ".join(get_command_line_args(arg_name, arg_value) for arg_name, arg_value in args.items())
print("CMD", cmd)
return cmd


Expand Down Expand Up @@ -332,10 +333,6 @@ def test_main_runs(tmp_path, dummy_protein_dataset, dummy_parquet_train_val_inpu


@pytest.mark.slow
@pytest.mark.xfail(
reason="ESM2 training fails to resume from checkpoints. "
"Issue: https://github.com/NVIDIA/bionemo-framework/issues/757"
)
def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, dummy_parquet_train_val_inputs):
max_steps_first_run = 4
max_steps_second_run = 6
Expand All @@ -361,11 +358,9 @@ def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, du
)

# The first training command to finish at max_steps_first_run
stdout_first_run, stderr_first_run, returncode_first_run = run_command_in_subprocess(
command=command_first_run, path=str(tmp_path)
)

assert returncode_first_run == 0, "Command failed."
# stdout_first_run, stderr_first_run, returncode_first_run
#
stdout_first_run = run_command_in_subprocess(command=command_first_run, path=str(tmp_path))

assert f"Training epoch 0, iteration 0/{max_steps_first_run - 1}" in stdout_first_run
# Extract and validate global steps
Expand Down Expand Up @@ -403,53 +398,25 @@ def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, du
wandb_project=None,
experiment_name=experiment_name,
)
stdout_second_run, stderr_second_run, returncode_second_run = run_command_in_subprocess(
command=command_second_run, path=str(tmp_path)
)
stdout_second_run = run_command_in_subprocess(command=command_second_run, path=str(tmp_path))

# Verify that the command failed with a non-zero exit code
# This assertion will fail if the resume functionality gets fixed, prompting us to update the test
assert returncode_second_run != 0, (
"Resuming training passed. "
"The resume functionality works and this test needs to be updated. "
"Update issue https://github.com/NVIDIA/bionemo-framework/issues/757"
)
# Verify the model can continue training from the checkpoint without errors

# Verify the error message contains the expected exception type and error message
# The test fails if we get a different error than expected, which would require investigation
assert (
"megatron.core.dist_checkpointing.core.CheckpointingException" in stderr_second_run
and "Cannot find global shape metadata for N-D flattened tensor ShardedTensor" in stderr_second_run
), f"ESM2 training resuming failed due to an unexpected error.\nActual stderr: {stderr_second_run}..."

# Output the error for logging purposes
pytest.fail(
"Detected expected failure with megatron.core.dist_checkpointing.core.CheckpointingException as anticipated. "
"This is a known issue tracked in: https://github.com/NVIDIA/bionemo-framework/issues/757"
)
global_steps_second_run = extract_global_steps_from_log(stdout_second_run)

### TODO: The following section should be enabled when issue #757 is resolved ###
# Once the issue is fixed, we'll need to:
# 1. Remove the assertion expecting a non-zero return code
# 2. Replace with below assertions that verify successful resumption
# 3. Check for specific markers in the output that indicate successful state restoration
# 4. Verify the model can continue training from the checkpoint without errors
assert global_steps_second_run[0] == max_steps_first_run
assert global_steps_second_run[-1] == max_steps_second_run - 1
assert len(global_steps_second_run) == max_steps_second_run - max_steps_first_run

# global_steps_second_run = extract_global_steps_from_log(stdout_second_run)
#
# assert global_steps_second_run[0] == max_steps_first_run
# assert global_steps_second_run[-1] == max_steps_second_run - 1
# assert len(global_steps_second_run) == max_steps_second_run - max_steps_first_run
#
# expected_checkpoint_second_run_suffix = f"step={max_steps_second_run - 1}"
# matching_subfolders = [
# p
# for p in checkpoints_dir.iterdir()
# if p.is_dir() and (expected_checkpoint_second_run_suffix in p.name and "last" in p.name)
# ]
# assert matching_subfolders, (
# f"No checkpoint subfolder ending with '{expected_checkpoint_second_run_suffix}' found in {checkpoints_dir}."
# )
expected_checkpoint_second_run_suffix = f"step={max_steps_second_run - 1}"
matching_subfolders = [
p
for p in checkpoints_dir.iterdir()
if p.is_dir() and (expected_checkpoint_second_run_suffix in p.name and "last" in p.name)
]
assert matching_subfolders, (
f"No checkpoint subfolder ending with '{expected_checkpoint_second_run_suffix}' found in {checkpoints_dir}."
)


@pytest.mark.parametrize("limit_val_batches", [0.0, 1.0, 4, None])
Expand Down