@@ -363,6 +363,7 @@ def test_parse_arguments(job_config):
363363 _ ,
364364 _ ,
365365 _ ,
366+ _ ,
366367 ) = sft_trainer .parse_arguments (parser , job_config_copy )
367368 assert str (model_args .torch_dtype ) == "torch.bfloat16"
368369 assert data_args .dataset_text_field == "output"
@@ -390,6 +391,7 @@ def test_parse_arguments_defaults(job_config):
390391 _ ,
391392 _ ,
392393 _ ,
394+ _ ,
393395 ) = sft_trainer .parse_arguments (parser , job_config_defaults )
394396 assert str (model_args .torch_dtype ) == "torch.bfloat16"
395397 assert model_args .use_flash_attn is False
@@ -400,14 +402,14 @@ def test_parse_arguments_peft_method(job_config):
400402 parser = sft_trainer .get_parser ()
401403 job_config_pt = copy .deepcopy (job_config )
402404 job_config_pt ["peft_method" ] = "pt"
403- _ , _ , _ , _ , tune_config , _ , _ , _ , _ , _ , _ , _ , _ = sft_trainer .parse_arguments (
405+ _ , _ , _ , _ , tune_config , _ , _ , _ , _ , _ , _ , _ , _ , _ = sft_trainer .parse_arguments (
404406 parser , job_config_pt
405407 )
406408 assert isinstance (tune_config , peft_config .PromptTuningConfig )
407409
408410 job_config_lora = copy .deepcopy (job_config )
409411 job_config_lora ["peft_method" ] = "lora"
410- _ , _ , _ , _ , tune_config , _ , _ , _ , _ , _ , _ , _ , _ = sft_trainer .parse_arguments (
412+ _ , _ , _ , _ , tune_config , _ , _ , _ , _ , _ , _ , _ , _ , _ = sft_trainer .parse_arguments (
411413 parser , job_config_lora
412414 )
413415 assert isinstance (tune_config , peft_config .LoraConfig )
@@ -1053,12 +1055,18 @@ def _test_run_inference(checkpoint_path):
10531055
10541056
10551057def _validate_training (
1056- tempdir , check_eval = False , train_logs_file = "training_logs.jsonl"
1058+ tempdir ,
1059+ check_eval = False ,
1060+ train_logs_file = "training_logs.jsonl" ,
1061+ check_scanner_file = False ,
10571062):
10581063 assert any (x .startswith ("checkpoint-" ) for x in os .listdir (tempdir ))
10591064 train_logs_file_path = "{}/{}" .format (tempdir , train_logs_file )
10601065 _validate_logfile (train_logs_file_path , check_eval )
10611066
1067+ if check_scanner_file :
1068+ _validate_hf_resource_scanner_file (tempdir )
1069+
10621070
10631071def _validate_logfile (log_file_path , check_eval = False ):
10641072 train_log_contents = ""
@@ -1073,6 +1081,18 @@ def _validate_logfile(log_file_path, check_eval=False):
10731081 assert "validation_loss" in train_log_contents
10741082
10751083
1084+ def _validate_hf_resource_scanner_file (tempdir ):
1085+ scanner_file_path = os .path .join (tempdir , "scanner_output.json" )
1086+ assert os .path .exists (scanner_file_path ) is True
1087+ assert os .path .getsize (scanner_file_path ) > 0
1088+
1089+ with open (scanner_file_path , "r" , encoding = "utf-8" ) as f :
1090+ scanner_contents = json .load (f )
1091+
1092+ assert scanner_contents ["time_data" ] is not None
1093+ assert scanner_contents ["mem_data" ] is not None
1094+
1095+
10761096def _get_checkpoint_path (dir_path ):
10771097 return os .path .join (dir_path , "checkpoint-5" )
10781098
0 commit comments