Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 25 additions & 24 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,32 +304,9 @@ def main(
start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True
)
)

if create_tflops_callback:
# Add callback that logs the tera-FLOPS per second per GPU during training.
data_module.global_batch_size = (
global_batch_size # TODO(dorotat): remove this change after FLOPsMeasurementCallback is refactored
)
flop_meas_callback = FLOPsMeasurementCallback(
esm2_config,
data_module,
"bert",
)
callbacks.append(flop_meas_callback)

# Setup the logger and train the model
nemo_logger = setup_nemo_lightning_logger(
root_dir=result_dir,
name=experiment_name,
initialize_tensorboard_logger=create_tensorboard_logger,
wandb_config=wandb_config,
)

# Configure our custom ModelCheckpointe callback and AutoResume to save at nemo_logger.save_dir/checkpoints
if create_checkpoint_callback:
checkpoint_path = str(Path(nemo_logger.save_dir) / "checkpoints")
checkpoint_callback = nl_callbacks.ModelCheckpoint(
dirpath=checkpoint_path,
save_last=save_last_checkpoint,
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
save_top_k=save_top_k,
Expand All @@ -339,8 +316,20 @@ def main(
filename="{epoch}-{step}-{consumed_samples}",
# Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
)
callbacks.append(checkpoint_callback)

else:
checkpoint_callback = None
# Setup the logger and train the model
nemo_logger = setup_nemo_lightning_logger(
root_dir=result_dir,
name=experiment_name,
initialize_tensorboard_logger=create_tensorboard_logger,
wandb_config=wandb_config,
ckpt_callback=checkpoint_callback,
)

if create_checkpoint_callback:
checkpoint_path = str(Path(nemo_logger.save_dir) / "checkpoints")
auto_resume = resume.AutoResume(
resume_from_directory=checkpoint_path,
resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training.
Expand All @@ -350,6 +339,18 @@ def main(
else:
auto_resume = None

if create_tflops_callback:
# Add callback that logs the tera-FLOPS per second per GPU during training.
data_module.global_batch_size = (
global_batch_size # TODO(dorotat): remove this change after FLOPsMeasurementCallback is refactored
)
flop_meas_callback = FLOPsMeasurementCallback(
esm2_config,
data_module,
"bert",
)
callbacks.append(flop_meas_callback)

trainer = nl.Trainer(
devices=devices,
max_steps=num_steps if early_stop_on_step is None else early_stop_on_step,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def get_command_line_args(arg_name, arg_value) -> str:
return arg_str

cmd = "train_esm2 " + " ".join(get_command_line_args(arg_name, arg_value) for arg_name, arg_value in args.items())
print("CMD", cmd)
return cmd


Expand Down Expand Up @@ -332,10 +333,6 @@ def test_main_runs(tmp_path, dummy_protein_dataset, dummy_parquet_train_val_inpu


@pytest.mark.slow
@pytest.mark.xfail(
reason="ESM2 training fails to resume from checkpoints. "
"Issue: https://github.com/NVIDIA/bionemo-framework/issues/757"
)
def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, dummy_parquet_train_val_inputs):
max_steps_first_run = 4
max_steps_second_run = 6
Expand All @@ -361,11 +358,9 @@ def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, du
)

# The first training command to finish at max_steps_first_run
stdout_first_run, stderr_first_run, returncode_first_run = run_command_in_subprocess(
command=command_first_run, path=str(tmp_path)
)

assert returncode_first_run == 0, "Command failed."
# stdout_first_run, stderr_first_run, returncode_first_run
#
stdout_first_run = run_command_in_subprocess(command=command_first_run, path=str(tmp_path))

assert f"Training epoch 0, iteration 0/{max_steps_first_run - 1}" in stdout_first_run
# Extract and validate global steps
Expand Down Expand Up @@ -403,53 +398,25 @@ def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, du
wandb_project=None,
experiment_name=experiment_name,
)
stdout_second_run, stderr_second_run, returncode_second_run = run_command_in_subprocess(
command=command_second_run, path=str(tmp_path)
)
stdout_second_run = run_command_in_subprocess(command=command_second_run, path=str(tmp_path))

# Verify that the command failed with a non-zero exit code
# This assertion will fail if the resume functionality gets fixed, prompting us to update the test
assert returncode_second_run != 0, (
"Resuming training passed. "
"The resume functionality works and this test needs to be updated. "
"Update issue https://github.com/NVIDIA/bionemo-framework/issues/757"
)
# Verify the model can continue training from the checkpoint without errors

# Verify the error message contains the expected exception type and error message
# The test fails if we get a different error than expected, which would require investigation
assert (
"megatron.core.dist_checkpointing.core.CheckpointingException" in stderr_second_run
and "Cannot find global shape metadata for N-D flattened tensor ShardedTensor" in stderr_second_run
), f"ESM2 training resuming failed due to an unexpected error.\nActual stderr: {stderr_second_run}..."

# Output the error for logging purposes
pytest.fail(
"Detected expected failure with megatron.core.dist_checkpointing.core.CheckpointingException as anticipated. "
"This is a known issue tracked in: https://github.com/NVIDIA/bionemo-framework/issues/757"
)
global_steps_second_run = extract_global_steps_from_log(stdout_second_run)

### TODO: The following section should be enabled when issue #757 is resolved ###
# Once the issue is fixed, we'll need to:
# 1. Remove the assertion expecting a non-zero return code
# 2. Replace with below assertions that verify successful resumption
# 3. Check for specific markers in the output that indicate successful state restoration
# 4. Verify the model can continue training from the checkpoint without errors
assert global_steps_second_run[0] == max_steps_first_run
assert global_steps_second_run[-1] == max_steps_second_run - 1
assert len(global_steps_second_run) == max_steps_second_run - max_steps_first_run

# global_steps_second_run = extract_global_steps_from_log(stdout_second_run)
#
# assert global_steps_second_run[0] == max_steps_first_run
# assert global_steps_second_run[-1] == max_steps_second_run - 1
# assert len(global_steps_second_run) == max_steps_second_run - max_steps_first_run
#
# expected_checkpoint_second_run_suffix = f"step={max_steps_second_run - 1}"
# matching_subfolders = [
# p
# for p in checkpoints_dir.iterdir()
# if p.is_dir() and (expected_checkpoint_second_run_suffix in p.name and "last" in p.name)
# ]
# assert matching_subfolders, (
# f"No checkpoint subfolder ending with '{expected_checkpoint_second_run_suffix}' found in {checkpoints_dir}."
# )
expected_checkpoint_second_run_suffix = f"step={max_steps_second_run - 1}"
matching_subfolders = [
p
for p in checkpoints_dir.iterdir()
if p.is_dir() and (expected_checkpoint_second_run_suffix in p.name and "last" in p.name)
]
assert matching_subfolders, (
f"No checkpoint subfolder ending with '{expected_checkpoint_second_run_suffix}' found in {checkpoints_dir}."
)


@pytest.mark.parametrize("limit_val_batches", [0.0, 1.0, 4, None])
Expand Down
Loading