Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ARG WHEEL_VERSION=""
ARG ENABLE_AIM=false
ARG ENABLE_MLFLOW=false
ARG ENABLE_FMS_ACCELERATION=true
ARG ENABLE_SCANNER=false

## Base Layer ##################################################################
FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} AS base
Expand Down Expand Up @@ -111,6 +112,7 @@ ARG USER
ARG USER_UID
ARG ENABLE_FMS_ACCELERATION
ARG ENABLE_AIM
ARG ENABLE_SCANNER

RUN dnf install -y git && \
# perl-Net-SSLeay.x86_64 and server_key.pem are installed with git as dependencies
Expand Down Expand Up @@ -154,7 +156,11 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \

RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[mlflow]"; \
fi
fi

RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[scanner-dev]"; \
fi

# Clean up the wheel module. It's only needed by flash-attn install
RUN python -m pip uninstall wheel build -y && \
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ aim = ["aim>=3.19.0,<4.0"]
mlflow = ["mlflow"]
fms-accel = ["fms-acceleration>=0.6"]
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]
scanner-dev = ["HFResourceScanner>=0.1.0"]


[tool.setuptools.packages.find]
Expand Down
39 changes: 38 additions & 1 deletion tests/build/test_launch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
"""

# Standard
import json
import os
import tempfile
import glob

# Third Party
import pytest
from transformers.utils.import_utils import _is_package_available

# First Party
from build.accelerate_launch import main
Expand All @@ -31,7 +33,10 @@
USER_ERROR_EXIT_CODE,
INTERNAL_ERROR_EXIT_CODE,
)
from tuning.config.tracker_configs import FileLoggingTrackerConfig
from tuning.config.tracker_configs import (
FileLoggingTrackerConfig,
HFResourceScannerConfig,
)

SCRIPT = "tuning/sft_trainer.py"
MODEL_NAME = "Maykeye/TinyLLama-v0"
Expand Down Expand Up @@ -246,6 +251,38 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true():
assert os.path.exists(new_embeddings_file_path)


@pytest.mark.skipif(
not _is_package_available("HFResourceScanner"),
reason="Only runs if HFResourceScanner is installed",
)
def test_launch_with_HFResourceScanner_enabled():
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
scanner_outfile = os.path.join(
tempdir, HFResourceScannerConfig.scanner_output_filename
)
TRAIN_KWARGS = {
**BASE_LORA_KWARGS,
**{
"output_dir": tempdir,
"save_model_dir": tempdir,
"lora_post_process_for_vllm": True,
"gradient_accumulation_steps": 1,
"trackers": ["hf_resource_scanner"],
"scanner_output_filename": scanner_outfile,
},
}
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args

assert main() == 0
assert os.path.exists(scanner_outfile) is True
with open(scanner_outfile, "r", encoding="utf-8") as f:
scanner_res = json.load(f)
assert scanner_res["time_data"] is not None
assert scanner_res["mem_data"] is not None


def test_bad_script_path():
"""Check for appropriate error for an invalid training script location"""
with tempfile.TemporaryDirectory() as tempdir:
Expand Down
26 changes: 23 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand Down Expand Up @@ -390,6 +391,7 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -400,14 +402,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down Expand Up @@ -1053,12 +1055,18 @@ def _test_run_inference(checkpoint_path):


def _validate_training(
tempdir, check_eval=False, train_logs_file="training_logs.jsonl"
tempdir,
check_eval=False,
train_logs_file="training_logs.jsonl",
check_scanner_file=False,
):
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
train_logs_file_path = "{}/{}".format(tempdir, train_logs_file)
_validate_logfile(train_logs_file_path, check_eval)

if check_scanner_file:
_validate_hf_resource_scanner_file(tempdir)


def _validate_logfile(log_file_path, check_eval=False):
train_log_contents = ""
Expand All @@ -1073,6 +1081,18 @@ def _validate_logfile(log_file_path, check_eval=False):
assert "validation_loss" in train_log_contents


def _validate_hf_resource_scanner_file(tempdir):
scanner_file_path = os.path.join(tempdir, "scanner_output.json")
assert os.path.exists(scanner_file_path) is True
assert os.path.getsize(scanner_file_path) > 0

with open(scanner_file_path, "r", encoding="utf-8") as f:
scanner_contents = json.load(f)

assert scanner_contents["time_data"] is not None
assert scanner_contents["mem_data"] is not None


def _get_checkpoint_path(dir_path):
return os.path.join(dir_path, "checkpoint-5")

Expand Down
88 changes: 88 additions & 0 deletions tests/trackers/test_hf_resource_scanner_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/


# Standard
import copy
import os
import tempfile

# Third Party
from transformers.utils.import_utils import _is_package_available
import pytest

# First Party
from tests.test_sft_trainer import (
DATA_ARGS,
MODEL_ARGS,
TRAIN_ARGS,
_get_checkpoint_path,
_test_run_causallm_ft,
_test_run_inference,
_validate_training,
)

# Local
from tuning import sft_trainer
from tuning.config.tracker_configs import HFResourceScannerConfig, TrackerConfigFactory

## HF Resource Scanner Tracker Tests


@pytest.mark.skipif(
not _is_package_available("HFResourceScanner"),
reason="Only runs if HFResourceScanner is installed",
)
def test_run_with_hf_resource_scanner_tracker():
"""Ensure that training succeeds with a good tracker name"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.trackers = ["hf_resource_scanner"]

