diff --git a/pyproject.toml b/pyproject.toml index 9cfeecdc4..ba903ff7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "sci flash-attn = ["flash-attn>=2.5.3,<3.0"] aim = ["aim>=3.19.0,<4.0"] fms-accel = ["fms-acceleration>=0.1"] +scanner-dev = ["HFResourceScanner>=0.1.0"] gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"] diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index e331a5e9b..ed788d8d7 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -22,6 +22,7 @@ # Third Party import pytest +from transformers.utils.import_utils import _is_package_available # First Party from build.accelerate_launch import main @@ -246,6 +247,32 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true(): assert os.path.exists(new_embeddings_file_path) +@pytest.mark.skipif( + not _is_package_available("HFResourceScanner"), + reason="Only runs if HFResourceScanner is installed", +) +def test_launch_with_add_scanner_callback(): + with tempfile.TemporaryDirectory() as tempdir: + setup_env(tempdir) + TRAIN_KWARGS = { + **BASE_LORA_KWARGS, + **{ + "output_dir": tempdir, + "save_model_dir": tempdir, + "lora_post_process_for_vllm": True, + "gradient_accumulation_steps": 1, + "add_scanner_callback": True, + }, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + assert main() == 0 + + scanner_outfile = os.path.join(tempdir, "scanner_output.json") + assert os.path.exists(scanner_outfile) + + def test_bad_script_path(): """Check for appropriate error for an invalid training script location""" with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 69ccbf4fa..21a751327 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -335,6 +335,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -360,6 +361,7 @@ def test_parse_arguments_defaults(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_defaults) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn is False @@ -370,14 +372,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index c02d73781..52b50695a 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -37,6 +37,7 @@ ) from transformers.trainer_utils import get_last_checkpoint from transformers.utils import is_accelerate_available +from transformers.utils.import_utils import _is_package_available from trl import SFTConfig, SFTTrainer import transformers @@ -66,6 +67,10 @@ from tuning.utils.logging import set_log_level from tuning.utils.tokenizer_data_utils import tokenizer_and_embedding_resize +if _is_package_available("HFResourceScanner"): + # Third Party + from HFResourceScanner import Scanner # pylint: disable=import-error + def train( model_args: configs.ModelArguments, @@ -446,6 +451,13 @@ def get_parser(): help='Pass a json string representing K:V pairs to be associated\ to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'', ) + parser.add_argument( + "--add_scanner_callback", + type=bool, + required=False, + default=False, + help="whether to attach the scanner callback to measure memory and time of the training", + ) return parser @@ -498,6 +510,7 @@ def parse_arguments(parser, json_config=None): ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") exp_metadata = json_config.get("exp_metadata") + add_scanner_callback = json_config.get("add_scanner_callback") else: ( model_args, @@ -517,6 +530,7 @@ def parse_arguments(parser, json_config=None): peft_method = additional.peft_method exp_metadata = additional.exp_metadata + add_scanner_callback = additional.add_scanner_callback if peft_method == "lora": tune_config = lora_config @@ -537,6 +551,7 @@ def parse_arguments(parser, json_config=None): fusedops_kernels_config, attention_and_distributed_packing_config, exp_metadata, + add_scanner_callback, ) @@ -558,6 +573,7 @@ def main(): fusedops_kernels_config, attention_and_distributed_packing_config, exp_metadata, + add_scanner_callback, ) = parse_arguments(parser, job_config) # Function to set log level for python native logger and transformers training logger @@ -568,7 +584,7 @@ def main(): model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \ tune_config %s, file_logger_config, %s aim_config %s, \ quantized_lora_config %s, fusedops_kernels_config %s, \ - attention_and_distributed_packing_config %s exp_metadata %s", + attention_and_distributed_packing_config %s, exp_metadata %s, add_scanner_callback %s", model_args, data_args, training_args, @@ -580,6 +596,7 @@ def main(): fusedops_kernels_config, attention_and_distributed_packing_config, exp_metadata, + add_scanner_callback, ) except Exception as e: # pylint: disable=broad-except logger.error(traceback.format_exc()) @@ -607,10 +624,27 @@ def main(): combined_tracker_configs.file_logger_config = file_logger_config combined_tracker_configs.aim_config = aim_config + sc_callback = None if training_args.output_dir: os.makedirs(training_args.output_dir, exist_ok=True) logger.info("using the output directory at %s", training_args.output_dir) + if add_scanner_callback: + if _is_package_available("HFResourceScanner"): + output_fmt = os.path.join( + training_args.output_dir, "scanner_output.json" + ) + sc_callback = [Scanner(output_fmt=output_fmt)] + logger.info( + "Attaching HFResourceScanner as a callback with output_fmt: %s", + output_fmt, + ) + else: + raise ValueError( + "add_scanner_callback was set to true, but HFResourceScanner is not installed. \ + Install the package HFResourceScanner, or set add_scanner_callback to False." + ) + try: trainer, additional_train_info = train( model_args=model_args, @@ -619,7 +653,7 @@ def main(): peft_config=tune_config, trainer_controller_args=trainer_controller_args, tracker_configs=combined_tracker_configs, - additional_callbacks=None, + additional_callbacks=sc_callback, exp_metadata=metadata, quantized_lora_config=quantized_lora_config, fusedops_kernels_config=fusedops_kernels_config,