diff --git a/examples/scripts/sft_neuron.py b/examples/scripts/sft_neuron.py new file mode 100755 index 00000000000..28205b0409c --- /dev/null +++ b/examples/scripts/sft_neuron.py @@ -0,0 +1,282 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +# Full training +``` +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-5 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --eos_token '<|im_end|>' \ + --eval_strategy steps \ + --eval_steps 100 \ + --output_dir Qwen2-0.5B-SFT \ + --push_to_hub +``` + +# LoRA +``` +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-4 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --eos_token '<|im_end|>' \ + --eval_strategy steps \ + --eval_steps 100 \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --output_dir Qwen2-0.5B-SFT \ + --push_to_hub +``` +""" + +import argparse +import os +from functools import cached_property, lru_cache +from typing import Any + +import torch +from accelerate import Accelerator, ParallelismConfig, logging +from accelerate.state import AcceleratorState, PartialState +from accelerate.utils import DistributedType +from datasets import load_dataset +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES +from transformers.trainer_pt_utils import AcceleratorConfig +from transformers.utils import requires_backends + +from trl import ( + DatasetMixtureConfig, + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_dataset, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +logger = logging.get_logger(__name__) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +@lru_cache +def is_torch_neuron_available() -> bool: + try: + import torch_neuronx # noqa: F401 + return True + except ImportError: + return False + + +class NeuronPartialState(PartialState): + def _prepare_backend( + self, cpu: bool = False, sagemaker_dp=False, backend: str | None = None + ) -> tuple[str, DistributedType]: + if is_torch_neuron_available(): + # TODO: extend accelerate to add the proper DistributedType for Neuron. + return "neuron", DistributedType.MULTI_GPU + else: + return super()._prepare_backend(cpu=cpu, sagemaker_dp=sagemaker_dp, backend=backend) + +class NeuronSFTConfig(SFTConfig): + @cached_property + def _setup_devices(self) -> "torch.device": + if is_torch_neuron_available(): + from transformers.utils import logging as transformers_logging + + transformers_logger = transformers_logging.get_logger(__name__) + + requires_backends(self, ["torch"]) + transformers_logger.info("PyTorch: setting up devices") + + # For Neuron, we need to set the `ACCELERATE_TORCH_DEVICE` environment variable to "neuron" before + # initializing the AcceleratorState. + os.environ["ACCELERATE_TORCH_DEVICE"] = "neuron" + + # Build kwargs for PartialState; actual init happens below + accelerator_state_kwargs: dict[str, Any] = {"enabled": True, "use_configured_state": False} + if isinstance(self.accelerator_config, AcceleratorConfig): + accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop( + "use_configured_state", False + ) + if accelerator_state_kwargs["use_configured_state"]: + if PartialState._shared_state == {}: + raise ValueError( + "Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured " + "`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. " + ) + self.distributed_state = PartialState(cpu=self.use_cpu) + else: + AcceleratorState._reset_state(reset_partial_state=True) + self.distributed_state = None + + if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop( + "use_configured_state", False + ): + self.distributed_state = NeuronPartialState(**accelerator_state_kwargs) + + self._n_gpu = 0 + device = torch.device("neuron") + else: + device = super()._setup_devices + return device + + def _validate_args(self): + try: + super()._validate_args() + except Exception as e: + bf16_error_message = ( + "Your setup doesn't support bf16/gpu. You need to assign use_cpu if you want to train the model on CPU." + ) + if not (is_torch_neuron_available() and isinstance(e, ValueError) and bf16_error_message in str(e)): + raise e + +def main(script_args, training_args, model_args, dataset_args): + ################ + # Model init kwargs + ################ + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + dtype=model_args.dtype, + ) + + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + # Create model + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() + + tp_size = int(os.environ.get("TP_SIZE", "1")) + dp_size = int(os.environ.get("WORLD_SIZE", "1")) // tp_size + + tp_plan = { + "model.layers.*.self_attn.q_proj": "colwise", + "model.layers.*.self_attn.k_proj": "colwise", + "model.layers.*.self_attn.v_proj": "colwise", + "model.layers.*.self_attn.o_proj": "rowwise", + "model.layers.*.mlp.gate_proj": "colwise", + "model.layers.*.mlp.up_proj": "colwise", + "model.layers.*.mlp.down_proj": "rowwise", + } + + if tp_size > 1: + parallelism_config = ParallelismConfig( + dp_replicate_size=1, + dp_shard_size=dp_size, + tp_size=tp_size, + ) + training_args.parallelism_config = parallelism_config + + def _prepare_tp(self, *args): + # This function is used to prepare the model for tensor parallelism. In a real implementation, this would + # involve modifying the model's layers according to the tp_plan and the device mesh. + # However, with the current approach in Transformers, it is not required anymore. + return args + + Accelerator._prepare_tp = _prepare_tp.__get__(Accelerator) + + kwargs = {} + kwargs["tp_plan"] = tp_plan + kwargs["tp_size"] = tp_size + + if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures): + from transformers import AutoModelForImageTextToText + + model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **kwargs, **model_kwargs) + else: + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **kwargs, **model_kwargs) + + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + logger.warning( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + dataset = get_dataset(dataset_args) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") + + + # Initialize the SFT trainer + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=get_peft_config(model_args), + ) + + # Train the model + trainer.train() + + # Log training complete + trainer.accelerator.print("✅ Training completed.") + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.") + + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.") + + +def make_parser(subparsers: argparse._SubParsersAction | None = None): + dataclass_types = (ScriptArguments, NeuronSFTConfig, ModelConfig, DatasetMixtureConfig) + if subparsers is not None: + parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + + main(script_args, training_args, model_args, dataset_args)