_test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir)
_test_run_inference(_get_checkpoint_path(tempdir))


@pytest.mark.skipif(
not _is_package_available("HFResourceScanner"),
reason="Only runs if HFResourceScanner is installed",
)
def test_sample_run_with_hf_resource_scanner_updated_filename():
"""Ensure that hf_resource_scanner output filename can be updated"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.gradient_accumulation_steps = 1
train_args.output_dir = tempdir

# add hf_resource_scanner to the list of requested tracker
train_args.trackers = ["hf_resource_scanner"]

scanner_output_file = "scanner_output.json"

tracker_configs = TrackerConfigFactory(
hf_resource_scanner_config=HFResourceScannerConfig(
scanner_output_filename=os.path.join(tempdir, scanner_output_file)
)
)

sft_trainer.train(
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
)

# validate ft tuning configs
_validate_training(tempdir, check_scanner_file=True)
6 changes: 6 additions & 0 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from dataclasses import dataclass


@dataclass
class HFResourceScannerConfig:
scanner_output_filename: str = "scanner_output.json"


@dataclass
class FileLoggingTrackerConfig:
training_logs_filename: str = "training_logs.jsonl"
Expand Down Expand Up @@ -80,3 +85,4 @@ class TrackerConfigFactory:
file_logger_config: FileLoggingTrackerConfig = None
aim_config: AimConfig = None
mlflow_config: MLflowConfig = None
hf_resource_scanner_config: HFResourceScannerConfig = None
12 changes: 11 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from tuning.config.tracker_configs import (
AimConfig,
FileLoggingTrackerConfig,
HFResourceScannerConfig,
MLflowConfig,
TrackerConfigFactory,
)
Expand Down Expand Up @@ -458,6 +459,7 @@ def get_parser():
peft_config.PromptTuningConfig,
FileLoggingTrackerConfig,
AimConfig,
HFResourceScannerConfig,
QuantizedLoraConfig,
FusedOpsAndKernelsConfig,
AttentionAndDistributedPackingConfig,
Expand Down Expand Up @@ -506,6 +508,8 @@ def parse_arguments(parser, json_config=None):
Configuration for training log file.
AimConfig
Configuration for AIM stack.
HFResourceScannerConfig
Configuration for HFResourceScanner.
QuantizedLoraConfig
Configuration for quantized LoRA (a form of PEFT).
FusedOpsAndKernelsConfig
Expand All @@ -529,6 +533,7 @@ def parse_arguments(parser, json_config=None):
prompt_tuning_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand All @@ -547,6 +552,7 @@ def parse_arguments(parser, json_config=None):
prompt_tuning_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand Down Expand Up @@ -574,6 +580,7 @@ def parse_arguments(parser, json_config=None):
tune_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand All @@ -597,6 +604,7 @@ def main():
tune_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand All @@ -611,7 +619,7 @@ def main():
logger.debug(
"Input args parsed: \
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
tune_config %s, file_logger_config, %s aim_config %s, \
tune_config %s, file_logger_config %s, aim_config %s, hf_resource_scanner_config %s, \
quantized_lora_config %s, fusedops_kernels_config %s, \
attention_and_distributed_packing_config, %s,\
mlflow_config %s, fast_moe_config %s, \
Expand All @@ -623,6 +631,7 @@ def main():
tune_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand Down Expand Up @@ -656,6 +665,7 @@ def main():
file_logger_config=file_logger_config,
aim_config=aim_config,
mlflow_config=mlflow_config,
hf_resource_scanner_config=hf_resource_scanner_config,
)

if training_args.output_dir:
Expand Down
Loading
Loading