Skip to content

Commit 39dc6cf

Browse files
authored
Merge branch 'main' into upgrade-python-version
2 parents d650ce9 + f22e243 commit 39dc6cf

File tree

10 files changed

+265
-7
lines changed

10 files changed

+265
-7
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ ARROW | ✅
8080

8181
As iterated above, we also support passing a HF dataset ID directly via `--training_data_path` argument.
8282

83+
**NOTE**: Due to the variety of supported data formats and file types, `--training_data_path` is handled as follows:
84+
- If `--training_data_path` ends in a valid file extension (e.g., .json, .csv), it is treated as a file.
85+
- If `--training_data_path` points to a valid folder, it is treated as a folder.
86+
- If neither of these are true, the data preprocessor tries to load `--training_data_path` as a Hugging Face (HF) dataset ID.
87+
8388
## Use cases supported with `training_data_path` argument
8489

8590
### 1. Data formats with a single sequence and a specified response_template to use for masking on completion.

build/Dockerfile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ ARG WHEEL_VERSION=""
2323
ARG ENABLE_AIM=false
2424
ARG ENABLE_MLFLOW=false
2525
ARG ENABLE_FMS_ACCELERATION=true
26+
ARG ENABLE_SCANNER=false
2627

2728
## Base Layer ##################################################################
2829
FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} AS base
@@ -111,6 +112,7 @@ ARG USER
111112
ARG USER_UID
112113
ARG ENABLE_FMS_ACCELERATION
113114
ARG ENABLE_AIM
115+
ARG ENABLE_SCANNER
114116

115117
RUN dnf install -y git && \
116118
# perl-Net-SSLeay.x86_64 and server_key.pem are installed with git as dependencies
@@ -154,7 +156,11 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
154156

155157
RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
156158
python -m pip install --user "$(head bdist_name)[mlflow]"; \
157-
fi
159+
fi
160+
161+
RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
162+
python -m pip install --user "$(head bdist_name)[scanner-dev]"; \
163+
fi
158164

159165
# Clean up the wheel module. It's only needed by flash-attn install
160166
RUN python -m pip uninstall wheel build -y && \

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ aim = ["aim>=3.19.0,<4.0"]
4444
mlflow = ["mlflow"]
4545
fms-accel = ["fms-acceleration>=0.6"]
4646
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]
47+
scanner-dev = ["HFResourceScanner>=0.1.0"]
4748

4849

4950
[tool.setuptools.packages.find]

