@@ -184,7 +184,7 @@ def get_command_line_args(arg_name, arg_value) -> str:
184184 arg_str = f"--{ arg_name .replace ('_' , '-' )} ={ arg_value } "
185185 return arg_str
186186
187- cmd = "train_esm2 " + " " .join (get_command_line_args (arg_name , arg_value ) for arg_name , arg_value in args .items ())
187+ cmd = f "train_esm2 { ' ' .join (get_command_line_args (arg_name , arg_value ) for arg_name , arg_value in args .items ())} "
188188 return cmd
189189
190190
@@ -332,10 +332,6 @@ def test_main_runs(tmp_path, dummy_protein_dataset, dummy_parquet_train_val_inpu
332332
333333
334334@pytest .mark .slow
335- @pytest .mark .xfail (
336- reason = "ESM2 training fails to resume from checkpoints. "
337- "Issue: https://github.com/NVIDIA/bionemo-framework/issues/757"
338- )
339335def test_main_stop_at_num_steps_and_continue (tmp_path , dummy_protein_dataset , dummy_parquet_train_val_inputs ):
340336 max_steps_first_run = 4
341337 max_steps_second_run = 6
@@ -361,11 +357,9 @@ def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, du
361357 )
362358
363359 # The first training command to finish at max_steps_first_run
364- stdout_first_run , stderr_first_run , returncode_first_run = run_command_in_subprocess (
365- command = command_first_run , path = str (tmp_path )
366- )
367-
368- assert returncode_first_run == 0 , "Command failed."
360+ # stdout_first_run, stderr_first_run, returncode_first_run
361+ #
362+ stdout_first_run = run_command_in_subprocess (command = command_first_run , path = str (tmp_path ))
369363
370364 assert f"Training epoch 0, iteration 0/{ max_steps_first_run - 1 } " in stdout_first_run
371365 # Extract and validate global steps
@@ -403,53 +397,25 @@ def test_main_stop_at_num_steps_and_continue(tmp_path, dummy_protein_dataset, du
403397 wandb_project = None ,
404398 experiment_name = experiment_name ,
405399 )
406- stdout_second_run , stderr_second_run , returncode_second_run = run_command_in_subprocess (
407- command = command_second_run , path = str (tmp_path )
408- )
400+ stdout_second_run = run_command_in_subprocess (command = command_second_run , path = str (tmp_path ))
409401
410- # Verify that the command failed with a non-zero exit code
411- # This assertion will fail if the resume functionality gets fixed, prompting us to update the test
412- assert returncode_second_run != 0 , (
413- "Resuming training passed. "
414- "The resume functionality works and this test needs to be updated. "
415- "Update issue https://github.com/NVIDIA/bionemo-framework/issues/757"
416- )
402+ # Verify the model can continue training from the checkpoint without errors
417403
418- # Verify the error message contains the expected exception type and error message
419- # The test fails if we get a different error than expected, which would require investigation
420- assert (
421- "megatron.core.dist_checkpointing.core.CheckpointingException" in stderr_second_run
422- and "Cannot find global shape metadata for N-D flattened tensor ShardedTensor" in stderr_second_run
423- ), f"ESM2 training resuming failed due to an unexpected error.\n Actual stderr: { stderr_second_run } ..."
424-
425- # Output the error for logging purposes
426- pytest .fail (
427- "Detected expected failure with megatron.core.dist_checkpointing.core.CheckpointingException as anticipated. "
428- "This is a known issue tracked in: https://github.com/NVIDIA/bionemo-framework/issues/757"
429- )
404+ global_steps_second_run = extract_global_steps_from_log (stdout_second_run )
430405
431- ### TODO: The following section should be enabled when issue #757 is resolved ###
432- # Once the issue is fixed, we'll need to:
433- # 1. Remove the assertion expecting a non-zero return code
434- # 2. Replace with below assertions that verify successful resumption
435- # 3. Check for specific markers in the output that indicate successful state restoration
436- # 4. Verify the model can continue training from the checkpoint without errors
406+ assert global_steps_second_run [0 ] == max_steps_first_run
407+ assert global_steps_second_run [- 1 ] == max_steps_second_run - 1
408+ assert len (global_steps_second_run ) == max_steps_second_run - max_steps_first_run
437409
438- # global_steps_second_run = extract_global_steps_from_log(stdout_second_run)
439- #
440- # assert global_steps_second_run[0] == max_steps_first_run
441- # assert global_steps_second_run[-1] == max_steps_second_run - 1
442- # assert len(global_steps_second_run) == max_steps_second_run - max_steps_first_run
443- #
444- # expected_checkpoint_second_run_suffix = f"step={max_steps_second_run - 1}"
445- # matching_subfolders = [
446- # p
447- # for p in checkpoints_dir.iterdir()
448- # if p.is_dir() and (expected_checkpoint_second_run_suffix in p.name and "last" in p.name)
449- # ]
450- # assert matching_subfolders, (
451- # f"No checkpoint subfolder ending with '{expected_checkpoint_second_run_suffix}' found in {checkpoints_dir}."
452- # )
410+ expected_checkpoint_second_run_suffix = f"step={ max_steps_second_run - 1 } "
411+ matching_subfolders = [
412+ p
413+ for p in checkpoints_dir .iterdir ()
414+ if p .is_dir () and (expected_checkpoint_second_run_suffix in p .name and "last" in p .name )
415+ ]
416+ assert matching_subfolders , (
417+ f"No checkpoint subfolder ending with '{ expected_checkpoint_second_run_suffix } ' found in { checkpoints_dir } ."
418+ )
453419
454420
455421@pytest .mark .parametrize ("limit_val_batches" , [0.0 , 1.0 , 4 , None ])
0 commit comments