Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "sci
flash-attn = ["flash-attn>=2.5.3,<3.0"]
aim = ["aim>=3.19.0,<4.0"]
fms-accel = ["fms-acceleration>=0.1"]
scanner-dev = ["HFResourceScanner>=0.1.0"]
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]


Expand Down
27 changes: 27 additions & 0 deletions tests/build/test_launch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

# Third Party
import pytest
from transformers.utils.import_utils import _is_package_available

# First Party
from build.accelerate_launch import main
Expand Down Expand Up @@ -246,6 +247,32 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true():
assert os.path.exists(new_embeddings_file_path)


@pytest.mark.skipif(
not _is_package_available("HFResourceScanner"),
reason="Only runs if HFResourceScanner is installed",
)
def test_launch_with_add_scanner_callback():
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
TRAIN_KWARGS = {
**BASE_LORA_KWARGS,
**{
"output_dir": tempdir,
"save_model_dir": tempdir,
"lora_post_process_for_vllm": True,
"gradient_accumulation_steps": 1,
"add_scanner_callback": True,
},
}
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args

assert main() == 0

scanner_outfile = os.path.join(tempdir, "scanner_output.json")
assert os.path.exists(scanner_outfile)


def test_bad_script_path():
"""Check for appropriate error for an invalid training script location"""
with tempfile.TemporaryDirectory() as tempdir:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -360,6 +361,7 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -370,14 +372,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down
38 changes: 36 additions & 2 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_accelerate_available
from transformers.utils.import_utils import _is_package_available
from trl import SFTConfig, SFTTrainer
import transformers

Expand Down Expand Up @@ -66,6 +67,10 @@
from tuning.utils.logging import set_log_level
from tuning.utils.tokenizer_data_utils import tokenizer_and_embedding_resize

if _is_package_available("HFResourceScanner"):
# Third Party
from HFResourceScanner import Scanner # pylint: disable=import-error


def train(
model_args: configs.ModelArguments,
Expand Down Expand Up @@ -446,6 +451,13 @@ def get_parser():
help='Pass a json string representing K:V pairs to be associated\
to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'',
)
parser.add_argument(
"--add_scanner_callback",
type=bool,
required=False,
default=False,
help="whether to attach the scanner callback to measure memory and time of the training",
)
return parser


Expand Down Expand Up @@ -498,6 +510,7 @@ def parse_arguments(parser, json_config=None):
) = parser.parse_dict(json_config, allow_extra_keys=True)
peft_method = json_config.get("peft_method")
exp_metadata = json_config.get("exp_metadata")
add_scanner_callback = json_config.get("add_scanner_callback")
else:
(
model_args,
Expand All @@ -517,6 +530,7 @@ def parse_arguments(parser, json_config=None):

peft_method = additional.peft_method
exp_metadata = additional.exp_metadata
add_scanner_callback = additional.add_scanner_callback

if peft_method == "lora":
tune_config = lora_config
Expand All @@ -537,6 +551,7 @@ def parse_arguments(parser, json_config=None):
fusedops_kernels_config,
attention_and_distributed_packing_config,
exp_metadata,
add_scanner_callback,
)


Expand All @@ -558,6 +573,7 @@ def main():
fusedops_kernels_config,
attention_and_distributed_packing_config,
exp_metadata,
add_scanner_callback,
) = parse_arguments(parser, job_config)

# Function to set log level for python native logger and transformers training logger
Expand All @@ -568,7 +584,7 @@ def main():
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
tune_config %s, file_logger_config, %s aim_config %s, \
quantized_lora_config %s, fusedops_kernels_config %s, \
attention_and_distributed_packing_config %s exp_metadata %s",
attention_and_distributed_packing_config %s, exp_metadata %s, add_scanner_callback %s",
model_args,
data_args,
training_args,
Expand All @@ -580,6 +596,7 @@ def main():
fusedops_kernels_config,
attention_and_distributed_packing_config,
exp_metadata,
add_scanner_callback,
)
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc())
Expand Down Expand Up @@ -607,10 +624,27 @@ def main():

combined_tracker_configs.file_logger_config = file_logger_config
combined_tracker_configs.aim_config = aim_config
sc_callback = None

if training_args.output_dir:
os.makedirs(training_args.output_dir, exist_ok=True)
logger.info("using the output directory at %s", training_args.output_dir)
if add_scanner_callback:
if _is_package_available("HFResourceScanner"):
output_fmt = os.path.join(
training_args.output_dir, "scanner_output.json"
)
sc_callback = [Scanner(output_fmt=output_fmt)]
logging.info(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we use logger instead of logging?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh my bad, I didnt see this. Fixed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you

"Attaching HFResourceScanner as a callback with output_fmt: %s",
output_fmt,
)
else:
raise ValueError(
"add_scanner_callback was set to true, but HFResourceScanner is not installed. \
Install the package HFResourceScanner, or set add_scanner_callback to False."
)

try:
trainer, additional_train_info = train(
model_args=model_args,
Expand All @@ -619,7 +653,7 @@ def main():
peft_config=tune_config,
trainer_controller_args=trainer_controller_args,
tracker_configs=combined_tracker_configs,
additional_callbacks=None,
additional_callbacks=sc_callback,
exp_metadata=metadata,
quantized_lora_config=quantized_lora_config,
fusedops_kernels_config=fusedops_kernels_config,
Expand Down
Loading