diff --git a/build/Dockerfile b/build/Dockerfile index 9a6a5583f..66aa40a6e 100644 --- a/build/Dockerfile +++ b/build/Dockerfile @@ -23,6 +23,7 @@ ARG WHEEL_VERSION="" ARG ENABLE_AIM=false ARG ENABLE_MLFLOW=false ARG ENABLE_FMS_ACCELERATION=true +ARG ENABLE_SCANNER=false ## Base Layer ################################################################## FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} AS base @@ -111,6 +112,7 @@ ARG USER ARG USER_UID ARG ENABLE_FMS_ACCELERATION ARG ENABLE_AIM +ARG ENABLE_SCANNER RUN dnf install -y git && \ # perl-Net-SSLeay.x86_64 and server_key.pem are installed with git as dependencies @@ -154,7 +156,11 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \ RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \ python -m pip install --user "$(head bdist_name)[mlflow]"; \ -fi + fi + +RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \ + python -m pip install --user "$(head bdist_name)[scanner-dev]"; \ + fi # Clean up the wheel module. It's only needed by flash-attn install RUN python -m pip uninstall wheel build -y && \ diff --git a/pyproject.toml b/pyproject.toml index b930f7680..0041804bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ aim = ["aim>=3.19.0,<4.0"] mlflow = ["mlflow"] fms-accel = ["fms-acceleration>=0.6"] gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"] +scanner-dev = ["HFResourceScanner>=0.1.0"] [tool.setuptools.packages.find] diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index e331a5e9b..c699e16da 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -16,12 +16,14 @@ """ # Standard +import json import os import tempfile import glob # Third Party import pytest +from transformers.utils.import_utils import _is_package_available # First Party from build.accelerate_launch import main @@ -31,7 +33,10 @@ USER_ERROR_EXIT_CODE, INTERNAL_ERROR_EXIT_CODE, ) -from tuning.config.tracker_configs import FileLoggingTrackerConfig +from tuning.config.tracker_configs import ( + FileLoggingTrackerConfig, + HFResourceScannerConfig, +) SCRIPT = "tuning/sft_trainer.py" MODEL_NAME = "Maykeye/TinyLLama-v0" @@ -246,6 +251,38 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true(): assert os.path.exists(new_embeddings_file_path) +@pytest.mark.skipif( + not _is_package_available("HFResourceScanner"), + reason="Only runs if HFResourceScanner is installed", +) +def test_launch_with_HFResourceScanner_enabled(): + with tempfile.TemporaryDirectory() as tempdir: + setup_env(tempdir) + scanner_outfile = os.path.join( + tempdir, HFResourceScannerConfig.scanner_output_filename + ) + TRAIN_KWARGS = { + **BASE_LORA_KWARGS, + **{ + "output_dir": tempdir, + "save_model_dir": tempdir, + "lora_post_process_for_vllm": True, + "gradient_accumulation_steps": 1, + "trackers": ["hf_resource_scanner"], + "scanner_output_filename": scanner_outfile, + }, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + assert main() == 0 + assert os.path.exists(scanner_outfile) is True + with open(scanner_outfile, "r", encoding="utf-8") as f: + scanner_res = json.load(f) + assert scanner_res["time_data"] is not None + assert scanner_res["mem_data"] is not None + + def test_bad_script_path(): """Check for appropriate error for an invalid training script location""" with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index f2d4a1ee1..8faa3746c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -363,6 +363,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -390,6 +391,7 @@ def test_parse_arguments_defaults(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_defaults) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn is False @@ -400,14 +402,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) @@ -1053,12 +1055,18 @@ def _test_run_inference(checkpoint_path): def _validate_training( - tempdir, check_eval=False, train_logs_file="training_logs.jsonl" + tempdir, + check_eval=False, + train_logs_file="training_logs.jsonl", + check_scanner_file=False, ): assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir)) train_logs_file_path = "{}/{}".format(tempdir, train_logs_file) _validate_logfile(train_logs_file_path, check_eval) + if check_scanner_file: + _validate_hf_resource_scanner_file(tempdir) + def _validate_logfile(log_file_path, check_eval=False): train_log_contents = "" @@ -1073,6 +1081,18 @@ def _validate_logfile(log_file_path, check_eval=False): assert "validation_loss" in train_log_contents +def _validate_hf_resource_scanner_file(tempdir): + scanner_file_path = os.path.join(tempdir, "scanner_output.json") + assert os.path.exists(scanner_file_path) is True + assert os.path.getsize(scanner_file_path) > 0 + + with open(scanner_file_path, "r", encoding="utf-8") as f: + scanner_contents = json.load(f) + + assert scanner_contents["time_data"] is not None + assert scanner_contents["mem_data"] is not None + + def _get_checkpoint_path(dir_path): return os.path.join(dir_path, "checkpoint-5") diff --git a/tests/trackers/test_hf_resource_scanner_tracker.py b/tests/trackers/test_hf_resource_scanner_tracker.py new file mode 100644 index 000000000..04ce0e20c --- /dev/null +++ b/tests/trackers/test_hf_resource_scanner_tracker.py @@ -0,0 +1,88 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + + +# Standard +import copy +import os +import tempfile + +# Third Party +from transformers.utils.import_utils import _is_package_available +import pytest + +# First Party +from tests.test_sft_trainer import ( + DATA_ARGS, + MODEL_ARGS, + TRAIN_ARGS, + _get_checkpoint_path, + _test_run_causallm_ft, + _test_run_inference, + _validate_training, +) + +# Local +from tuning import sft_trainer +from tuning.config.tracker_configs import HFResourceScannerConfig, TrackerConfigFactory + +## HF Resource Scanner Tracker Tests + + +@pytest.mark.skipif( + not _is_package_available("HFResourceScanner"), + reason="Only runs if HFResourceScanner is installed", +) +def test_run_with_hf_resource_scanner_tracker(): + """Ensure that training succeeds with a good tracker name""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.trackers = ["hf_resource_scanner"] + + _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir) + _test_run_inference(_get_checkpoint_path(tempdir)) + + +@pytest.mark.skipif( + not _is_package_available("HFResourceScanner"), + reason="Only runs if HFResourceScanner is installed", +) +def test_sample_run_with_hf_resource_scanner_updated_filename(): + """Ensure that hf_resource_scanner output filename can be updated""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.gradient_accumulation_steps = 1 + train_args.output_dir = tempdir + + # add hf_resource_scanner to the list of requested tracker + train_args.trackers = ["hf_resource_scanner"] + + scanner_output_file = "scanner_output.json" + + tracker_configs = TrackerConfigFactory( + hf_resource_scanner_config=HFResourceScannerConfig( + scanner_output_filename=os.path.join(tempdir, scanner_output_file) + ) + ) + + sft_trainer.train( + MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs + ) + + # validate ft tuning configs + _validate_training(tempdir, check_scanner_file=True) diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py index 51c44aed1..bcadc7776 100644 --- a/tuning/config/tracker_configs.py +++ b/tuning/config/tracker_configs.py @@ -16,6 +16,11 @@ from dataclasses import dataclass +@dataclass +class HFResourceScannerConfig: + scanner_output_filename: str = "scanner_output.json" + + @dataclass class FileLoggingTrackerConfig: training_logs_filename: str = "training_logs.jsonl" @@ -80,3 +85,4 @@ class TrackerConfigFactory: file_logger_config: FileLoggingTrackerConfig = None aim_config: AimConfig = None mlflow_config: MLflowConfig = None + hf_resource_scanner_config: HFResourceScannerConfig = None diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index b3e28f686..0bf3a3b08 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -52,6 +52,7 @@ from tuning.config.tracker_configs import ( AimConfig, FileLoggingTrackerConfig, + HFResourceScannerConfig, MLflowConfig, TrackerConfigFactory, ) @@ -458,6 +459,7 @@ def get_parser(): peft_config.PromptTuningConfig, FileLoggingTrackerConfig, AimConfig, + HFResourceScannerConfig, QuantizedLoraConfig, FusedOpsAndKernelsConfig, AttentionAndDistributedPackingConfig, @@ -506,6 +508,8 @@ def parse_arguments(parser, json_config=None): Configuration for training log file. AimConfig Configuration for AIM stack. + HFResourceScannerConfig + Configuration for HFResourceScanner. QuantizedLoraConfig Configuration for quantized LoRA (a form of PEFT). FusedOpsAndKernelsConfig @@ -529,6 +533,7 @@ def parse_arguments(parser, json_config=None): prompt_tuning_config, file_logger_config, aim_config, + hf_resource_scanner_config, quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, @@ -547,6 +552,7 @@ def parse_arguments(parser, json_config=None): prompt_tuning_config, file_logger_config, aim_config, + hf_resource_scanner_config, quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, @@ -574,6 +580,7 @@ def parse_arguments(parser, json_config=None): tune_config, file_logger_config, aim_config, + hf_resource_scanner_config, quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, @@ -597,6 +604,7 @@ def main(): tune_config, file_logger_config, aim_config, + hf_resource_scanner_config, quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, @@ -611,7 +619,7 @@ def main(): logger.debug( "Input args parsed: \ model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \ - tune_config %s, file_logger_config, %s aim_config %s, \ + tune_config %s, file_logger_config %s, aim_config %s, hf_resource_scanner_config %s, \ quantized_lora_config %s, fusedops_kernels_config %s, \ attention_and_distributed_packing_config, %s,\ mlflow_config %s, fast_moe_config %s, \ @@ -623,6 +631,7 @@ def main(): tune_config, file_logger_config, aim_config, + hf_resource_scanner_config, quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, @@ -656,6 +665,7 @@ def main(): file_logger_config=file_logger_config, aim_config=aim_config, mlflow_config=mlflow_config, + hf_resource_scanner_config=hf_resource_scanner_config, ) if training_args.output_dir: diff --git a/tuning/trackers/hf_resource_scanner_tracker.py b/tuning/trackers/hf_resource_scanner_tracker.py new file mode 100644 index 000000000..50f503df4 --- /dev/null +++ b/tuning/trackers/hf_resource_scanner_tracker.py @@ -0,0 +1,47 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import logging + +# Third Party +from HFResourceScanner import Scanner # pylint: disable=import-error + +# Local +from .tracker import Tracker +from tuning.config.tracker_configs import HFResourceScannerConfig + + +class HFResourceScannerTracker(Tracker): + def __init__(self, tracker_config: HFResourceScannerConfig): + """Tracker which encodes callback to scan for resources using HFResourceScanner + + Args: + tracker_config (HFResourceScannerConfig): An instance of HFResourceScanner + tracker config which contains the location of output file. + """ + super().__init__(name="hf_resource_scanner", tracker_config=tracker_config) + # Get logger with root log level + self.logger = logging.getLogger() + + def get_hf_callback(self): + """Returns the HFResourceScanner object associated with this tracker. + + Returns: + HFResourceScanner: The file logging callback which inherits + transformers.TrainerCallback and records the metrics to a file. + """ + output_filename = self.config.scanner_output_filename + self.hf_callback = Scanner(output_fmt=output_filename) + return self.hf_callback diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index a550250a8..be1057fac 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -29,8 +29,14 @@ AIMSTACK_TRACKER = "aim" FILE_LOGGING_TRACKER = "file_logger" MLFLOW_TRACKER = "mlflow" +HF_RESOURCE_SCANNER_TRACKER = "hf_resource_scanner" -AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER, MLFLOW_TRACKER] +AVAILABLE_TRACKERS = [ + AIMSTACK_TRACKER, + FILE_LOGGING_TRACKER, + HF_RESOURCE_SCANNER_TRACKER, + MLFLOW_TRACKER, +] # Trackers which can be used @@ -39,6 +45,7 @@ # One time package check for list of external trackers. _is_aim_available = _is_package_available("aim") _is_mlflow_available = _is_package_available("mlflow") +_is_hf_resource_scanner_available = _is_package_available("HFResourceScanner") def _get_tracker_class(T, C): @@ -91,6 +98,35 @@ def _register_mlflow_tracker(): ) +def _register_hf_resource_scanner_tracker(): + # pylint: disable=import-outside-toplevel + if _is_hf_resource_scanner_available: + # Local + from .hf_resource_scanner_tracker import HFResourceScannerTracker + from tuning.config.tracker_configs import HFResourceScannerConfig + + HFResourceScannerTracker = _get_tracker_class( + HFResourceScannerTracker, HFResourceScannerConfig + ) + + REGISTERED_TRACKERS[HF_RESOURCE_SCANNER_TRACKER] = HFResourceScannerTracker + logger.info("Registered HFResourceScanner tracker") + else: + logger.info( + "Not registering HFResourceScanner tracker due to unavailablity of package.\n" + "Please install HFResourceScanner if you intend to use it.\n" + "\t pip install HFResourceScanner" + ) + + +def _is_tracker_installed(name): + if name == AIMSTACK_TRACKER: + return _is_aim_available + if name == HF_RESOURCE_SCANNER_TRACKER: + return _is_hf_resource_scanner_available + return False + + def _register_file_logging_tracker(): FileTracker = _get_tracker_class(FileLoggingTracker, FileLoggingTrackerConfig) REGISTERED_TRACKERS[FILE_LOGGING_TRACKER] = FileTracker @@ -109,6 +145,8 @@ def _register_trackers(): _register_file_logging_tracker() if MLFLOW_TRACKER not in REGISTERED_TRACKERS: _register_mlflow_tracker() + if HF_RESOURCE_SCANNER_TRACKER not in REGISTERED_TRACKERS: + _register_hf_resource_scanner_tracker() def _get_tracker_config_by_name(name: str, tracker_configs: TrackerConfigFactory):