tests/build/test_launch_script.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
"""
1717

1818
# Standard
19+
import json
1920
import os
2021
import tempfile
2122
import glob
2223

2324
# Third Party
2425
import pytest
26+
from transformers.utils.import_utils import _is_package_available
2527

2628
# First Party
2729
from build.accelerate_launch import main
@@ -31,7 +33,10 @@
3133
USER_ERROR_EXIT_CODE,
3234
INTERNAL_ERROR_EXIT_CODE,
3335
)
34-
from tuning.config.tracker_configs import FileLoggingTrackerConfig
36+
from tuning.config.tracker_configs import (
37+
FileLoggingTrackerConfig,
38+
HFResourceScannerConfig,
39+
)
3540

3641
SCRIPT = "tuning/sft_trainer.py"
3742
MODEL_NAME = "Maykeye/TinyLLama-v0"
@@ -246,6 +251,38 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true():
246251
assert os.path.exists(new_embeddings_file_path)
247252

248253

254+
@pytest.mark.skipif(
255+
not _is_package_available("HFResourceScanner"),
256+
reason="Only runs if HFResourceScanner is installed",
257+
)
258+
def test_launch_with_HFResourceScanner_enabled():
259+
with tempfile.TemporaryDirectory() as tempdir:
260+
setup_env(tempdir)
261+
scanner_outfile = os.path.join(
262+
tempdir, HFResourceScannerConfig.scanner_output_filename
263+
)
264+
TRAIN_KWARGS = {
265+
**BASE_LORA_KWARGS,
266+
**{
267+
"output_dir": tempdir,
268+
"save_model_dir": tempdir,
269+
"lora_post_process_for_vllm": True,
270+
"gradient_accumulation_steps": 1,
271+
"trackers": ["hf_resource_scanner"],
272+
"scanner_output_filename": scanner_outfile,
273+
},
274+
}
275+
serialized_args = serialize_args(TRAIN_KWARGS)
276+
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
277+
278+
assert main() == 0
279+
assert os.path.exists(scanner_outfile) is True
280+
with open(scanner_outfile, "r", encoding="utf-8") as f:
281+
scanner_res = json.load(f)
282+
assert scanner_res["time_data"] is not None
283+
assert scanner_res["mem_data"] is not None
284+
285+
249286
def test_bad_script_path():
250287
"""Check for appropriate error for an invalid training script location"""
251288
with tempfile.TemporaryDirectory() as tempdir:

tests/test_sft_trainer.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def test_parse_arguments(job_config):
363363
_,
364364
_,
365365
_,
366+
_,
366367
) = sft_trainer.parse_arguments(parser, job_config_copy)
367368
assert str(model_args.torch_dtype) == "torch.bfloat16"
368369
assert data_args.dataset_text_field == "output"
@@ -390,6 +391,7 @@ def test_parse_arguments_defaults(job_config):
390391
_,
391392
_,
392393
_,
394+
_,
393395
) = sft_trainer.parse_arguments(parser, job_config_defaults)
394396
assert str(model_args.torch_dtype) == "torch.bfloat16"
395397
assert model_args.use_flash_attn is False
@@ -400,14 +402,14 @@ def test_parse_arguments_peft_method(job_config):
400402
parser = sft_trainer.get_parser()
401403
job_config_pt = copy.deepcopy(job_config)
402404
job_config_pt["peft_method"] = "pt"
403-
_, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
405+
_, _, _, _, tune_config, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
404406
parser, job_config_pt
405407
)
406408
assert isinstance(tune_config, peft_config.PromptTuningConfig)
407409

408410
job_config_lora = copy.deepcopy(job_config)
409411
job_config_lora["peft_method"] = "lora"
410-
_, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
412+
_, _, _, _, tune_config, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
411413
parser, job_config_lora
412414
)
413415
assert isinstance(tune_config, peft_config.LoraConfig)
@@ -1053,12 +1055,18 @@ def _test_run_inference(checkpoint_path):
10531055

10541056

10551057
def _validate_training(
1056-
tempdir, check_eval=False, train_logs_file="training_logs.jsonl"
1058+
tempdir,
1059+
check_eval=False,
1060+
train_logs_file="training_logs.jsonl",
1061+
check_scanner_file=False,
10571062
):
10581063
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
10591064
train_logs_file_path = "{}/{}".format(tempdir, train_logs_file)
10601065
_validate_logfile(train_logs_file_path, check_eval)
10611066

1067+
if check_scanner_file:
1068+
_validate_hf_resource_scanner_file(tempdir)
1069+
10621070

10631071
def _validate_logfile(log_file_path, check_eval=False):
10641072
train_log_contents = ""
@@ -1073,6 +1081,18 @@ def _validate_logfile(log_file_path, check_eval=False):
10731081
assert "validation_loss" in train_log_contents
10741082

10751083

1084+
def _validate_hf_resource_scanner_file(tempdir):
1085+
scanner_file_path = os.path.join(tempdir, "scanner_output.json")
1086+
assert os.path.exists(scanner_file_path) is True
1087+
assert os.path.getsize(scanner_file_path) > 0
1088+
1089+
with open(scanner_file_path, "r", encoding="utf-8") as f:
1090+
scanner_contents = json.load(f)
1091+
1092+
assert scanner_contents["time_data"] is not None
1093+
assert scanner_contents["mem_data"] is not None
1094+
1095+
10761096
def _get_checkpoint_path(dir_path):
10771097
return os.path.join(dir_path, "checkpoint-5")
10781098

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# SPDX-License-Identifier: Apache-2.0
16+
# https://spdx.dev/learn/handling-license-info/
17+
18+
19+
# Standard
20+
import copy
21+
import os
22+
import tempfile
23+
24+
# Third Party
25+
from transformers.utils.import_utils import _is_package_available
26+
import pytest
27+
28+
# First Party
29+
from tests.test_sft_trainer import (
30+
DATA_ARGS,
31+
MODEL_ARGS,
32+
TRAIN_ARGS,
33+
_get_checkpoint_path,
34+
_test_run_causallm_ft,
35+
_test_run_inference,
36+
_validate_training,
37+
)
38+
39+
# Local
40+
from tuning import sft_trainer
41+
from tuning.config.tracker_configs import HFResourceScannerConfig, TrackerConfigFactory
42+
43+
## HF Resource Scanner Tracker Tests
44+
45+
46+
@pytest.mark.skipif(
47+
not _is_package_available("HFResourceScanner"),
48+
reason="Only runs if HFResourceScanner is installed",
49+
)
50+
def test_run_with_hf_resource_scanner_tracker():
51+
"""Ensure that training succeeds with a good tracker name"""
52+
with tempfile.TemporaryDirectory() as tempdir:
53+
train_args = copy.deepcopy(TRAIN_ARGS)
54+
train_args.trackers = ["hf_resource_scanner"]
55+
56+
_test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir)
57+
_test_run_inference(_get_checkpoint_path(tempdir))
58+
59+
60+
@pytest.mark.skipif(
61+
not _is_package_available("HFResourceScanner"),
62+
reason="Only runs if HFResourceScanner is installed",
63+
)
64+
def test_sample_run_with_hf_resource_scanner_updated_filename():
65+
"""Ensure that hf_resource_scanner output filename can be updated"""
66+
67+
with tempfile.TemporaryDirectory() as tempdir:
68+
train_args = copy.deepcopy(TRAIN_ARGS)
69+
train_args.gradient_accumulation_steps = 1
70+
train_args.output_dir = tempdir
71+
72+
# add hf_resource_scanner to the list of requested tracker
73+
train_args.trackers = ["hf_resource_scanner"]
74+
75+
scanner_output_file = "scanner_output.json"
76+
77+
tracker_configs = TrackerConfigFactory(
78+
hf_resource_scanner_config=HFResourceScannerConfig(
79+
scanner_output_filename=os.path.join(tempdir, scanner_output_file)
80+
)
81+
)
82+
83+
sft_trainer.train(
84+
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
85+
)
86+
87+
# validate ft tuning configs
88+
_validate_training(tempdir, check_scanner_file=True)

tuning/config/tracker_configs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
from dataclasses import dataclass
1717

1818

19+
@dataclass
20+
class HFResourceScannerConfig:
21+
scanner_output_filename: str = "scanner_output.json"
22+
23+
1924
@dataclass
2025
class FileLoggingTrackerConfig:
2126
training_logs_filename: str = "training_logs.jsonl"
@@ -80,3 +85,4 @@ class TrackerConfigFactory:
8085
file_logger_config: FileLoggingTrackerConfig = None
8186
aim_config: AimConfig = None
8287
mlflow_config: MLflowConfig = None
88+
hf_resource_scanner_config: HFResourceScannerConfig = None

tuning/sft_trainer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from tuning.config.tracker_configs import (
5353
AimConfig,
5454
FileLoggingTrackerConfig,
55+
HFResourceScannerConfig,
5556
MLflowConfig,
5657
TrackerConfigFactory,
5758
)
@@ -458,6 +459,7 @@ def get_parser():
458459
peft_config.PromptTuningConfig,
459460
FileLoggingTrackerConfig,
460461
AimConfig,
462+
HFResourceScannerConfig,
461463
QuantizedLoraConfig,
462464
FusedOpsAndKernelsConfig,
463465
AttentionAndDistributedPackingConfig,
@@ -506,6 +508,8 @@ def parse_arguments(parser, json_config=None):
506508
Configuration for training log file.
507509
AimConfig
508510
Configuration for AIM stack.
511+
HFResourceScannerConfig
512+
Configuration for HFResourceScanner.
509513
QuantizedLoraConfig
510514
Configuration for quantized LoRA (a form of PEFT).
511515
FusedOpsAndKernelsConfig
@@ -529,6 +533,7 @@ def parse_arguments(parser, json_config=None):
529533
prompt_tuning_config,
530534
file_logger_config,
531535
aim_config,
536+
hf_resource_scanner_config,
532537
quantized_lora_config,
533538
fusedops_kernels_config,
534539
attention_and_distributed_packing_config,
@@ -547,6 +552,7 @@ def parse_arguments(parser, json_config=None):
547552
prompt_tuning_config,
548553
file_logger_config,
549554
aim_config,
555+
hf_resource_scanner_config,
550556
quantized_lora_config,
551557
fusedops_kernels_config,
552558
attention_and_distributed_packing_config,
@@ -574,6 +580,7 @@ def parse_arguments(parser, json_config=None):
574580
tune_config,
575581
file_logger_config,
576582
aim_config,
583+
hf_resource_scanner_config,
577584
quantized_lora_config,
578585
fusedops_kernels_config,
579586
attention_and_distributed_packing_config,
@@ -597,6 +604,7 @@ def main():
597604
tune_config,
598605
file_logger_config,
599606
aim_config,
607+
hf_resource_scanner_config,
600608
quantized_lora_config,
601609
fusedops_kernels_config,
602610
attention_and_distributed_packing_config,
@@ -611,7 +619,7 @@ def main():
611619
logger.debug(
612620
"Input args parsed: \
613621
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
614-
tune_config %s, file_logger_config, %s aim_config %s, \
622+
tune_config %s, file_logger_config %s, aim_config %s, hf_resource_scanner_config %s, \
615623
quantized_lora_config %s, fusedops_kernels_config %s, \
616624
attention_and_distributed_packing_config, %s,\
617625
mlflow_config %s, fast_moe_config %s, \
@@ -623,6 +631,7 @@ def main():
623631
tune_config,
624632
file_logger_config,
625633
aim_config,
634+
hf_resource_scanner_config,
626635
quantized_lora_config,
627636
fusedops_kernels_config,
628637
attention_and_distributed_packing_config,
@@ -656,6 +665,7 @@ def main():
656665
file_logger_config=file_logger_config,
657666
aim_config=aim_config,
658667
mlflow_config=mlflow_config,
668+
hf_resource_scanner_config=hf_resource_scanner_config,
659669
)
660670

661671
if training_args.output_dir:

0 commit comments

Comments
 (0)