diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml deleted file mode 100644 index d450c3b4..00000000 --- a/.github/workflows/build_documentation.yml +++ /dev/null @@ -1,18 +0,0 @@ -name: Build documentation - -on: - push: - branches: - - main - -jobs: - build: - uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main - with: - commit_sha: ${{ github.sha }} - package: alignment-handbook - path_to_docs: alignment-handbook/chapters/ - additional_args: --not_python_module - languages: en - secrets: - hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml deleted file mode 100644 index 96469836..00000000 --- a/.github/workflows/build_pr_documentation.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Build PR Documentation - -on: - pull_request: - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - build: - uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main - with: - commit_sha: ${{ github.event.pull_request.head.sha }} - pr_number: ${{ github.event.number }} - package: alignment-handbook - path_to_docs: alignment-handbook/chapters/ - additional_args: --not_python_module - languages: en \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 990795fd..2f0dc37b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,6 +26,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install ".[dev, torch]" + python -m pip install ".[dev]" - name: Run unit tests run: HF_TOKEN=$HF_TOKEN pytest -sv tests/ \ No newline at end of file diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml deleted file mode 100644 index d80d92c3..00000000 --- a/.github/workflows/upload_pr_documentation.yml +++ /dev/null @@ -1,16 +0,0 @@ -name: Upload PR Documentation - -on: - workflow_run: - workflows: ["Build PR Documentation"] - types: - - completed - -jobs: - build: - uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main - with: - package_name: alignment-handbook - secrets: - hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} - comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} \ No newline at end of file diff --git a/CITATION.cff b/CITATION.cff index 1e2b2293..fa56b2b1 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -26,4 +26,4 @@ authors: family-names: Wolf repository-code: 'https://github.com/huggingface/alignment-handbook' license: Apache-2.0 -version: 0.3.0.dev0 +version: 0.4.0.dev0 diff --git a/README.md b/README.md index c651eb2b..add54301 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ However, we know from the [InstructGPT](https://huggingface.co/papers/2203.02155 The Alignment Handbook aims to fill that gap by providing the community with a series of robust training recipes that span the whole pipeline. ## News πŸ—žοΈ +* **July 24, 2025**: We release the full [post-training recipe](recipes/smollm2/README.md) behind SmolLM3-3B: a state-of-the-art hybrid reasoning model πŸ’­ * **November 21, 2024**: We release the [recipe](recipes/smollm2/README.md) for fine-tuning SmolLM2-Instruct. * **August 18, 2024**: We release SmolLM-Instruct v0.2, along with the [recipe](recipes/smollm/README.md) to fine-tuning small LLMs πŸ’» * **April 12, 2024**: We release Zephyr 141B (A35B), in collaboration with Argilla and Kaist AI, along with the recipe to fine-tune Mixtral 8x22B with ORPO πŸͺ @@ -60,32 +61,35 @@ The initial release of the handbook will focus on the following techniques: ## Installation instructions -To run the code in this project, first, create a Python virtual environment using e.g. Conda: +To run the code in this project, first, create a Python virtual environment using e.g. `uv`: ```shell -conda create -n handbook python=3.10 && conda activate handbook +uv venv handbook --python 3.11 && source handbook/bin/activate && uv pip install --upgrade pip ``` -Next, install PyTorch `v2.1.2` - the precise version is important for reproducibility! Since this is hardware-dependent, we -direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/locally/). +> [!TIP] +> To install `uv`, follow the [UV Installation Guide](https://docs.astral.sh/uv/getting-started/installation/). + +Next, install PyTorch `v2.6.0` + +```shell +uv pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cu126 +``` + +Note that the precise version is important for reproducibility! Since this is hardware-dependent, we also direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/locally/). You can then install the remaining package dependencies as follows: ```shell -git clone https://github.com/huggingface/alignment-handbook.git -cd ./alignment-handbook/ -python -m pip install . +uv pip install . ``` You will also need Flash Attention 2 installed, which can be done by running: ```shell -python -m pip install flash-attn --no-build-isolation +uv pip install "flash-attn==2.7.4.post1" --no-build-isolation ``` -> **Note** -> If your machine has less than 96GB of RAM and many CPU cores, reduce the `MAX_JOBS` arguments, e.g. `MAX_JOBS=4 pip install flash-attn --no-build-isolation` - Next, log into your Hugging Face account as follows: ```shell @@ -106,7 +110,6 @@ You can now check out the `scripts` and `recipes` directories for instructions o β”œβ”€β”€ LICENSE β”œβ”€β”€ Makefile <- Makefile with commands like `make style` β”œβ”€β”€ README.md <- The top-level README for developers using this project -β”œβ”€β”€ chapters <- Educational content to render on hf.co/learn β”œβ”€β”€ recipes <- Recipe configs, accelerate configs, slurm scripts β”œβ”€β”€ scripts <- Scripts to train and evaluate chat models β”œβ”€β”€ setup.cfg <- Installation config (mostly used for configuring code quality & tests) @@ -121,10 +124,10 @@ If you find the content of this repo useful in your work, please cite it as foll ```bibtex @software{Tunstall_The_Alignment_Handbook, - author = {Tunstall, Lewis and Beeching, Edward and Lambert, Nathan and Rajani, Nazneen and Huang, Shengyi and Rasul, Kashif and Bartolome, Alvaro and M. Rush, Alexander and Wolf, Thomas}, + author = {Tunstall, Lewis and Beeching, Edward and Lambert, Nathan and Rajani, Nazneen and Huang, Shengyi and Rasul, Kashif and Bartolome, Alvaro, and PatiΓ±o, M. Carlos and M. Rush, Alexander and Wolf, Thomas}, license = {Apache-2.0}, title = {{The Alignment Handbook}}, url = {https://github.com/huggingface/alignment-handbook}, - version = {0.3.0.dev0} + version = {0.4.0.dev0} } ``` diff --git a/chapters/en/_toctree.yml b/chapters/en/_toctree.yml deleted file mode 100644 index e8fc7c0a..00000000 --- a/chapters/en/_toctree.yml +++ /dev/null @@ -1,4 +0,0 @@ -- title: Unit 0. Welcome to the RLHF Handbook! - sections: - - local: chapter0/introduction - title: What is this about? \ No newline at end of file diff --git a/chapters/en/chapter0/introduction.mdx b/chapters/en/chapter0/introduction.mdx deleted file mode 100644 index 26f500f4..00000000 --- a/chapters/en/chapter0/introduction.mdx +++ /dev/null @@ -1,3 +0,0 @@ -# Welcome to the RLHF Handbook! - -Stay tuned for more details πŸ€— \ No newline at end of file diff --git a/recipes/accelerate_configs/multi_gpu.yaml b/recipes/accelerate_configs/ddp.yaml similarity index 100% rename from recipes/accelerate_configs/multi_gpu.yaml rename to recipes/accelerate_configs/ddp.yaml diff --git a/recipes/accelerate_configs/fsdp_qlora.yaml b/recipes/accelerate_configs/fsdp_qlora.yaml deleted file mode 100644 index f28a0f10..00000000 --- a/recipes/accelerate_configs/fsdp_qlora.yaml +++ /dev/null @@ -1,25 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: FSDP -downcast_bf16: 'no' -fsdp_config: - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_backward_prefetch: BACKWARD_PRE - fsdp_cpu_ram_efficient_loading: true - fsdp_forward_prefetch: false - fsdp_offload_params: true - fsdp_sharding_strategy: FULL_SHARD - fsdp_state_dict_type: SHARDED_STATE_DICT - fsdp_sync_module_states: true - fsdp_use_orig_params: false -machine_rank: 0 -main_training_function: main -mixed_precision: 'no' -num_machines: 1 -num_processes: 2 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false \ No newline at end of file diff --git a/recipes/accelerate_configs/deepspeed_zero3.yaml b/recipes/accelerate_configs/zero3.yaml similarity index 96% rename from recipes/accelerate_configs/deepspeed_zero3.yaml rename to recipes/accelerate_configs/zero3.yaml index b5a1201f..21ab374f 100644 --- a/recipes/accelerate_configs/deepspeed_zero3.yaml +++ b/recipes/accelerate_configs/zero3.yaml @@ -19,4 +19,4 @@ same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false -use_cpu: false +use_cpu: false \ No newline at end of file diff --git a/recipes/constitutional-ai/README.md b/recipes/constitutional-ai/README.md index 08f4520a..8b3285c1 100644 --- a/recipes/constitutional-ai/README.md +++ b/recipes/constitutional-ai/README.md @@ -11,10 +11,10 @@ This repo includes the recipe for training the following models: You will require 8 GPUs (80GB of VRAM) to train the full model. ```shell # Step 1 - SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/constitutional-ai/sft/config_{grok,anthropic}.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/sft.py --config recipes/constitutional-ai/sft/config_{grok,anthropic}.yaml # Step 2 - DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/constitutional-ai/dpo/config_anthropic.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/dpo.py --config recipes/constitutional-ai/dpo/config_anthropic.yaml # Note that we did not include the DPO recipe for grok, as that model's seems overtrained and too snarky. ``` diff --git a/recipes/constitutional-ai/dpo/config_anthropic.yaml b/recipes/constitutional-ai/dpo/config_anthropic.yaml index 48f57676..46189468 100644 --- a/recipes/constitutional-ai/dpo/config_anthropic.yaml +++ b/recipes/constitutional-ai/dpo/config_anthropic.yaml @@ -4,13 +4,39 @@ torch_dtype: null # Data training arguments # For definitions, see: src/h4/training/config.py -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 - HuggingFaceH4/cai-conversation-harmless: 1.0 -dataset_splits: -- train_prefs -- test_prefs -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: train_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: test_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/cai-conversation-harmless + config: default + split: train_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/cai-conversation-harmless + config: default + split: test_prefs + columns: + - chosen + - rejected + weight: 1.0 + test_split_size: 3000 + seed: 0 +dataset_num_proc: 12 # DPOTrainer arguments bf16: true diff --git a/recipes/constitutional-ai/sft/config_anthropic.yaml b/recipes/constitutional-ai/sft/config_anthropic.yaml index 6724de0c..cfaba96e 100644 --- a/recipes/constitutional-ai/sft/config_anthropic.yaml +++ b/recipes/constitutional-ai/sft/config_anthropic.yaml @@ -6,13 +6,23 @@ attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" -dataset_mixer: - HuggingFaceH4/cai-conversation-harmless: 1.0 - HuggingFaceH4/ultrachat_200k: 1.0 -dataset_splits: -- train_sft -- test_sft -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/cai-conversation-harmless + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/ultrachat_200k + config: default + split: test_sft + columns: + - messages + weight: 1.0 + test_split_size: 1000 + seed: 0 +dataset_num_proc: 12 # SFT trainer config bf16: true diff --git a/recipes/constitutional-ai/sft/config_grok.yaml b/recipes/constitutional-ai/sft/config_grok.yaml index c79031dc..681fd36f 100644 --- a/recipes/constitutional-ai/sft/config_grok.yaml +++ b/recipes/constitutional-ai/sft/config_grok.yaml @@ -6,13 +6,23 @@ attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" -dataset_mixer: - HuggingFaceH4/grok-conversation-harmless: 0.15 - HuggingFaceH4/ultrachat_200k: 1.0 -dataset_splits: -- train_sft -- test_sft -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/grok-conversation-harmless + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/ultrachat_200k + config: default + split: test_sft + columns: + - messages + weight: 1.0 + test_split_size: 1000 + seed: 0 +dataset_num_proc: 12 # SFT trainer config bf16: true diff --git a/recipes/gpt2-nl/README.md b/recipes/gpt2-nl/README.md deleted file mode 100644 index 68eccfc8..00000000 --- a/recipes/gpt2-nl/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# Language Adaptation through Continued Pretraining - -This directory shows a base example of how to use continued pretraining and further tuning to adapt a language model to new data (e.g. a new language or domain). - -Three steps are needed: continued pretraining (`cpt`), supervised finetuning (`sft`), and direct preference optimisation (`dpo`). In this dummy example, we'll continue pretraining gpt2 on Dutch raw data, then sft-tuning it, and finally aligning it with DPO. Note that no extensive hyperparameters were tested in this example and that the output models are bad - it is just to show you how you can use the scripts for LM adaptation. The scripts work on 4x 3090s (24GB VRAM). If you have less powerful hardware you may need to reduce the batch size. - -## Continued pretraining - -This step will further pretrain the original `gpt2` model on plain Dutch text. Note that the script will by default use the `text` column in the dataset but you can change that by specifying `text_column` in the yaml file or on the command-line. - -```shell -ACCELERATE_LOG_LEVEL=info accelerate launch \ - --config_file recipes/accelerate_configs/multi_gpu.yaml \ - --num_processes 4 \ - scripts/run_cpt.py \ - recipes/gpt2-nl/cpt/config_full.yaml -``` - -## Supervised finetuning - -As other recipes, such as the famous zephyr-7b-beta recipe, have shown, we can then teach our model how to hold a conversation by finetuning it on chat-formatted data. As a base model, we'll make use of the output of the previous step. - -```shell -ACCELERATE_LOG_LEVEL=info accelerate launch \ - --config_file recipes/accelerate_configs/multi_gpu.yaml \ - --num_processes 4 \ - scripts/run_sft.py recipes/gpt2-nl/sft/config_full.yaml -``` - -## Direct preference optimisation - -Finally, to align the model better with feedback, we can finetune the SFT output with the DPO algorithm. This should improve the quality of the chat capabilities of the model. - -```shell -ACCELERATE_LOG_LEVEL=info accelerate launch \ - --config_file recipes/accelerate_configs/multi_gpu.yaml \ - --num_processes 4 \ - scripts/run_dpo.py recipes/gpt2-nl/dpo/config_full.yaml -``` - -## Conclusion - -With the steps above you can adapt an LM to a new domain, more data, or even a different language. Then, with sft and dpo, you can end up building a powerful chatbot, too! All within just three simple commands. It should be obvious that all of these follow a very similar approach, which makes them suitable to apply in parameterized slurm jobs. The neat part is that you can easily overwrite arguments in the yaml files by specifying the overwriting argument as a command-line argument, so the adaptability is also great. diff --git a/recipes/gpt2-nl/cpt/config_full.yaml b/recipes/gpt2-nl/cpt/config_full.yaml deleted file mode 100644 index 9c7056cf..00000000 --- a/recipes/gpt2-nl/cpt/config_full.yaml +++ /dev/null @@ -1,45 +0,0 @@ -# Model arguments -model_name_or_path: gpt2 -model_revision: main -torch_dtype: bfloat16 - -# Data training arguments -dataset_mixer: - yhavinga/mc4_nl_cleaned: 1.0 -dataset_splits: - - train -dataset_configs: - - tiny -preprocessing_num_workers: 12 - -# SFT trainer config -bf16: true -do_eval: False -eval_strategy: "no" -gradient_accumulation_steps: 1 -gradient_checkpointing: true -gradient_checkpointing_kwargs: - use_reentrant: False -hub_model_id: gpt2-cpt-dutch -hub_strategy: every_save -learning_rate: 2.0e-04 -log_level: info -logging_steps: 5 -logging_strategy: steps -lr_scheduler_type: cosine -max_seq_length: 1024 -max_steps: -1 -num_train_epochs: 1 -output_dir: data/gpt2-cpt-dutch -overwrite_output_dir: true -per_device_eval_batch_size: 8 -per_device_train_batch_size: 16 -push_to_hub: true -remove_unused_columns: true -report_to: -- wandb -save_strategy: "steps" -save_steps: 100 -save_total_limit: 1 -seed: 42 -warmup_ratio: 0.1 diff --git a/recipes/gpt2-nl/dpo/config_full.yaml b/recipes/gpt2-nl/dpo/config_full.yaml deleted file mode 100644 index 976c2537..00000000 --- a/recipes/gpt2-nl/dpo/config_full.yaml +++ /dev/null @@ -1,44 +0,0 @@ -# Model arguments -model_name_or_path: BramVanroy/gpt2-sft-dutch -model_revision: main -torch_dtype: bfloat16 - -# Data training arguments -# For definitions, see: src/h4/training/config.py -dataset_mixer: - BramVanroy/ultra_feedback_dutch: 1.0 -dataset_splits: -- train_prefs -- test_prefs -preprocessing_num_workers: 12 - -# DPOTrainer arguments -bf16: true -beta: 0.1 -do_eval: true -eval_strategy: steps -eval_steps: 100 -gradient_accumulation_steps: 8 -gradient_checkpointing: true -gradient_checkpointing_kwargs: - use_reentrant: False -hub_model_id: gpt2-dpo-dutch -learning_rate: 5.0e-7 -log_level: info -logging_steps: 10 -lr_scheduler_type: cosine -max_length: 1024 -max_prompt_length: 512 -num_train_epochs: 1 -optim: adamw_torch -output_dir: data/gpt2-dpo-dutch -per_device_train_batch_size: 8 -per_device_eval_batch_size: 8 -push_to_hub: true -save_strategy: "steps" -save_steps: 100 -save_total_limit: 1 -seed: 42 -warmup_ratio: 0.1 -report_to: -- wandb diff --git a/recipes/gpt2-nl/sft/config_full.yaml b/recipes/gpt2-nl/sft/config_full.yaml deleted file mode 100644 index f80d8efc..00000000 --- a/recipes/gpt2-nl/sft/config_full.yaml +++ /dev/null @@ -1,45 +0,0 @@ -# Model arguments -model_name_or_path: BramVanroy/gpt2-cpt-dutch -model_revision: main -torch_dtype: bfloat16 - -# Data training arguments -chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" -dataset_mixer: - BramVanroy/ultrachat_200k_dutch: 1.0 -dataset_splits: -- train_sft -- test_sft -preprocessing_num_workers: 12 - -# SFT trainer config -bf16: true -do_eval: true -eval_strategy: epoch -gradient_accumulation_steps: 1 -gradient_checkpointing: true -gradient_checkpointing_kwargs: - use_reentrant: False -hub_model_id: gpt2-sft-dutch -hub_strategy: every_save -learning_rate: 2.0e-05 -log_level: info -logging_steps: 5 -logging_strategy: steps -lr_scheduler_type: cosine -max_seq_length: 1024 -max_steps: -1 -num_train_epochs: 1 -output_dir: data/gpt2-sft-dutch -overwrite_output_dir: true -per_device_eval_batch_size: 8 -per_device_train_batch_size: 8 -push_to_hub: true -remove_unused_columns: true -report_to: -- wandb -save_strategy: "steps" -save_steps: 100 -save_total_limit: 1 -seed: 42 -warmup_ratio: 0.1 diff --git a/recipes/launch.slurm b/recipes/launch.slurm index d90fdae9..167f1863 100644 --- a/recipes/launch.slurm +++ b/recipes/launch.slurm @@ -1,34 +1,97 @@ #!/bin/bash +#SBATCH --job-name=handbook #SBATCH --ntasks-per-node=1 #SBATCH --exclusive #SBATCH --gres=gpu:8 #SBATCH --partition=hopper-prod # Adjust this for your cluster -#SBATCH --output=/fsx/h4/logs/%x-%j.out # Adjust this for your cluster -#SBATCH --err=/fsx/h4/logs/%x-%j.err # Adjust this for your cluster +#SBATCH --output=./logs/%x-%j.out +#SBATCH --error=./logs/%x-%j.err +#SBATCH --requeue +#SBATCH --time=3-00:00:00 +if [[ "$*" == *"--help"* ]]; then + echo "Usage: sbatch recipes/launch.slurm [options]" + echo "Options:" + echo " --model MODEL Model name" + echo " --task TASK Task name (e.g. sft, grpo)" + echo " --config SUFFIX Configuration suffix (e.g. demo, v00.00)" + echo " --accelerator CONFIG Accelerator configuration name (e.g. zero3)" + echo " --dp N Data parallelism for vLLM server (default: 1)" + echo " --tp N Tensor parallelism for vLLM server (default: 1)" + echo " --args \"ARGS\" Optional arguments to pass to the training script" + exit 0 +fi + +# Specific configuration optimized for the Hugging Face Compute Cluster +module load cuda/12.9 set -x -e source ~/.bashrc -conda activate handbook +source handbook/bin/activate +START_TIME=$(date +%s) echo "START TIME: $(date)" -MODEL=$1 -TASK=$2 -PRECISION=$3 -ACCELERATOR=$4 -OPTIONAL_ARGS=$5 +# Default values +MODEL="" +TASK="" +CONFIG_SUFFIX="" +ACCELERATOR="" +DP=1 +TP=1 +OPTIONAL_ARGS="" -# Training setup -NUM_NODES=$SLURM_NNODES -GPUS_PER_NODE=8 -WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE)) -# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match -CONFIG_FILE=recipes/$MODEL/$TASK/config_$PRECISION.yaml +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL="$2" + shift 2 + ;; + --task) + TASK="$2" + shift 2 + ;; + --config) + CONFIG_SUFFIX="$2" + shift 2 + ;; + --accelerator) + ACCELERATOR="$2" + shift 2 + ;; + --dp) + DP="$2" + shift 2 + ;; + --tp) + TP="$2" + shift 2 + ;; + --args) + OPTIONAL_ARGS="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +# Validate required arguments +if [[ -z "$MODEL" || -z "$TASK" || -z "$CONFIG_SUFFIX" || -z "$ACCELERATOR" ]]; then + echo "Error: Missing required arguments" + echo "Run with --help for usage information" + exit 1 +fi + + +CONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}') # Split the string into individual arguments IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS" - # Loop through the arguments and find the one with "--gradient_accumulation_steps" for arg in "${ARGS[@]}"; do if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then @@ -39,48 +102,75 @@ for arg in "${ARGS[@]}"; do done echo "Gradient accumulation steps: $GRAD_ACC_STEPS" -# so processes know who to talk to -MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) + +MODEL=$(grep 'model_name_or_path:' $CONFIG_FILE | awk '{print $2}') +REVISION=$(grep 'model_revision:' $CONFIG_FILE | head -n 1 | awk '{print $2}') + +# Distributed configuration +NUM_NODES=$SLURM_NNODES +GPUS_PER_NODE=8 +WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE)) +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) +MASTER_ADDR=${NODELIST[0]} # First node for main process MASTER_PORT=6000 +TRAIN_NODES=("${NODELIST[@]}") + +USE_VLLM="false" +if [[ -f "$CONFIG_FILE" ]] && grep -qE '^\s*use_vllm:\s*true' "$CONFIG_FILE"; then + USE_VLLM="true" +fi +# if using vllm +if [[ "$USE_VLLM" == "true" ]]; then + TRAIN_NODES=("${NODELIST[@]:0:$((NUM_NODES - 1))}") + VLLM_NODE=${NODELIST[-1]} # Last node + WORLD_SIZE=$((WORLD_SIZE - GPUS_PER_NODE)) + NUM_NODES=$((NUM_NODES - 1)) + srun --nodes=1 --ntasks=1 --nodelist=$VLLM_NODE trl vllm-serve --model $MODEL --revision $REVISION --tensor_parallel_size $TP --data_parallel_size $DP & + + OPTIONAL_ARGS="$OPTIONAL_ARGS --vllm_server_host=$VLLM_NODE" +fi + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 export CMD=" \ - scripts/run_$TASK.py $CONFIG_FILE $OPTIONAL_ARGS + scripts/$TASK.py --config $CONFIG_FILE $OPTIONAL_ARGS " - -export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \ +export LAUNCHER="ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \ --config_file recipes/accelerate_configs/$ACCELERATOR.yaml \ --gradient_accumulation_steps $GRAD_ACC_STEPS \ --num_machines $NUM_NODES \ --num_processes $WORLD_SIZE \ --main_process_ip $MASTER_ADDR \ --main_process_port $MASTER_PORT \ - --machine_rank \$SLURM_PROCID \ - --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \ + --machine_rank $SLURM_PROCID \ + --rdzv_backend=c10d \ --max_restarts 1 \ - --role \$(hostname -s): \ --tee 3 \ " - -# force crashing on nccl issues like hanging broadcast -export NCCL_ASYNC_ERROR_HANDLING=1 -# export NCCL_DEBUG=INFO -# export NCCL_DEBUG_SUBSYS=COLL -# export NCCL_SOCKET_NTHREADS=1 -# export NCCL_NSOCKS_PERTHREAD=1 -# export CUDA_LAUNCH_BLOCKING=1 - -# Specific configuration optimized for the Hugging Face Compute Cluster -# Be ye warned this may not work on other clusters! -module load cuda/12.1 - # srun error handling: # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +NODELIST=$(IFS=,; echo "${TRAIN_NODES[*]}") + SRUN_ARGS=" \ --wait=60 \ --kill-on-bad-exit=1 \ + --nodes=$NUM_NODES \ + --ntasks=$NUM_NODES \ + --nodelist=$NODELIST " +clear; srun $SRUN_ARGS bash -c "$LAUNCHER $CMD" 2>&1 -clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1 - -echo "END TIME: $(date)" \ No newline at end of file +END_TIME=$(date +%s) +echo "END TIME: $(date)" +ELAPSED_SECONDS=$((END_TIME - START_TIME)) +HOURS=$((ELAPSED_SECONDS / 3600)) +MINUTES=$(( (ELAPSED_SECONDS % 3600) / 60 )) +SECONDS=$((ELAPSED_SECONDS % 60)) +echo "TOTAL JOB TIME: ${HOURS}h ${MINUTES}m ${SECONDS}s (${ELAPSED_SECONDS} seconds)" diff --git a/recipes/pref_align_scan/README.md b/recipes/pref_align_scan/README.md index f9c81a51..bebae849 100644 --- a/recipes/pref_align_scan/README.md +++ b/recipes/pref_align_scan/README.md @@ -1,4 +1,5 @@ # Comparing Preference Alignment Algorithms + This directory contains various comparisons for three algorithms: DPO, IPO, and KTO. Each algorithm has been run in different hyperparameter configurations to study their performance. Two different models and datasets have been used to compare the performance of each algorithm: - zephyr-beta-sft and Ultrafeedback @@ -35,7 +36,7 @@ for config in "${configs[@]}"; do model_revision="${loss_type}-${beta}" # Submit the job - sbatch --job-name=${job_name} recipes/launch.slurm pref_align_scan dpo $config deepspeed_zero3 \ + sbatch --job-name=${job_name} recipes/launch.slurm pref_align_scan dpo $config zero3 \ "--beta=${beta} --loss_type=${loss_type} --output_dir=data/$config-7b-align-scan-${loss_type}-beta-${beta} --hub_model_revision=${model_revision}" done done diff --git a/recipes/pref_align_scan/dpo/config_openhermes.yaml b/recipes/pref_align_scan/dpo/config_openhermes.yaml index 43e8a230..18cd6942 100644 --- a/recipes/pref_align_scan/dpo/config_openhermes.yaml +++ b/recipes/pref_align_scan/dpo/config_openhermes.yaml @@ -3,12 +3,25 @@ model_name_or_path: teknium/OpenHermes-2.5-Mistral-7B torch_dtype: null # Data training arguments -dataset_mixer: - HuggingFaceH4/orca_dpo_pairs: 1.0 -dataset_splits: -- train_prefs -- test_prefs -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/orca_dpo_pairs + config: default + split: train_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/orca_dpo_pairs + config: default + split: test_prefs + columns: + - chosen + - rejected + weight: 1.0 + test_split_size: 500 + seed: 0 +dataset_num_proc: 12 # Training arguments with sensible defaults bf16: true diff --git a/recipes/pref_align_scan/dpo/config_zephyr.yaml b/recipes/pref_align_scan/dpo/config_zephyr.yaml index 0dd6d379..6384fca8 100644 --- a/recipes/pref_align_scan/dpo/config_zephyr.yaml +++ b/recipes/pref_align_scan/dpo/config_zephyr.yaml @@ -3,12 +3,25 @@ model_name_or_path: alignment-handbook/zephyr-7b-sft-full torch_dtype: null # Data training arguments -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 -dataset_splits: -- train_prefs -- test_prefs -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: train_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: test_prefs + columns: + - chosen + - rejected + weight: 1.0 + test_split_size: 2000 + seed: 0 +dataset_num_proc: 12 # Training arguments with sensible defaults bf16: true diff --git a/recipes/pref_align_scan/launch_scan.sh b/recipes/pref_align_scan/launch_scan.sh index 334b9472..ee050527 100644 --- a/recipes/pref_align_scan/launch_scan.sh +++ b/recipes/pref_align_scan/launch_scan.sh @@ -17,7 +17,7 @@ for config in "${configs[@]}"; do model_revision="${loss_type}-${beta}" # Submit the job - sbatch --job-name=${job_name} recipes/launch.slurm pref_align_scan dpo $config deepspeed_zero3 \ + sbatch --job-name=${job_name} recipes/launch.slurm pref_align_scan dpo $config zero3 \ "--beta=${beta} --loss_type=${loss_type} --output_dir=data/$config-7b-align-scan-${loss_type}-beta-${beta} --hub_model_revision=${model_revision}" done done diff --git a/recipes/smollm/README.md b/recipes/smollm/README.md index d636ed3f..467fe236 100644 --- a/recipes/smollm/README.md +++ b/recipes/smollm/README.md @@ -15,5 +15,5 @@ Follow the installation instructions in https://github.com/huggingface/alignment We train the models on 8 GPUs using the following command: ```shell -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/smollm/sft/config.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/sft.py --config recipes/smollm/sft/config.yaml ``` diff --git a/recipes/smollm/sft/config.yaml b/recipes/smollm/sft/config.yaml index 2462191c..b816e688 100644 --- a/recipes/smollm/sft/config.yaml +++ b/recipes/smollm/sft/config.yaml @@ -1,22 +1,76 @@ # Model arguments model_name_or_path: HuggingFaceTB/SmolLM-360M model_revision: main -tokenizer_name_or_path: HuggingFaceTB/SmolLM-360M-Instruct # Custom tokenizer with <|im_start|> and <|im_end|> tokens torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments -dataset_mixer: - HuggingFaceTB/Magpie-Pro-300K-Filtered-H4: 1.0 - HuggingFaceTB/self-oss-instruct-sc2-H4: 1.0 - HuggingFaceTB/OpenHermes-2.5-H4: 0.001 - HuggingFaceTB/everyday-conversations-llama3.1-2k: 1.0 - HuggingFaceTB/instruct-data-basics-smollm-H4: 1.0 - -dataset_splits: -- train_sft -- test_sft -preprocessing_num_workers: 36 +chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" +dataset_mixture: + datasets: + - id: HuggingFaceTB/Magpie-Pro-300K-Filtered-H4 + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceTB/Magpie-Pro-300K-Filtered-H4 + config: default + split: test_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceTB/self-oss-instruct-sc2-H4 + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceTB/self-oss-instruct-sc2-H4 + config: default + split: test_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceTB/OpenHermes-2.5-H4 + config: default + split: train_sft + columns: + - messages + weight: 0.001 + - id: HuggingFaceTB/OpenHermes-2.5-H4 + config: default + split: test_sft + columns: + - messages + weight: 0.001 + - id: HuggingFaceTB/everyday-conversations-llama3.1-2k + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceTB/everyday-conversations-llama3.1-2k + config: default + split: test_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceTB/instruct-data-basics-smollm-H4 + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceTB/instruct-data-basics-smollm-H4 + config: default + split: test_sft + columns: + - messages + weight: 1.0 + test_split_size: 0.05 + seed: 0 +dataset_num_proc: 24 # SFT trainer config bf16: true @@ -24,7 +78,7 @@ dataset_kwargs: add_special_tokens: false # We already wrap and in the chat template append_concat_token: false # No need to add across samples do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/smollm2/README.md b/recipes/smollm2/README.md index 2afc8844..ee655433 100644 --- a/recipes/smollm2/README.md +++ b/recipes/smollm2/README.md @@ -12,17 +12,17 @@ We train the 1.7B on 8 GPUs using the following command: ```shell # SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/smollm2/sft/config.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/sft.py --config recipes/smollm2/sft/config.yaml # DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/smollm2/dpo/config.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/dpo.py --config recipes/smollm2/dpo/config.yaml ``` For the 135M and 360M we use [smol-smoltalk](https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk) dataset for SFT and UltraFeedback for DPO: ```shell # SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/smollm2/sft/config_smol.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/sft.py --config recipes/smollm2/sft/config_smol.yaml # DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/smollm2/dpo/config_smol.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/dpo.py --config recipes/smollm2/dpo/config_smol.yaml ``` \ No newline at end of file diff --git a/recipes/smollm2/dpo/config.yaml b/recipes/smollm2/dpo/config.yaml index 1f35f8dc..6ad6e47b 100644 --- a/recipes/smollm2/dpo/config.yaml +++ b/recipes/smollm2/dpo/config.yaml @@ -3,13 +3,25 @@ model_name_or_path: loubnabnl/smollm2-1.7B-sft torch_dtype: bfloat16 # Data training arguments -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 - -dataset_splits: -- train_prefs -- test_prefs -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: train_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: test_prefs + columns: + - chosen + - rejected + weight: 1.0 + test_split_size: 2000 + seed: 0 +dataset_num_proc: 12 # DPOTrainer arguments bf16: true diff --git a/recipes/smollm2/dpo/config_smol.yaml b/recipes/smollm2/dpo/config_smol.yaml index b629bc3a..7679a7d1 100644 --- a/recipes/smollm2/dpo/config_smol.yaml +++ b/recipes/smollm2/dpo/config_smol.yaml @@ -3,13 +3,25 @@ model_name_or_path: loubnabnl/smollm2-360M-sft # we use this script for the 135M torch_dtype: bfloat16 # Data training arguments -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 - -dataset_splits: -- train_prefs -- test_prefs -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: train_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: test_prefs + columns: + - chosen + - rejected + weight: 1.0 + test_split_size: 2000 + seed: 0 +dataset_num_proc: 12 # DPOTrainer arguments bf16: true diff --git a/recipes/smollm2/sft/config.yaml b/recipes/smollm2/sft/config.yaml index 6f6cd516..78c9b77f 100644 --- a/recipes/smollm2/sft/config.yaml +++ b/recipes/smollm2/sft/config.yaml @@ -1,26 +1,21 @@ # Model arguments model_name_or_path: HuggingFaceTB/SmolLM2-1.7B model_revision: main -tokenizer_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct # Custom tokenizer with <|im_start|> and <|im_end|> tokens torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments -dataset_mixer: - HuggingFaceTB/smoltalk: 1.0 - -dataset_configs: -- all - -dataset_splits: -- train -- test -preprocessing_num_workers: 36 +chat_template: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" +dataset_name: HuggingFaceTB/smoltalk +dataset_config: all +dataset_train_split: train +dataset_test_split: test +dataset_num_proc: 24 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/smollm2/sft/config_smol.yaml b/recipes/smollm2/sft/config_smol.yaml index 70be48cc..b8285f99 100644 --- a/recipes/smollm2/sft/config_smol.yaml +++ b/recipes/smollm2/sft/config_smol.yaml @@ -1,23 +1,20 @@ # Model arguments model_name_or_path: HuggingFaceTB/SmolLM2-360M # we use this script for the 135M model too model_revision: main -tokenizer_name_or_path: HuggingFaceTB/SmolLM2-360M-Instruct # Custom tokenizer with <|im_start|> and <|im_end|> tokens torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments -dataset_mixer: - HuggingFaceTB/smol-smoltalk: 1.0 - -dataset_splits: -- train -- test -preprocessing_num_workers: 36 +chat_template: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" +dataset_name: HuggingFaceTB/smol-smoltalk +dataset_train_split: train +dataset_test_split: test +dataset_num_proc: 24 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/smollm3/dpo/apo.yaml b/recipes/smollm3/dpo/apo.yaml index 33925468..bea42f30 100644 --- a/recipes/smollm3/dpo/apo.yaml +++ b/recipes/smollm3/dpo/apo.yaml @@ -1,7 +1,7 @@ # Config for 2 node, with GBS 32 # Model arguments -model_name_or_path: HuggingFaceTB/SmolLM3-SFT -model_revision: v70.00-step-000000928 +model_name_or_path: HuggingFaceTB/SmolLM3-3B-checkpoints +model_revision: it-SFT torch_dtype: bfloat16 attn_implementation: flash_attention_2 trust_remote_code: true @@ -36,10 +36,8 @@ gradient_accumulation_steps: 2 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true -hub_model_id: HuggingFaceTB/SmolLM3-DPO -hub_model_revision: v20.00 -output_dir: data/SmolLM3-DPO-v20.00 -run_name: SmolLM3-DPO-v20.00 +output_dir: data/SmolLM3-DPO +run_name: SmolLM3-DPO hub_strategy: every_save learning_rate: 1.0e-06 @@ -65,7 +63,4 @@ save_steps: 0.5 save_total_limit: 1 seed: 42 use_liger_kernel: true -warmup_ratio: 0.1 -wandb_entity: huggingface -wandb_project: SmolLM3 -wandb_run_group: SmolLM3-DPO \ No newline at end of file +warmup_ratio: 0.1 \ No newline at end of file diff --git a/recipes/smollm3/sft/mid.yaml b/recipes/smollm3/sft/mid.yaml index 24638f7f..84619ad5 100644 --- a/recipes/smollm3/sft/mid.yaml +++ b/recipes/smollm3/sft/mid.yaml @@ -1,6 +1,6 @@ # Config for 8 nodes, with GBS 128 # Model arguments -model_name_or_path: HuggingFaceTB/smollm3-3B-base-final-remote-code +model_name_or_path: HuggingFaceTB/SmolLM3-3B-Base model_revision: main torch_dtype: bfloat16 attn_implementation: flash_attention_2 @@ -23,7 +23,7 @@ dataset_mixture: - messages weight: 1.0 seed: 0 -dataset_num_proc: 12 +dataset_num_proc: 48 eos_token: <|im_end|> # SFT trainer config @@ -35,9 +35,7 @@ gradient_accumulation_steps: 2 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true -hub_model_id: HuggingFaceTB/SmolLM3-SFT -hub_model_revision: v14.00 -output_dir: data/SmolLM3-SFT-v14.00 +output_dir: data/SmolLM3-Mid hub_strategy: every_save learning_rate: 2.0e-05 log_level: info @@ -62,8 +60,5 @@ save_total_limit: 1 seed: 42 use_liger_kernel: true warmup_ratio: 0.03 -wandb_entity: huggingface -wandb_project: SmolLM3 -wandb_run_group: SmolLM3-SFT -run_name: v14.00 +run_name: smollm3-midtraining average_tokens_across_devices: true \ No newline at end of file diff --git a/recipes/smollm3/sft/sft.yaml b/recipes/smollm3/sft/sft.yaml index 30666b79..d0be5024 100644 --- a/recipes/smollm3/sft/sft.yaml +++ b/recipes/smollm3/sft/sft.yaml @@ -1,7 +1,7 @@ # Config for 8 nodes # Model arguments -model_name_or_path: HuggingFaceTB/SmolLM3-SFT -model_revision: v14.01-step-000033664 +model_name_or_path: HuggingFaceTB/SmolLM3-3B-checkpoints +model_revision: it-mid-training torch_dtype: bfloat16 attn_implementation: flash_attention_2 trust_remote_code: true @@ -199,9 +199,7 @@ gradient_accumulation_steps: 2 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true -hub_model_id: HuggingFaceTB/SmolLM3-SFT -hub_model_revision: v70.00 -output_dir: data/SmolLM3-SFT-v70.00 +output_dir: data/SmolLM3-SFT hub_strategy: every_save learning_rate: 2.0e-05 log_level: info @@ -226,8 +224,5 @@ save_total_limit: 1 seed: 42 use_liger_kernel: true warmup_ratio: 0.03 -wandb_entity: huggingface -wandb_project: SmolLM3 -wandb_run_group: SmolLM3-SFT -run_name: v70.00 +run_name: smollm3-sft-training average_tokens_across_devices: true \ No newline at end of file diff --git a/recipes/starchat2-15b/README.md b/recipes/starchat2-15b/README.md index 06e807f1..5eb55b21 100644 --- a/recipes/starchat2-15b/README.md +++ b/recipes/starchat2-15b/README.md @@ -14,8 +14,8 @@ You will require 8 GPUs (80GB of VRAM) to train the full model - alternatively, ```shell # Step 1 - SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/starchat2-15b/sft/config_v0.1.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/sft.py --config recipes/starchat2-15b/sft/config_v0.1.yaml # Step 2 - DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/starchat2-15b/dpo/config_v0.1.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/dpo.py --config recipes/starchat2-15b/dpo/config_v0.1.yaml ``` diff --git a/recipes/starchat2-15b/dpo/config_v0.1.yaml b/recipes/starchat2-15b/dpo/config_v0.1.yaml index cf0ddb3f..a33a349b 100644 --- a/recipes/starchat2-15b/dpo/config_v0.1.yaml +++ b/recipes/starchat2-15b/dpo/config_v0.1.yaml @@ -4,13 +4,39 @@ torch_dtype: bfloat16 # Data training arguments # For definitions, see: src/h4/training/config.py -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 - HuggingFaceH4/orca_dpo_pairs: 1.0 -dataset_splits: -- train_prefs -- test_prefs -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: train_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/orca_dpo_pairs + config: default + split: train_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: test_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/orca_dpo_pairs + config: default + split: test_prefs + columns: + - chosen + - rejected + weight: 1.0 + test_split_size: 2000 + seed: 0 +dataset_num_proc: 12 # DPOTrainer arguments bf16: true diff --git a/recipes/starchat2-15b/sft/config_v0.1.yaml b/recipes/starchat2-15b/sft/config_v0.1.yaml index f5892de5..9faecabf 100644 --- a/recipes/starchat2-15b/sft/config_v0.1.yaml +++ b/recipes/starchat2-15b/sft/config_v0.1.yaml @@ -6,16 +6,71 @@ attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" -dataset_mixer: - HuggingFaceH4/airoboros-3.2: 1.0 - HuggingFaceH4/Code-Feedback: 1.0 - HuggingFaceH4/orca-math-word-problems-200k: 1.0 - HuggingFaceH4/SystemChat: 1.0 - HuggingFaceH4/capybara: 1.0 -dataset_splits: -- train_sft -- test_sft -preprocessing_num_workers: 24 +dataset_mixture: + datasets: + - id: HuggingFaceH4/airoboros-3.2 + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/Code-Feedback + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/orca-math-word-problems-200k + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/SystemChat + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/capybara + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/airoboros-3.2 + config: default + split: test_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/Code-Feedback + config: default + split: test_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/orca-math-word-problems-200k + config: default + split: test_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/SystemChat + config: default + split: test_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/capybara + config: default + split: test_sft + columns: + - messages + weight: 1.0 + test_split_size: 1000 + seed: 0 +dataset_num_proc: 24 # SFT trainer config bf16: true diff --git a/recipes/zephyr-141b-A35b/README.md b/recipes/zephyr-141b-A35b/README.md index 203cd14c..befc757b 100644 --- a/recipes/zephyr-141b-A35b/README.md +++ b/recipes/zephyr-141b-A35b/README.md @@ -19,5 +19,5 @@ Under the hood, this calls the following script which can be adapted to other mo ```shell -ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file recipes/accelerate_configs/fsdp.yaml scripts/run_orpo.py recipes/zephyr-141b-A35b/orpo/config_full.yaml +ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file recipes/accelerate_configs/fsdp.yaml scripts/run_orpo.py --config recipes/zephyr-141b-A35b/orpo/config_full.yaml ``` \ No newline at end of file diff --git a/recipes/zephyr-141b-A35b/orpo/config_full.yaml b/recipes/zephyr-141b-A35b/orpo/config_full.yaml index b5210132..d02511bb 100644 --- a/recipes/zephyr-141b-A35b/orpo/config_full.yaml +++ b/recipes/zephyr-141b-A35b/orpo/config_full.yaml @@ -6,11 +6,9 @@ attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" -dataset_mixer: - argilla/distilabel-capybara-dpo-7k-binarized: 1.0 -dataset_splits: -- train -preprocessing_num_workers: 8 +dataset_name: argilla/distilabel-capybara-dpo-7k-binarized +dataset_train_split: train +dataset_num_proc: 8 # ORPOTrainer arguments bf16: true diff --git a/recipes/zephyr-7b-beta/README.md b/recipes/zephyr-7b-beta/README.md index 8c082f17..585372c2 100644 --- a/recipes/zephyr-7b-beta/README.md +++ b/recipes/zephyr-7b-beta/README.md @@ -15,10 +15,10 @@ See below for commands to train these models using either DeepSpeed ZeRO-3 or Lo You will require 8 GPUs (80GB of VRAM) to train the full model. ```shell # Step 1 - SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_full.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/sft.py --config recipes/zephyr-7b-beta/sft/config_full.yaml # Step 2 - DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_full.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/dpo.py --config recipes/zephyr-7b-beta/dpo/config_full.yaml ``` ## QLoRA training examples @@ -26,10 +26,10 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con Train faster with flash-attention 2 (GPU supporting FA2: A100, H100, etc) ```````shell # Step 1 - SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/ddp.yaml --num_processes=1 scripts/sft.py --config recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true # Step 2 - DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/ddp.yaml --num_processes=1 scripts/dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml ``````` P.S. Using Flash Attention also allows you to drastically increase the batch size (x2 in my case) @@ -37,8 +37,8 @@ P.S. Using Flash Attention also allows you to drastically increase the batch siz Train without flash-attention (i.e. via PyTorch's scaled dot product attention): ```````shell # Step 1 - SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --attn_implementation=sdpa +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/ddp.yaml --num_processes=1 scripts/sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --attn_implementation=sdpa # Step 2 - DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --attn_implementation=sdpa +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/ddp.yaml --num_processes=1 scripts/dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --attn_implementation=sdpa ``````` \ No newline at end of file diff --git a/recipes/zephyr-7b-beta/dpo/config_full.yaml b/recipes/zephyr-7b-beta/dpo/config_full.yaml index 12b47b18..e3df56d3 100644 --- a/recipes/zephyr-7b-beta/dpo/config_full.yaml +++ b/recipes/zephyr-7b-beta/dpo/config_full.yaml @@ -3,13 +3,25 @@ model_name_or_path: alignment-handbook/zephyr-7b-sft-full torch_dtype: null # Data training arguments -# For definitions, see: src/h4/training/config.py -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 -dataset_splits: -- train_prefs -- test_prefs -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: train_prefs + columns: + - chosen + - rejected + weight: 1.0 + - id: HuggingFaceH4/ultrafeedback_binarized + config: default + split: test_prefs + columns: + - chosen + - rejected + weight: 1.0 + test_split_size: 2000 + seed: 0 +dataset_num_proc: 12 # DPOTrainer arguments bf16: true diff --git a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml index 46fbccd9..a125db77 100644 --- a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml +++ b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml @@ -25,7 +25,7 @@ dataset_mixer: dataset_splits: - train_prefs - test_prefs -preprocessing_num_workers: 12 +dataset_num_proc: 12 # DPOTrainer arguments bf16: true diff --git a/recipes/zephyr-7b-beta/sft/config_full.yaml b/recipes/zephyr-7b-beta/sft/config_full.yaml index f1e8457d..8c6c7e97 100644 --- a/recipes/zephyr-7b-beta/sft/config_full.yaml +++ b/recipes/zephyr-7b-beta/sft/config_full.yaml @@ -6,12 +6,23 @@ attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" -dataset_mixer: - HuggingFaceH4/ultrachat_200k: 1.0 -dataset_splits: -- train_sft -- test_sft -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/ultrachat_200k + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/ultrachat_200k + config: default + split: test_sft + columns: + - messages + weight: 1.0 + test_split_size: 1000 + seed: 0 +dataset_num_proc: 12 # SFT trainer config bf16: true @@ -38,7 +49,7 @@ per_device_train_batch_size: 16 push_to_hub: true remove_unused_columns: true report_to: -- tensorboard +- wandb save_strategy: "steps" save_steps: 100 save_total_limit: 1 diff --git a/recipes/zephyr-7b-beta/sft/config_qlora.yaml b/recipes/zephyr-7b-beta/sft/config_qlora.yaml index 4881757c..4809f899 100644 --- a/recipes/zephyr-7b-beta/sft/config_qlora.yaml +++ b/recipes/zephyr-7b-beta/sft/config_qlora.yaml @@ -21,12 +21,23 @@ lora_target_modules: # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" -dataset_mixer: - HuggingFaceH4/ultrachat_200k: 1.0 -dataset_splits: -- train_sft -- test_sft -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: HuggingFaceH4/ultrachat_200k + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/ultrachat_200k + config: default + split: test_sft + columns: + - messages + weight: 1.0 + test_split_size: 1000 + seed: 0 +dataset_num_proc: 12 # SFT trainer config bf16: true diff --git a/recipes/zephyr-7b-gemma/README.md b/recipes/zephyr-7b-gemma/README.md index 416462e3..c0b06298 100644 --- a/recipes/zephyr-7b-gemma/README.md +++ b/recipes/zephyr-7b-gemma/README.md @@ -14,8 +14,8 @@ You will require 8 GPUs (80GB of VRAM) to train the full model - alternatively, ```shell # Step 1 - SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-gemma/sft/config_full.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/sft.py --config recipes/zephyr-7b-gemma/sft/config_full.yaml # Step 2 - DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-gemma/dpo/config_full.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/dpo.py --config recipes/zephyr-7b-gemma/dpo/config_full.yaml ``` diff --git a/recipes/zephyr-7b-gemma/dpo/config_full.yaml b/recipes/zephyr-7b-gemma/dpo/config_full.yaml index f17ac683..5204194a 100644 --- a/recipes/zephyr-7b-gemma/dpo/config_full.yaml +++ b/recipes/zephyr-7b-gemma/dpo/config_full.yaml @@ -3,13 +3,25 @@ model_name_or_path: HuggingFaceH4/zephyr-7b-gemma-sft-v0.1 torch_dtype: bfloat16 # Data training arguments -# For definitions, see: src/h4/training/config.py -dataset_mixer: - argilla/dpo-mix-7k: 1.0 -dataset_splits: -- train -- test -preprocessing_num_workers: 12 +dataset_mixture: + datasets: + - id: argilla/dpo-mix-7k + config: default + split: train + columns: + - chosen + - rejected + weight: 1.0 + - id: argilla/dpo-mix-7k + config: default + split: test + columns: + - chosen + - rejected + weight: 1.0 + test_split_size: 500 + seed: 0 +dataset_num_proc: 12 # DPOTrainer arguments bf16: true diff --git a/recipes/zephyr-7b-gemma/sft/config_full.yaml b/recipes/zephyr-7b-gemma/sft/config_full.yaml index 03226ab3..7f97256f 100644 --- a/recipes/zephyr-7b-gemma/sft/config_full.yaml +++ b/recipes/zephyr-7b-gemma/sft/config_full.yaml @@ -1,17 +1,28 @@ # Model arguments model_name_or_path: google/gemma-7b model_revision: main -tokenizer_name_or_path: philschmid/gemma-tokenizer-chatml # Custom tokenizer with <|im_start|> and <|im_end|> tokens torch_dtype: bfloat16 attn_implementation: flash_attention_2 # Data training arguments -dataset_mixer: - HuggingFaceH4/deita-10k-v0-sft: 1.0 -dataset_splits: -- train_sft -- test_sft -preprocessing_num_workers: 12 +chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}" +dataset_mixture: + datasets: + - id: HuggingFaceH4/deita-10k-v0-sft + config: default + split: train_sft + columns: + - messages + weight: 1.0 + - id: HuggingFaceH4/deita-10k-v0-sft + config: default + split: test_sft + columns: + - messages + weight: 1.0 + test_split_size: 1000 + seed: 0 +dataset_num_proc: 12 # SFT trainer config bf16: true diff --git a/scripts/README.md b/scripts/README.md index 79d2e195..48f7c31e 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -13,19 +13,19 @@ In practice, we find comparable performance for both full and QLoRA fine-tuning, ```shell # Full training with ZeRO-3 on 8 GPUs -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/run_{task}.py --config recipes/{model_name}/{task}/config_full.yaml # QLoRA 4-bit training on a single GPU -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/ddp.yaml --num_processes=1 scripts/run_{task}.py --config recipes/{model_name}/{task}/config_qlora.yaml # LoRA training on a single GPU -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --load_in_4bit=false +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/ddp.yaml --num_processes=1 scripts/run_{task}.py --config recipes/{model_name}/{task}/config_qlora.yaml --load_in_4bit=false # LoRA training with ZeRO-3 on two or more GPUs -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --load_in_4bit=false +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes={num_gpus} scripts/run_{task}.py --config recipes/{model_name}/{task}/config_qlora.yaml --load_in_4bit=false # QLoRA training with FSDP on two or more GPUs -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/fsdp+qlora.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16 +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/fsdp+qlora.yaml --num_processes={num_gpus} scripts/run_{task}.py --config recipes/{model_name}/{task}/config_qlora.yaml --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16 ``` Here `{task}` refers to the type of training you wish to run. Currently, the following tasks are supported: @@ -38,10 +38,10 @@ Here `{task}` refers to the type of training you wish to run. Currently, the fol ```shell # Step 1 - train SFT policy -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_full.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/sft.py --config recipes/zephyr-7b-beta/sft/config_full.yaml # Step 2 - align with DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_full.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/dpo.py --config recipes/zephyr-7b-beta/dpo/config_full.yaml ``` **πŸ’‘ Tip:** If you scale up/down the number of GPUs, we recommend also scaling up the per-device batch size or number of gradient accumulation steps to keep the global batch size constant (and thus replicate our results). @@ -50,14 +50,14 @@ By default, these scripts will push each model to your Hugging Face Hub username ```shell # Change batch size, number of epochs etc -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --per_device_train_batch_size=42 --num_train_epochs=5 +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/run_{task}.py --config recipes/{model_name}/{task}/config_full.yaml --per_device_train_batch_size=42 --num_train_epochs=5 ``` ## Logging with Weights and Biases By default, all training metrics are logged with TensorBoard. If you have a [Weights and Biases](https://wandb.ai/site) account and are logged in, you can view the training metrics by appending `--report_to=wandb`, e.g. ```shell -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --report_to=wandb +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/run_{task}.py --config recipes/{model_name}/{task}/config_full.yaml --report_to=wandb ``` ## Launching jobs on a Slurm cluster @@ -72,7 +72,7 @@ Here `{model_name}` and `{task}` are defined as above, while `{precision}` refer ```shell # Launch on Slurm and override default hyperparameters -sbatch --job-name=handbook_sft --nodes=1 recipes/launch.slurm zephyr-7b-beta sft full deepspeed_zero3 '--per_device_train_batch_size=42 --num_train_epochs=5' +sbatch --job-name=handbook_sft --nodes=1 recipes/launch.slurm zephyr-7b-beta sft full zero3 '--per_device_train_batch_size=42 --num_train_epochs=5' ``` You can scale the number of nodes by increasing the `--nodes` flag. diff --git a/scripts/dpo.py b/scripts/dpo.py new file mode 100644 index 00000000..07df54c7 --- /dev/null +++ b/scripts/dpo.py @@ -0,0 +1,159 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Full training +python scripts/dpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --learning_rate 5.0e-7 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 \ + --output_dir Qwen2-0.5B-DPO \ + --no_remove_unused_columns + +# LoRA: +python scripts/dpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --learning_rate 5.0e-6 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 \ + --output_dir Qwen2-0.5B-DPO \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +""" + +import logging +import os +import sys + +import datasets +import torch +import transformers +from transformers import set_seed +from transformers.trainer_utils import get_last_checkpoint + +from alignment import DPOConfig, ScriptArguments, get_dataset, get_model, get_tokenizer +from trl import DPOTrainer, ModelConfig, TrlParser, get_peft_config + + +logger = logging.getLogger(__name__) + + +def main(script_args, training_args, model_args): + # Set seed for reproducibility + set_seed(training_args.seed) + + ############### + # Setup logging + ############### + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.info(f"Model parameters {model_args}") + logger.info(f"Script parameters {script_args}") + logger.info(f"Training parameters {training_args}") + + # Check for last checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + + ################### + # Model & Tokenizer + ################### + model = get_model(model_args, training_args) + ref_model = get_model(model_args, training_args) + tokenizer = get_tokenizer(model_args, training_args) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + ######### + # Dataset + ######### + dataset = get_dataset(script_args) + for split in dataset: + if "messages" in dataset[split].column_names: + dataset[split] = dataset[split].remove_columns("messages") + + ########## + # Training + ########## + trainer = DPOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + logger.info("*** Train ***") + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + metrics["train_samples"] = len(dataset[script_args.dataset_train_split]) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/scripts/orpo.py b/scripts/orpo.py new file mode 100644 index 00000000..6c44a86c --- /dev/null +++ b/scripts/orpo.py @@ -0,0 +1,158 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Full training +python scripts/orpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --max_length 2048 \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --learning_rate 5.0e-7 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 \ + --output_dir Qwen2-0.5B-ORPO \ + --no_remove_unused_columns + +# LoRA: +python scripts/orpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --learning_rate 5.0e-6 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 \ + --output_dir Qwen2-0.5B-ORPO \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +""" + +import logging +import os +import sys + +import datasets +import torch +import transformers +from transformers import set_seed +from transformers.trainer_utils import get_last_checkpoint + +from alignment import ORPOConfig, ScriptArguments, get_dataset, get_model, get_tokenizer +from trl import ModelConfig, ORPOTrainer, TrlParser, get_peft_config + + +logger = logging.getLogger(__name__) + + +def main(script_args, training_args, model_args): + # Set seed for reproducibility + set_seed(training_args.seed) + + ############### + # Setup logging + ############### + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.info(f"Model parameters {model_args}") + logger.info(f"Script parameters {script_args}") + logger.info(f"Training parameters {training_args}") + + # Check for last checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + + ################### + # Model & Tokenizer + ################### + model = get_model(model_args, training_args) + tokenizer = get_tokenizer(model_args, training_args) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + ######### + # Dataset + ######### + dataset = get_dataset(script_args) + for split in dataset: + if "messages" in dataset[split].column_names: + dataset[split] = dataset[split].remove_columns("messages") + + ########## + # Training + ########## + trainer = ORPOTrainer( + model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + logger.info("*** Train ***") + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + metrics["train_samples"] = len(dataset[script_args.dataset_train_split]) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, ORPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/scripts/run_cpt.py b/scripts/run_cpt.py deleted file mode 100644 index 06c9e9d9..00000000 --- a/scripts/run_cpt.py +++ /dev/null @@ -1,209 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Continued pretraining script for decoder language models. -""" - -import logging -import random -import sys - -import datasets -import torch -import transformers -from transformers import set_seed - -from alignment import ( - DataArguments, - H4ArgumentParser, - ModelArguments, - SFTConfig, - get_checkpoint, - get_datasets, - get_kbit_device_map, - get_peft_config, - get_quantization_config, - get_tokenizer, -) -from trl import SFTTrainer - - -logger = logging.getLogger(__name__) - - -def main(): - parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) - model_args, data_args, training_args = parser.parse() - - # Set seed for reproducibility - set_seed(training_args.seed) - - ############### - # Setup logging - ############### - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - datasets.utils.logging.set_verbosity(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Log on each process a small summary - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) - logger.info(f"Model parameters {model_args}") - logger.info(f"Data parameters {data_args}") - logger.info(f"Training/evaluation parameters {training_args}") - - # Check for last checkpoint - last_checkpoint = get_checkpoint(training_args) - if last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") - - ############### - # Load datasets - ############### - raw_datasets = get_datasets( - data_args, - splits=data_args.dataset_splits, - configs=data_args.dataset_configs, - columns_to_keep=[data_args.text_column], - ) - - logger.info( - f"Training on the following datasets and their proportions:" - f" {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" - ) - - train_dataset = raw_datasets["train"] if "train" in raw_datasets else None - eval_dataset = raw_datasets["test"] if "test" in raw_datasets else None - - if train_dataset is None: - raise ValueError( - "Training set must be included (so make sure that your dataset has a split with" " 'train' in the name)." - ) - - if training_args.do_eval and eval_dataset is None: - raise ValueError("'--do_eval' enabled so make sure that your dataset has a split with 'test' in the name.") - - ################ - # Load tokenizer - ################ - tokenizer = get_tokenizer(model_args, data_args, auto_set_chat_template=False) - - with training_args.main_process_first(desc="Log a few random samples from the processed training set"): - for index in random.sample(range(len(raw_datasets["train"])), 3): - logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}") - - ####################### - # Load pretrained model - ####################### - logger.info("*** Load pretrained model ***") - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - - model_kwargs = dict( - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - - ######################## - # Initialize the Trainer - ######################## - trainer = SFTTrainer( - model=model_args.model_name_or_path, - model_init_kwargs=model_kwargs, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - dataset_text_field=data_args.text_column, - max_seq_length=training_args.max_seq_length, - tokenizer=tokenizer, - packing=True, - peft_config=get_peft_config(model_args), - dataset_kwargs=training_args.dataset_kwargs, - ) - - ############### - # Training loop - ############### - logger.info("*** Train ***") - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - - train_result = trainer.train(resume_from_checkpoint=checkpoint) - metrics = train_result.metrics - metrics["train_samples"] = len(train_dataset) - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - ################################## - # Save model and create model card - ################################## - logger.info("*** Save model ***") - trainer.save_model(training_args.output_dir) - logger.info(f"Model saved to {training_args.output_dir}") - - # Save everything else on main process - kwargs = { - "finetuned_from": model_args.model_name_or_path, - "dataset": list(data_args.dataset_mixer.keys()), - "dataset_tags": list(data_args.dataset_mixer.keys()), - "tags": ["alignment-handbook"], - } - if trainer.accelerator.is_main_process: - trainer.create_model_card(**kwargs) - # Restore k,v cache for fast inference - trainer.model.config.use_cache = True - trainer.model.config.save_pretrained(training_args.output_dir) - - ########## - # Evaluate - ########## - if training_args.do_eval: - logger.info("*** Evaluate ***") - metrics = trainer.evaluate() - metrics["eval_samples"] = len(eval_dataset) - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - if training_args.push_to_hub is True: - logger.info("Pushing to hub...") - trainer.push_to_hub(**kwargs) - - logger.info("*** Training complete ***") - - -if __name__ == "__main__": - main() diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py deleted file mode 100644 index 972d969a..00000000 --- a/scripts/run_dpo.py +++ /dev/null @@ -1,261 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import random -import sys - -import torch -import transformers -from transformers import AutoModelForCausalLM, set_seed - -from alignment import ( - DataArguments, - DPOConfig, - H4ArgumentParser, - ModelArguments, - apply_chat_template, - decontaminate_humaneval, - get_checkpoint, - get_datasets, - get_kbit_device_map, - get_peft_config, - get_quantization_config, - get_tokenizer, - is_adapter_model, -) -from peft import PeftConfig, PeftModel -from trl import DPOTrainer - - -logger = logging.getLogger(__name__) - - -def main(): - parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig)) - model_args, data_args, training_args = parser.parse() - - ####### - # Setup - ####### - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Log on each process the small summary: - logger.info(f"Model parameters {model_args}") - logger.info(f"Data parameters {data_args}") - logger.info(f"Training/evaluation parameters {training_args}") - - # Check for last checkpoint - last_checkpoint = get_checkpoint(training_args) - if last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") - - # Set seed for reproducibility - set_seed(training_args.seed) - - ############### - # Load datasets - ############### - raw_datasets = get_datasets( - data_args, - splits=data_args.dataset_splits, - configs=data_args.dataset_configs, - columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"], - ) - logger.info( - f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" - ) - column_names = list(raw_datasets["train"].features) - - ##################################### - # Load tokenizer and process datasets - ##################################### - data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn - tokenizer = get_tokenizer(model_args, data_args) - - ##################### - # Apply chat template - ##################### - raw_datasets = raw_datasets.map( - apply_chat_template, - fn_kwargs={ - "tokenizer": tokenizer, - "task": "dpo", - "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg, - }, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - desc="Formatting comparisons with prompt template", - ) - - ########################## - # Decontaminate benchmarks - ########################## - num_raw_train_samples = len(raw_datasets["train"]) - raw_datasets = raw_datasets.filter( - decontaminate_humaneval, - fn_kwargs={"text_column": "text_chosen"}, - batched=True, - batch_size=10_000, - num_proc=1, - desc="Decontaminating HumanEval samples", - ) - num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"]) - logger.info( - f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set." - ) - - # Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected - for split in ["train", "test"]: - raw_datasets[split] = raw_datasets[split].rename_columns( - {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"} - ) - - # Log a few random samples from the training set: - for index in random.sample(range(len(raw_datasets["train"])), 3): - logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}") - logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}") - logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}") - - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - - model_kwargs = dict( - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - - model = model_args.model_name_or_path - if is_adapter_model(model, model_args.model_revision) is True: - logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}") - peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision) - model_kwargs = dict( - revision=model_args.base_model_revision, - trust_remote_code=model_args.trust_remote_code, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - base_model = AutoModelForCausalLM.from_pretrained( - peft_config.base_model_name_or_path, - **model_kwargs, - ) - model = PeftModel.from_pretrained( - base_model, - model_args.model_name_or_path, - revision=model_args.model_revision, - ) - model_kwargs = None - - ref_model = model - ref_model_kwargs = model_kwargs - - if model_args.use_peft is True: - ref_model = None - ref_model_kwargs = None - - ######################### - # Instantiate DPO trainer - ######################### - trainer = DPOTrainer( - model, - ref_model, - model_init_kwargs=model_kwargs, - ref_model_init_kwargs=ref_model_kwargs, - args=training_args, - beta=training_args.beta, - train_dataset=raw_datasets["train"], - eval_dataset=raw_datasets["test"], - tokenizer=tokenizer, - max_length=training_args.max_length, - max_prompt_length=training_args.max_prompt_length, - peft_config=get_peft_config(model_args), - loss_type=training_args.loss_type, - ) - - ############### - # Training loop - ############### - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - metrics = train_result.metrics - metrics["train_samples"] = len(raw_datasets["train"]) - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - logger.info("*** Training complete ***") - - ################################## - # Save model and create model card - ################################## - logger.info("*** Save model ***") - trainer.save_model(training_args.output_dir) - logger.info(f"Model saved to {training_args.output_dir}") - - # Save everything else on main process - kwargs = { - "finetuned_from": model_args.model_name_or_path, - "dataset": list(data_args.dataset_mixer.keys()), - "dataset_tags": list(data_args.dataset_mixer.keys()), - "tags": ["alignment-handbook"], - } - if trainer.accelerator.is_main_process: - trainer.create_model_card(**kwargs) - # Restore k,v cache for fast inference - trainer.model.config.use_cache = True - trainer.model.config.save_pretrained(training_args.output_dir) - - ########## - # Evaluate - ########## - if training_args.do_eval: - logger.info("*** Evaluate ***") - metrics = trainer.evaluate() - metrics["eval_samples"] = len(raw_datasets["test"]) - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - if training_args.push_to_hub is True: - logger.info("Pushing to hub...") - trainer.push_to_hub(**kwargs) - - logger.info("*** Training complete! ***") - - -if __name__ == "__main__": - main() diff --git a/scripts/run_orpo.py b/scripts/run_orpo.py deleted file mode 100644 index ce864d31..00000000 --- a/scripts/run_orpo.py +++ /dev/null @@ -1,270 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import random -import sys -from typing import Any, Dict - -import torch -import transformers -from transformers import AutoModelForCausalLM, set_seed - -from alignment import ( - DataArguments, - H4ArgumentParser, - ModelArguments, - apply_chat_template, - decontaminate_humaneval, - get_checkpoint, - get_datasets, - get_kbit_device_map, - get_peft_config, - get_quantization_config, - get_tokenizer, -) -from trl import ORPOConfig, ORPOTrainer, setup_chat_format - - -logger = logging.getLogger(__name__) - - -def main(): - parser = H4ArgumentParser((ModelArguments, DataArguments, ORPOConfig)) - model_args, data_args, training_args = parser.parse() - - ####### - # Setup - ####### - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Log on each process the small summary: - logger.info(f"Model parameters {model_args}") - logger.info(f"Data parameters {data_args}") - logger.info(f"Training/evaluation parameters {training_args}") - - # Check for last checkpoint - last_checkpoint = get_checkpoint(training_args) - if last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") - - # Set seed for reproducibility - set_seed(training_args.seed) - - ############### - # Load datasets - ############### - raw_datasets = get_datasets( - data_args, - splits=data_args.dataset_splits, - configs=data_args.dataset_configs, - columns_to_keep=[ - "prompt", - "chosen", - "rejected", - ], - ) - logger.info( - f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" - ) - column_names = list(raw_datasets["train"].features) - - ##################################### - # Load tokenizer and process datasets - ##################################### - data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn - tokenizer = get_tokenizer(model_args, data_args) - - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - - # For ChatML we need to add special tokens and resize the embedding layer - if "<|im_start|>" in tokenizer.chat_template: - model, tokenizer = setup_chat_format(model, tokenizer) - - ##################### - # Apply chat template - ##################### - raw_datasets = raw_datasets.map( - apply_chat_template, - fn_kwargs={ - "tokenizer": tokenizer, - "task": "orpo", - "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg, - }, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - desc="Formatting comparisons with prompt template", - ) - - ############################# - # Filter out seq > max_length - ############################# - if training_args.max_prompt_length is not None: - unfiltered_train_samples = len(raw_datasets["train"]) - if "test" in raw_datasets: - unfiltered_test_samples = len(raw_datasets["test"]) - - def filter_fn(sample: Dict[str, Any]) -> Dict[str, Any]: - prompt_length = tokenizer( - sample["text_prompt"], - return_tensors="pt", - )[ - "input_ids" - ].size(dim=-1) - - return prompt_length < training_args.max_prompt_length - - raw_datasets = raw_datasets.filter( - filter_fn, - desc="Filtering out the samples where len(text_prompt) > max_prompt_length", - ) - - filtered_train_samples = unfiltered_train_samples - len(raw_datasets["train"]) - logger.info( - f"Filtered out {filtered_train_samples} training samples out of the {unfiltered_train_samples} samples." - ) - if "test" in raw_datasets: - filtered_test_samples = unfiltered_test_samples - len(raw_datasets["test"]) - logger.info( - f"Filtered out {filtered_test_samples} test samples out of the {unfiltered_test_samples} samples." - ) - - ########################## - # Decontaminate benchmarks - ########################## - num_raw_train_samples = len(raw_datasets["train"]) - raw_datasets = raw_datasets.filter( - decontaminate_humaneval, - fn_kwargs={"text_column": "text_chosen"}, - batched=True, - batch_size=10_000, - num_proc=1, - desc="Decontaminating HumanEval samples", - ) - num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"]) - logger.info( - f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set." - ) - - # Replace column names with what TRL needs, text_prompt -> prompt, text_chosen -> chosen and text_rejected -> rejected - for split in raw_datasets.keys(): - raw_datasets[split] = raw_datasets[split].rename_columns( - { - "text_prompt": "prompt", - "text_chosen": "chosen", - "text_rejected": "rejected", - } - ) - - # Log a few random samples from the training set: - for index in random.sample(range(len(raw_datasets["train"])), 3): - logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}") - logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}") - logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}") - - ########################## - # Instantiate ORPO trainer - ########################## - trainer = ORPOTrainer( - model, - args=training_args, - train_dataset=raw_datasets["train"], - eval_dataset=raw_datasets["test"] if "test" in raw_datasets else None, - tokenizer=tokenizer, - peft_config=get_peft_config(model_args), # type: ignore - ) - - ############### - # Training loop - ############### - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - metrics = train_result.metrics - metrics["train_samples"] = len(raw_datasets["train"]) - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - logger.info("*** Training complete ***") - - ################################## - # Save model and create model card - ################################## - logger.info("*** Save model ***") - if trainer.is_fsdp_enabled: - trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") - trainer.save_model(training_args.output_dir) - logger.info(f"Model saved to {training_args.output_dir}") - - # Save everything else on main process - kwargs = { - "finetuned_from": model_args.model_name_or_path, - "dataset": list(data_args.dataset_mixer.keys()), - "dataset_tags": list(data_args.dataset_mixer.keys()), - "tags": ["alignment-handbook"], - } - if trainer.accelerator.is_main_process: - trainer.create_model_card(**kwargs) - # Restore k,v cache for fast inference - trainer.model.config.use_cache = True - trainer.model.config.save_pretrained(training_args.output_dir) - - ########## - # Evaluate - ########## - if training_args.do_eval and "test" in raw_datasets: - logger.info("*** Evaluate ***") - metrics = trainer.evaluate() - metrics["eval_samples"] = len(raw_datasets["test"]) - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - if training_args.push_to_hub is True: - logger.info("Pushing to hub...") - trainer.push_to_hub(**kwargs) - - logger.info("*** Training complete! ***") - - -if __name__ == "__main__": - main() diff --git a/scripts/run_sft.py b/scripts/run_sft.py deleted file mode 100644 index 60a2dfdb..00000000 --- a/scripts/run_sft.py +++ /dev/null @@ -1,233 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Supervised fine-tuning script for decoder language models. -""" - -import logging -import random -import sys - -import datasets -import torch -import transformers -from transformers import AutoModelForCausalLM, set_seed - -from alignment import ( - DataArguments, - H4ArgumentParser, - ModelArguments, - SFTConfig, - apply_chat_template, - decontaminate_humaneval, - get_checkpoint, - get_datasets, - get_kbit_device_map, - get_peft_config, - get_quantization_config, - get_tokenizer, -) -from trl import SFTTrainer, setup_chat_format - - -logger = logging.getLogger(__name__) - - -def main(): - parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) - model_args, data_args, training_args = parser.parse() - - # Set seed for reproducibility - set_seed(training_args.seed) - - ############### - # Setup logging - ############### - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - datasets.utils.logging.set_verbosity(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Log on each process a small summary - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) - logger.info(f"Model parameters {model_args}") - logger.info(f"Data parameters {data_args}") - logger.info(f"Training/evaluation parameters {training_args}") - - # Check for last checkpoint - last_checkpoint = get_checkpoint(training_args) - if last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") - - ############### - # Load datasets - ############### - raw_datasets = get_datasets( - data_args, - splits=data_args.dataset_splits, - configs=data_args.dataset_configs, - columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"], - ) - logger.info( - f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" - ) - column_names = list(raw_datasets["train"].features) - - ################ - # Load tokenizer - ################ - tokenizer = get_tokenizer(model_args, data_args) - - ####################### - # Load pretrained model - ####################### - logger.info("*** Load pretrained model ***") - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - - model_kwargs = dict( - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - - model = model_args.model_name_or_path - # For ChatML we need to add special tokens and resize the embedding layer - if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path: - model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) - model, tokenizer = setup_chat_format(model, tokenizer) - model_kwargs = None - - ##################### - # Apply chat template - ##################### - raw_datasets = raw_datasets.map( - apply_chat_template, - fn_kwargs={ - "tokenizer": tokenizer, - "task": "sft", - "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg, - }, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - desc="Applying chat template", - ) - - ########################## - # Decontaminate benchmarks - ########################## - num_raw_train_samples = len(raw_datasets["train"]) - raw_datasets = raw_datasets.filter(decontaminate_humaneval, batched=True, batch_size=10_000, num_proc=1) - num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"]) - logger.info( - f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set." - ) - - train_dataset = raw_datasets["train"] - eval_dataset = raw_datasets["test"] - - with training_args.main_process_first(desc="Log a few random samples from the processed training set"): - for index in random.sample(range(len(raw_datasets["train"])), 3): - logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}") - - ######################## - # Initialize the Trainer - ######################## - trainer = SFTTrainer( - model=model, - model_init_kwargs=model_kwargs, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - dataset_text_field="text", - max_seq_length=training_args.max_seq_length, - tokenizer=tokenizer, - packing=True, - peft_config=get_peft_config(model_args), - dataset_kwargs=training_args.dataset_kwargs, - ) - - ############### - # Training loop - ############### - logger.info("*** Train ***") - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - metrics = train_result.metrics - metrics["train_samples"] = len(train_dataset) - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - ################################## - # Save model and create model card - ################################## - logger.info("*** Save model ***") - trainer.save_model(training_args.output_dir) - logger.info(f"Model saved to {training_args.output_dir}") - - # Save everything else on main process - kwargs = { - "finetuned_from": model_args.model_name_or_path, - "dataset": list(data_args.dataset_mixer.keys()), - "dataset_tags": list(data_args.dataset_mixer.keys()), - "tags": ["alignment-handbook"], - } - if trainer.accelerator.is_main_process: - trainer.create_model_card(**kwargs) - # Restore k,v cache for fast inference - trainer.model.config.use_cache = True - trainer.model.config.save_pretrained(training_args.output_dir) - - ########## - # Evaluate - ########## - if training_args.do_eval: - logger.info("*** Evaluate ***") - metrics = trainer.evaluate() - metrics["eval_samples"] = len(eval_dataset) - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - if training_args.push_to_hub is True: - logger.info("Pushing to hub...") - trainer.push_to_hub(**kwargs) - - logger.info("*** Training complete ***") - - -if __name__ == "__main__": - main() diff --git a/scripts/sft.py b/scripts/sft.py new file mode 100644 index 00000000..80dbc641 --- /dev/null +++ b/scripts/sft.py @@ -0,0 +1,174 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Supervised fine-tuning script for decoder language models. + +Usage: + +# One 1 node of 8 x H100s +accelerate launch --config_file recipes/accelerate_configs/zero3.yaml scripts/sft.py \ + --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-5 \ + --num_train_epochs 1 \ + --packing \ + --max_seq_length 4096 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --bf16 true \ + --logging_steps 5 \ + --eval_strategy steps \ + --eval_steps 100 \ + --output_dir data/Qwen2.5-1.5B-SFT +""" + +import logging +import os +import sys + +import datasets +import transformers +from transformers import set_seed +from transformers.trainer_utils import get_last_checkpoint + +from alignment import ScriptArguments, SFTConfig, get_dataset, get_model, get_tokenizer +from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format + + +logger = logging.getLogger(__name__) + + +def main(script_args, training_args, model_args): + # Set seed for reproducibility + set_seed(training_args.seed) + + ############### + # Setup logging + ############### + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.info(f"Model parameters {model_args}") + logger.info(f"Script parameters {script_args}") + logger.info(f"Training parameters {training_args}") + + # Check for last checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + + ################ + # Load datasets + ################ + dataset = get_dataset(script_args) + ################ + # Load tokenizer + ################ + tokenizer = get_tokenizer(model_args, training_args) + ############ + # Load model + ############ + logger.info("*** Loading model ***") + model = get_model(model_args, training_args) + + if tokenizer.chat_template is None: + logger.info("No chat template provided, using ChatML.") + model, tokenizer = setup_chat_format(model, tokenizer, format="chatml") + + ############################ + # Initialize the SFT Trainer + ############################ + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None), + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + ############### + # Training loop + ############### + logger.info("*** Train ***") + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + metrics["train_samples"] = len(dataset[script_args.dataset_train_split]) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + ################################## + # Save model and create model card + ################################## + logger.info("*** Save model ***") + # Align the model's generation config with the tokenizer's eos token + # to avoid unbounded generation in the transformers `pipeline()` function + trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id + trainer.model.config.eos_token_id = tokenizer.eos_token_id + trainer.save_model(training_args.output_dir) + logger.info(f"Model saved to {training_args.output_dir}") + + # Save everything else on main process + kwargs = { + "model_name": training_args.hub_model_id if training_args.push_to_hub else None, + "dataset_name": script_args.dataset_name, + "tags": ["alignment-handbook"], + } + if trainer.accelerator.is_main_process: + trainer.create_model_card(**kwargs) + # Restore k,v cache for fast inference + trainer.model.config.use_cache = True + trainer.model.config.save_pretrained(training_args.output_dir) + + ########## + # Evaluate + ########## + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate() + metrics["eval_samples"] = len(dataset[script_args.dataset_test_split]) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + ############# + # push to hub + ############# + if training_args.push_to_hub: + logger.info("Pushing to hub...") + trainer.push_to_hub(**kwargs) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/setup.py b/setup.py index 8bb46178..7259bd35 100644 --- a/setup.py +++ b/setup.py @@ -41,34 +41,35 @@ # IMPORTANT: all dependencies should be listed here with their version requirements, if any. # * If a dependency is fast-moving (e.g. transformers), pin to the exact version _deps = [ - "accelerate>=0.29.2", - "bitsandbytes>=0.43.0", + "accelerate>=1.9.0", + "bitsandbytes>=0.46.1", "black>=24.4.2", - "datasets>=2.18.0", - "deepspeed>=0.14.4", - "einops>=0.6.1", + "datasets>=4.0.0", + "deepspeed>=0.17.2", + "einops>=0.8.1", "evaluate==0.4.0", "flake8>=6.0.0", "hf-doc-builder>=0.4.0", - "hf_transfer>=0.1.4", - "huggingface-hub>=0.19.2,<1.0", + "huggingface-hub>=0.33.4,<1.0", "isort>=5.12.0", + "liger-kernel>=0.6.0", "ninja>=1.11.1", "numpy>=1.24.2", "packaging>=23.0", "parameterized>=0.9.0", - "peft>=0.9.0", + "peft>=0.16.0", "protobuf<=3.20.2", # Needed to avoid conflicts with `transformers` "pytest", - "safetensors>=0.3.3", - "sentencepiece>=0.1.99", + "safetensors>=0.5.3", + "sentencepiece>=0.2.0", "scipy", "tensorboard", - "torch>=2.1.2", - "transformers>=4.39.3", - "trl>=0.9.6,<0.13.0", + "torch>=2.6.0", + "transformers>=4.53.3", + "trl>=0.19.1", "jinja2>=3.0.0", "tqdm>=4.64.1", + "wandb", ] # this is a lookup table with items like: @@ -99,9 +100,9 @@ def deps_list(*pkgs): deps["evaluate"], deps["datasets"], deps["deepspeed"], - deps["hf_transfer"], deps["huggingface-hub"], deps["jinja2"], + deps["liger-kernel"], deps["ninja"], deps["numpy"], deps["packaging"], # utilities from PyPA to e.g., compare versions @@ -114,6 +115,7 @@ def deps_list(*pkgs): deps["tqdm"], # progress bars in model download and training scripts deps["transformers"], deps["trl"], + deps["wandb"], ] setup( diff --git a/src/alignment/__init__.py b/src/alignment/__init__.py index 2fafb542..0836e213 100644 --- a/src/alignment/__init__.py +++ b/src/alignment/__init__.py @@ -1,31 +1,16 @@ -__version__ = "0.3.0.dev0" +__version__ = "0.4.0.dev0" -from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig -from .data import apply_chat_template, get_datasets -from .decontaminate import decontaminate_humaneval -from .model_utils import ( - get_checkpoint, - get_kbit_device_map, - get_peft_config, - get_quantization_config, - get_tokenizer, - is_adapter_model, -) +from .configs import DPOConfig, ORPOConfig, ScriptArguments, SFTConfig +from .data import get_dataset +from .model_utils import get_model, get_tokenizer __all__ = [ - "DataArguments", + "ScriptArguments", "DPOConfig", - "H4ArgumentParser", - "ModelArguments", "SFTConfig", - "apply_chat_template", - "get_datasets", - "decontaminate_humaneval", - "get_checkpoint", - "get_kbit_device_map", - "get_peft_config", - "get_quantization_config", + "ORPOConfig", + "get_dataset", "get_tokenizer", - "is_adapter_model", + "get_model", ] diff --git a/src/alignment/configs.py b/src/alignment/configs.py index aff07920..e8f15baa 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -1,5 +1,19 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # coding=utf-8 -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,260 +26,132 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses -import os -import sys -from dataclasses import dataclass, field -from typing import Any, Dict, List, NewType, Optional, Tuple -from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser +from dataclasses import dataclass, field +from typing import Any, Optional import trl -MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -DataClassType = NewType("DataClassType", Any) - - -class H4ArgumentParser(HfArgumentParser): - def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]: - """ - Parse a YAML file and overwrite the default/loaded values with the values provided to the command line. - - Args: - yaml_arg (`str`): - The path to the config file used - other_args (`List[str]`, *optional`): - A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2']. - - Returns: - [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line - """ - arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg)) - - outputs = [] - # strip other args list into dict of key-value pairs - other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args} - used_args = {} - - # overwrite the default/loaded value with the value provided to the command line - # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327 - for data_yaml, data_class in zip(arg_list, self.dataclass_types): - keys = {f.name for f in dataclasses.fields(data_yaml) if f.init} - inputs = {k: v for k, v in vars(data_yaml).items() if k in keys} - for arg, val in other_args.items(): - # add only if in keys - - if arg in keys: - base_type = data_yaml.__dataclass_fields__[arg].type - inputs[arg] = val - - # cast type for ints, floats (default to strings) - if base_type in [int, float]: - inputs[arg] = base_type(val) - - if base_type == List[str]: - inputs[arg] = [str(v) for v in val.split(",")] - - # bool of a non-empty string is True, so we manually check for bools - if base_type is bool: - if val in ["true", "True"]: - inputs[arg] = True - else: - inputs[arg] = False - - # add to used-args so we can check if double add - if arg not in used_args: - used_args[arg] = val - else: - raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior") +@dataclass +class DatasetConfig: + """Configuration for a dataset in a mixture.""" - obj = data_class(**inputs) - outputs.append(obj) + id: str + config: Optional[str] = None + split: str = "train" + columns: Optional[list[str]] = None + weight: Optional[float] = None - return outputs - def parse(self) -> DataClassType | Tuple[DataClassType]: - if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): - # If we pass only one argument to the script and it's the path to a YAML file, - # let's parse it to get our arguments. - output = self.parse_yaml_file(os.path.abspath(sys.argv[1])) - # parse command line args and yaml file - elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"): - output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:]) - # parse command line args only - else: - output = self.parse_args_into_dataclasses() +@dataclass +class DatasetMixtureConfig: + """Configuration for a mixture of datasets.""" - if len(output) == 1: - output = output[0] - return output + datasets: list[DatasetConfig] + seed: int = 0 + test_split_size: Optional[float] = None @dataclass -class ModelArguments: +class ScriptArguments(trl.ScriptArguments): """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune. + Extended version of ScriptArguments with support for dataset mixtures. + + Args: + dataset_mixture (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Configuration for creating dataset mixtures with advanced options. + Format: + dataset_mixture: + datasets: + - id: dataset_id1 + config: config_name + columns: + - col1 + - col2 + weight: 0.5 + - id: dataset_id2 + config: config_name + columns: + - col1 + - col2 + weight: 0.5 + seed: 42 + test_split_size: 0.1 """ - base_model_revision: Optional[str] = field( - default=None, - metadata={"help": ("The base model checkpoint for weights initialization with PEFT adapters.")}, - ) - model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": ( - "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." - ) - }, - ) - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"}) - torch_dtype: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " - "dtype will be automatically derived from the model's weights." - ), - "choices": ["auto", "bfloat16", "float16", "float32"], - }, - ) - tokenizer_name_or_path: Optional[str] = field( + dataset_mixture: Optional[dict[str, Any]] = field( default=None, - metadata={ - "help": ( - "The path to the tokenizer. Useful if you want to use a different tokenizer to the one stored in `model_name_or_path`." - ) - }, - ) - trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) - attn_implementation: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Which attention implementation to use; you can use --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`" - ) - }, - ) - use_peft: bool = field( - default=False, - metadata={"help": ("Whether to use PEFT or not for training.")}, - ) - lora_r: Optional[int] = field( - default=16, - metadata={"help": ("LoRA R value.")}, - ) - lora_alpha: Optional[int] = field( - default=32, - metadata={"help": ("LoRA alpha.")}, - ) - lora_dropout: Optional[float] = field( - default=0.05, - metadata={"help": ("LoRA dropout.")}, - ) - lora_target_modules: Optional[List[str]] = field( - default=None, - metadata={"help": ("LoRA target modules.")}, - ) - lora_modules_to_save: Optional[List[str]] = field( - default=None, - metadata={"help": ("Model layers to unfreeze & train")}, - ) - load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"}) - load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"}) - - bnb_4bit_quant_type: Optional[str] = field( - default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"} - ) - use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) - bnb_4bit_quant_storage: Optional[str] = field( - default="uint8", - metadata={"help": "storage type to pack the quanitzed 4-bit prarams."}, + metadata={"help": "Configuration for creating dataset mixtures with advanced options like shuffling."}, ) def __post_init__(self): - if self.load_in_8bit and self.load_in_4bit: - raise ValueError("You can't use 8 bit and 4 bit precision at the same time") + if self.dataset_name is None and self.dataset_mixture is None: + raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided") + + if self.dataset_mixture is not None: + if not isinstance(self.dataset_mixture, dict) or "datasets" not in self.dataset_mixture: + raise ValueError( + "dataset_mixture must be a dictionary with a 'datasets' key. " + "Expected format: {'datasets': [...], 'seed': int}" + ) + + datasets_list = [] + datasets_data = self.dataset_mixture.get("datasets", []) + + if isinstance(datasets_data, list): + for dataset_config in datasets_data: + datasets_list.append( + DatasetConfig( + id=dataset_config.get("id"), + config=dataset_config.get("config"), + split=dataset_config.get("split", "train"), + columns=dataset_config.get("columns"), + weight=dataset_config.get("weight", 1.0), + ) + ) + else: + raise ValueError("'datasets' must be a list of dataset configurations") + + self.dataset_mixture = DatasetMixtureConfig( + datasets=datasets_list, + seed=self.dataset_mixture.get("seed", 0), + test_split_size=self.dataset_mixture.get("test_split_size", None), + ) + + # Check that column names are consistent across all dataset configs + columns_sets = [set(dataset.columns) for dataset in datasets_list if dataset.columns is not None] + if columns_sets: + first_columns = columns_sets[0] + if not all(columns == first_columns for columns in columns_sets): + raise ValueError( + "Column names must be consistent across all dataset configurations in a mixture. " + f"Found different column sets: {[list(cols) for cols in columns_sets]}" + ) @dataclass -class DataArguments: +class SFTConfig(trl.SFTConfig): """ - Arguments pertaining to what data we are going to input our model for training and eval. + args for callbacks, benchmarks etc """ chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) - dataset_mixer: Optional[Dict[str, float]] = field( - default=None, - metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")}, - ) - text_column: Optional[str] = field( - default="text", - metadata={"help": "The column name to use for the text in the dataset (only used for continued pretraining)."}, - ) - dataset_splits: Optional[List[str]] = field( - default_factory=lambda: ["train", "test"], - metadata={"help": ("List of train test splits to use in the dataset")}, - ) - dataset_configs: Optional[List[str]] = field( - default=None, - metadata={"help": "List of dataset config names. If given must be the same length as 'dataset_mixer' keys."}, - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - truncation_side: Optional[str] = field( - default=None, metadata={"help": "Truncation side to use for the tokenizer."} - ) - auto_insert_empty_system_msg: bool = field( - default=True, - metadata={ - "help": ( - "Whether to automatically insert an empty system message as the first message if `system` is mentioned in the chat template." - ) - }, - ) @dataclass -class SFTConfig(trl.SFTConfig): +class DPOConfig(trl.DPOConfig): """ - Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments - Also used for the continued pretraining task. + args for callbacks, benchmarks etc """ - hub_model_revision: Optional[str] = field( - default="main", - metadata={"help": ("The Hub model branch to push the model to.")}, - ) - logging_first_step: bool = field( - default=True, - metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, - ) + chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) @dataclass -class DPOConfig(trl.DPOConfig): +class ORPOConfig(trl.ORPOConfig): """ - Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments + args for callbacks, benchmarks etc """ - hub_model_revision: Optional[str] = field( - default="main", - metadata={"help": ("The Hub model branch to push the model to.")}, - ) - logging_first_step: bool = field( - default=True, - metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, - ) - optim: Optional[str] = field(default="rmsprop") - remove_unused_columns: bool = field(default=False) + chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) diff --git a/src/alignment/data.py b/src/alignment/data.py index 56a4af62..8d3bb314 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -1,5 +1,4 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,244 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Any, List, Literal, Optional +import logging -from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk -from datasets.builder import DatasetGenerationError +import datasets +from datasets import DatasetDict, concatenate_datasets -from .configs import DataArguments +from .configs import ScriptArguments -DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" +logger = logging.getLogger(__name__) -def maybe_insert_system_message(messages, tokenizer): - if messages[0]["role"] == "system": - return - - # chat template can be one of two attributes, we check in order - chat_template = tokenizer.chat_template - if chat_template is None: - chat_template = tokenizer.get_chat_template() - - # confirm the jinja template refers to a system message before inserting - if "system" in chat_template or "<|im_start|>" in chat_template: - messages.insert(0, {"role": "system", "content": ""}) - - -def apply_chat_template( - example, - tokenizer, - task: Literal["sft", "generation", "rm", "dpo"], - auto_insert_empty_system_msg: bool = True, -): - if task in ["sft", "generation"]: - messages = example["messages"] - # We add an empty system message if there is none - if auto_insert_empty_system_msg: - maybe_insert_system_message(messages, tokenizer) - example["text"] = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True if task == "generation" else False, - ) - elif task == "rm": - if all(k in example.keys() for k in ("chosen", "rejected")): - chosen_messages = example["chosen"] - rejected_messages = example["rejected"] - # We add an empty system message if there is none - if auto_insert_empty_system_msg: - maybe_insert_system_message(chosen_messages, tokenizer) - maybe_insert_system_message(rejected_messages, tokenizer) - - example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) - example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) - else: - raise ValueError( - f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" - ) - elif task in ["dpo", "orpo"]: - if all(k in example.keys() for k in ("chosen", "rejected")): - if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]): - raise ValueError( - f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages" - ) - - # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue - # We therefore need to extract the N-1 turns to form the prompt - if "prompt" in example and is_openai_format(example["prompt"]): - prompt_messages = example["prompt"] - chosen_messages = example["chosen"] - rejected_messages = example["rejected"] - else: - prompt_messages = example["chosen"][:-1] - # Now we extract the final turn to define chosen/rejected responses - chosen_messages = example["chosen"][-1:] - rejected_messages = example["rejected"][-1:] - - # Prepend a system message if the first message is not a system message - if auto_insert_empty_system_msg: - maybe_insert_system_message(prompt_messages, tokenizer) - - example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False) - example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) - example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) - else: - raise ValueError( - f"Could not format example as dialogue for `{task}` task! Require either the " - f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}" - ) - else: - raise ValueError( - f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']" - ) - return example - - -def is_openai_format(messages: Any) -> bool: - """ - Check if the input messages are in OpenAI format. - Args: - messages (`Any`): - Messages to check. - Returns: - `bool`: Whether the messages are in OpenAI format. - """ - if isinstance(messages, list) and all(isinstance(message, dict) for message in messages): - return all("role" in message and "content" in message for message in messages) - return False - - -def get_datasets( - data_config: DataArguments | dict, - splits: Optional[List[str]] = None, - configs: Optional[List[str]] = None, - columns_to_keep: Optional[List[str]] = None, - shuffle: bool = True, -) -> DatasetDict: - """ - Loads one or more datasets with varying training set proportions. +def get_dataset(args: ScriptArguments) -> DatasetDict: + """Load a dataset or a mixture of datasets based on the configuration. Args: - data_config (`DataArguments` or `dict`): - Dataset configuration and split proportions. - splits (`List[str]`, *optional*, defaults to `['train', 'test']`): - Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. - configs (Optional[List[str]], *optional*, defaults to `None`): - List of dataset config names. If given must be the same length as 'data_config' keys. - columns_to_keep (Optional[List[str]], *optional*, defaults to `None`): - Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts, - and for cpt this should be (at least) the text column. - shuffle (`bool`, *optional*, defaults to `True`): - Whether to shuffle the training and testing/validation data. - - Returns - [`DatasetDict`]: The dataset dictionary containing the loaded datasets. - """ - if type(data_config) is DataArguments: - # Structure of the config to read the datasets and their mix - # datasets_mixer: - # - 'dataset1': 0.5 - # - 'dataset2': 0.3 - # - 'dataset3': 0.2 - dataset_mixer = data_config.dataset_mixer - elif isinstance(data_config, dict): - # Structure of the input is: - # dataset_mixer = { - # "dataset1": 0.5, - # "dataset1": 0.3, - # "dataset1": 0.2, - # } - dataset_mixer = data_config - else: - raise ValueError(f"Data config {data_config} not recognized.") + args (ScriptArguments): Script arguments containing dataset configuration. - raw_datasets = mix_datasets( - dataset_mixer, - splits=splits, - configs=configs, - columns_to_keep=columns_to_keep, - shuffle=shuffle, - ) - return raw_datasets - - -def mix_datasets( - dataset_mixer: dict, - splits: Optional[List[str]] = None, - configs: Optional[List[str]] = None, - columns_to_keep: Optional[List[str]] = None, - shuffle=True, -) -> DatasetDict: - """ - Loads and mixes datasets according to proportions specified in `dataset_mixer`. - - Args: - dataset_mixer (`dict`): - Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1. - splits (Optional[List[str]], *optional*, defaults to `None`): - Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. - configs (Optional[List[str]], *optional*, defaults to `None`): - List of dataset config names. If given must be the same length as 'dataset_mixer' keys. - columns_to_keep (Optional[List[str]], *optional*, defaults to `None`): - Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts, - and for cpt this should be (at least) the text column. - shuffle (`bool`, *optional*, defaults to `True`): - Whether to shuffle the training and testing/validation data. + Returns: + DatasetDict: The loaded datasets. """ - splits = ["train", "test"] if splits is None else splits - configs = [None] * len(dataset_mixer) if not configs else configs - columns_to_keep = [] if columns_to_keep is None else columns_to_keep + if args.dataset_name and not args.dataset_mixture: + logger.info(f"Loading dataset: {args.dataset_name}") + return datasets.load_dataset(args.dataset_name, args.dataset_config) + elif args.dataset_mixture: + logger.info(f"Creating dataset mixture with {len(args.dataset_mixture.datasets)} datasets") + seed = args.dataset_mixture.seed + datasets_list = [] + + for dataset_config in args.dataset_mixture.datasets: + logger.info(f"Loading dataset for mixture: {dataset_config.id} (config: {dataset_config.config})") + ds = datasets.load_dataset( + dataset_config.id, + dataset_config.config, + split=dataset_config.split, + ) + if dataset_config.columns is not None: + ds = ds.select_columns(dataset_config.columns) + if dataset_config.weight is not None: + ds = ds.shuffle(seed=seed).select(range(int(len(ds) * dataset_config.weight))) + logger.info( + f"Subsampled dataset '{dataset_config.id}' (config: {dataset_config.config}) with weight={dataset_config.weight} to {len(ds)} examples" + ) - if configs is not None and len(configs) != len(dataset_mixer): - raise ValueError("The number of given dataset config names must be the same as the given number of datasets.") + datasets_list.append(ds) - raw_datasets = DatasetDict() - raw_train_datasets = [] - raw_val_datasets = [] - fracs = [] - for (ds, frac), ds_config in zip(dataset_mixer.items(), configs): - fracs.append(frac) - for split in splits: - try: - # Try first if dataset on a Hub repo - dataset = load_dataset(ds, ds_config, split=split) - except DatasetGenerationError: - # If not, check local dataset - dataset = load_from_disk(os.path.join(ds, split)) + if datasets_list: + combined_dataset = concatenate_datasets(datasets_list) + combined_dataset = combined_dataset.shuffle(seed=seed) + logger.info(f"Created dataset mixture with {len(combined_dataset)} examples") - # Remove redundant columns to avoid schema conflicts on load - dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep]) - if "train" in split: - raw_train_datasets.append(dataset) - elif "test" in split: - raw_val_datasets.append(dataset) + if args.dataset_mixture.test_split_size is not None: + combined_dataset = combined_dataset.train_test_split( + test_size=args.dataset_mixture.test_split_size, seed=seed + ) + logger.info( + f"Split dataset into train and test sets with test size: {args.dataset_mixture.test_split_size}" + ) + return combined_dataset else: - raise ValueError(f"Split type {split} not recognized as one of test or train.") - - if any(frac < 0 for frac in fracs): - raise ValueError("Dataset fractions cannot be negative.") - - if len(raw_train_datasets) > 0: - train_subsets = [] - for dataset, frac in zip(raw_train_datasets, fracs): - train_subset = dataset.select(range(int(frac * len(dataset)))) - train_subsets.append(train_subset) - if shuffle: - raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42) - else: - raw_datasets["train"] = concatenate_datasets(train_subsets) - # No subsampling for test datasets to enable fair comparison across models - if len(raw_val_datasets) > 0: - if shuffle: - raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42) + return DatasetDict({"train": combined_dataset}) else: - raw_datasets["test"] = concatenate_datasets(raw_val_datasets) + raise ValueError("No datasets were loaded from the mixture configuration") - if len(raw_datasets) == 0: - raise ValueError( - f"Dataset {dataset_mixer} not recognized with splits {splits}. Check the dataset has been correctly formatted." - ) - - return raw_datasets + else: + raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided") diff --git a/src/alignment/decontaminate.py b/src/alignment/decontaminate.py deleted file mode 100644 index 45cba95c..00000000 --- a/src/alignment/decontaminate.py +++ /dev/null @@ -1,91 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, List - -from datasets import load_dataset - - -# HumanEval solutions that are considered simple/generic enough to be kept in the training dataset -HUMAN_EVAL_STRINGS_OK = ["return x + y", "return len(string)", "return n**2", "return " ".join(strings)"] - - -def extract_docstring(prompt: str) -> str: - if '"""' in prompt: - if prompt.count('"""') == 2: - return prompt.split('"""')[1].strip() - elif prompt.count('"""') == 4: - return prompt.split('"""')[3].strip() - else: - raise ValueError() - elif "'''" in prompt: - assert prompt.count("'''") == 2 - return prompt.split("'''")[1].strip() - else: - raise ValueError() - - -def human_eval_docstrings() -> List[str]: - ds = load_dataset("openai_humaneval", split="test") - docstrings = [extract_docstring(v["prompt"]) for v in ds] - return docstrings - - -def load_dataset_column(dataset: str, column: str, split: str, name=None) -> List[str]: - ds = load_dataset(dataset, split=split, name=name) - res = [sample[column].strip() for sample in ds] - # Only return non-empty strings - return [sample for sample in res if len(sample) > 0] - - -FILTER_OUT = { - "human_eval_docstrings": human_eval_docstrings(), - "human_eval_solutions": [ - s - for s in load_dataset_column("openai_humaneval", "canonical_solution", "test") - if s not in HUMAN_EVAL_STRINGS_OK - ], -} - - -def normalize_whitespace(text: str) -> str: - return " ".join(text.split()) - - -def decontaminate_humaneval( - samples: List[Dict[str, Any]], text_column: str = "text", filter_out: Dict[str, List[str]] = FILTER_OUT -) -> List[Dict[str, Any]]: - """ - filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be - filtered-out. - Return a list where each element is True if the corresponding file should be included in the dataset. - Otherwise, the element is False. - """ - output = [] - - for content in samples[text_column]: - content = normalize_whitespace(content.lower()) - matched = False - for _, substrings in filter_out.items(): - for substring in substrings: - if normalize_whitespace(substring.lower()) in content: - matched = True - break - if matched: - break - # we keep files that are not matched - output.append(not matched) - - return output diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index 650517ee..267a7d90 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -1,5 +1,4 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,117 +11,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from pathlib import Path -from typing import Dict import torch -from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer -from transformers.trainer_utils import get_last_checkpoint +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer -from accelerate import Accelerator -from huggingface_hub import list_repo_files -from huggingface_hub.errors import RepositoryNotFoundError -from huggingface_hub.utils._validators import HFValidationError -from peft import LoraConfig, PeftConfig +from trl import ModelConfig, get_kbit_device_map, get_quantization_config -from .configs import DataArguments, DPOConfig, ModelArguments, SFTConfig -from .data import DEFAULT_CHAT_TEMPLATE +from .configs import SFTConfig -def get_current_device() -> int: - """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" - return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" - - -def get_kbit_device_map() -> Dict[str, int] | None: - """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`""" - return {"": get_current_device()} if torch.cuda.is_available() else None - - -def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig | None: - if model_args.load_in_4bit: - compute_dtype = torch.float16 - if model_args.torch_dtype not in {"auto", None}: - compute_dtype = getattr(torch, model_args.torch_dtype) - - quantization_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=compute_dtype, - bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, - bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, - bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage, - ).to_dict() - elif model_args.load_in_8bit: - quantization_config = BitsAndBytesConfig( - load_in_8bit=True, - ).to_dict() - else: - quantization_config = None - - return quantization_config - - -def get_tokenizer( - model_args: ModelArguments, data_args: DataArguments, auto_set_chat_template: bool = True -) -> PreTrainedTokenizer: +def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig) -> PreTrainedTokenizer: """Get the tokenizer for the model.""" tokenizer = AutoTokenizer.from_pretrained( - ( - model_args.model_name_or_path - if model_args.tokenizer_name_or_path is None - else model_args.tokenizer_name_or_path - ), + model_args.model_name_or_path, revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, ) - if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id - - if data_args.truncation_side is not None: - tokenizer.truncation_side = data_args.truncation_side - - # Set reasonable default for models without max length - if tokenizer.model_max_length > 100_000: - tokenizer.model_max_length = 2048 - if data_args.chat_template is not None: - tokenizer.chat_template = data_args.chat_template - elif auto_set_chat_template and tokenizer.get_chat_template() is None: - tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE + if training_args.chat_template is not None: + tokenizer.chat_template = training_args.chat_template return tokenizer -def get_peft_config(model_args: ModelArguments) -> PeftConfig | None: - if model_args.use_peft is False: - return None - - peft_config = LoraConfig( - r=model_args.lora_r, - lora_alpha=model_args.lora_alpha, - lora_dropout=model_args.lora_dropout, - bias="none", - task_type="CAUSAL_LM", - target_modules=model_args.lora_target_modules, - modules_to_save=model_args.lora_modules_to_save, +def get_model(model_args: ModelConfig, training_args: SFTConfig) -> AutoModelForCausalLM: + """Get the model""" + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + **model_kwargs, ) - return peft_config - - -def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: - try: - # Try first if model on a Hub repo - repo_files = list_repo_files(model_name_or_path, revision=revision) - except (HFValidationError, RepositoryNotFoundError): - # If not, check local repo - repo_files = os.listdir(model_name_or_path) - return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files - - -def get_checkpoint(training_args: SFTConfig | DPOConfig) -> Path | None: - last_checkpoint = None - if os.path.isdir(training_args.output_dir): - last_checkpoint = get_last_checkpoint(training_args.output_dir) - return last_checkpoint + return model diff --git a/tests/fixtures/config_dpo_full.yaml b/tests/fixtures/config_dpo_full.yaml index 9ed13873..d23d1d1c 100644 --- a/tests/fixtures/config_dpo_full.yaml +++ b/tests/fixtures/config_dpo_full.yaml @@ -8,7 +8,7 @@ dataset_mixer: dataset_splits: - train_prefs - test_prefs -preprocessing_num_workers: 12 +dataset_num_proc: 12 # DPOTrainer arguments bf16: true diff --git a/tests/fixtures/config_sft_full.yaml b/tests/fixtures/config_sft_full.yaml index 297dc06a..cc0fa337 100644 --- a/tests/fixtures/config_sft_full.yaml +++ b/tests/fixtures/config_sft_full.yaml @@ -10,7 +10,7 @@ dataset_mixer: dataset_splits: - train_sft - test_sft -preprocessing_num_workers: 12 +dataset_num_proc: 12 # SFT trainer config bf16: true diff --git a/tests/test_configs.py b/tests/test_configs.py deleted file mode 100644 index 2a4a7a6d..00000000 --- a/tests/test_configs.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import unittest - -from alignment import DataArguments, H4ArgumentParser, ModelArguments, SFTConfig - - -class H4ArgumentParserTest(unittest.TestCase): - def setUp(self): - self.parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) - self.yaml_file_path = "tests/fixtures/config_sft_full.yaml" - - def test_load_yaml(self): - model_args, data_args, training_args = self.parser.parse_yaml_file(os.path.abspath(self.yaml_file_path)) - self.assertEqual(model_args.model_name_or_path, "mistralai/Mistral-7B-v0.1") - - def test_load_yaml_and_args(self): - command_line_args = [ - "--model_name_or_path=test", - "--use_peft=true", - "--lora_r=16", - "--lora_dropout=0.5", - ] - model_args, data_args, training_args = self.parser.parse_yaml_and_args( - os.path.abspath(self.yaml_file_path), command_line_args - ) - self.assertEqual(model_args.model_name_or_path, "test") - self.assertEqual(model_args.use_peft, True) - self.assertEqual(model_args.lora_r, 16) - self.assertEqual(model_args.lora_dropout, 0.5) diff --git a/tests/test_data.py b/tests/test_data.py index bcf600f5..276a0b66 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -13,184 +13,111 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest -from copy import deepcopy import pytest -from datasets import Dataset -from transformers import AutoTokenizer -from alignment import DataArguments, ModelArguments, apply_chat_template, get_datasets, get_tokenizer -from alignment.data import maybe_insert_system_message - - -class GetDatasetsTest(unittest.TestCase): - """Each of these test datasets has 100 examples""" - - def test_loading_data_args(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 0.5, - "HuggingFaceH4/testing_self_instruct_small": 0.3, - "HuggingFaceH4/testing_codealpaca_small": 0.2, +from alignment import ScriptArguments, get_dataset + + +class GetDatasetTest(unittest.TestCase): + """Test the new get_dataset() method with dataset_mixture API""" + + def test_loading_dataset_mixture(self): + dataset_mixture = { + "datasets": [ + {"id": "HuggingFaceH4/testing_alpaca_small", "columns": ["prompt", "completion"], "weight": 0.5}, + { + "id": "HuggingFaceH4/testing_self_instruct_small", + "columns": ["prompt", "completion"], + "weight": 0.3, + }, + {"id": "HuggingFaceH4/testing_codealpaca_small", "columns": ["prompt", "completion"], "weight": 0.2}, + ], + "seed": 42, + "test_split_size": 0.1, } - data_args = DataArguments(dataset_mixer=dataset_mixer) - datasets = get_datasets(data_args, columns_to_keep=["prompt", "completion"]) - self.assertEqual(len(datasets["train"]), 100) - self.assertEqual(len(datasets["test"]), 300) - - def test_loading_data_dict(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 0.5, - "HuggingFaceH4/testing_self_instruct_small": 0.3, - "HuggingFaceH4/testing_codealpaca_small": 0.2, + args = ScriptArguments(dataset_mixture=dataset_mixture) + datasets = get_dataset(args) + # With weights 0.5, 0.3, 0.2 on 100-sample datasets and test_split_size=0.1 + # Total samples = 50 + 30 + 20 = 100 + # Train: 90, Test: 10 + self.assertEqual(len(datasets["train"]), 90) + self.assertEqual(len(datasets["test"]), 10) + + def test_loading_dataset_mixture_no_test_split(self): + dataset_mixture = { + "datasets": [ + {"id": "HuggingFaceH4/testing_alpaca_small", "columns": ["prompt", "completion"], "weight": 0.5}, + { + "id": "HuggingFaceH4/testing_self_instruct_small", + "columns": ["prompt", "completion"], + "weight": 0.3, + }, + {"id": "HuggingFaceH4/testing_codealpaca_small", "columns": ["prompt", "completion"], "weight": 0.2}, + ], + "seed": 42, } - datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"]) + args = ScriptArguments(dataset_mixture=dataset_mixture) + datasets = get_dataset(args) + # Total samples = 50 + 30 + 20 = 100 (all in train split) self.assertEqual(len(datasets["train"]), 100) - self.assertEqual(len(datasets["test"]), 300) - - def test_loading_with_unit_fractions(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 1.0, - "HuggingFaceH4/testing_self_instruct_small": 1.0, - "HuggingFaceH4/testing_codealpaca_small": 1.0, + self.assertNotIn("test", datasets) + + def test_loading_with_unit_weights(self): + dataset_mixture = { + "datasets": [ + {"id": "HuggingFaceH4/testing_alpaca_small", "columns": ["prompt", "completion"], "weight": 1.0}, + { + "id": "HuggingFaceH4/testing_self_instruct_small", + "columns": ["prompt", "completion"], + "weight": 1.0, + }, + {"id": "HuggingFaceH4/testing_codealpaca_small", "columns": ["prompt", "completion"], "weight": 1.0}, + ], + "seed": 42, + "test_split_size": 0.1, } - datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"]) - self.assertEqual(len(datasets["train"]), 300) - self.assertEqual(len(datasets["test"]), 300) - - def test_loading_with_fractions_greater_than_unity(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 0.7, - "HuggingFaceH4/testing_self_instruct_small": 0.4, - } - datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"]) - self.assertEqual(len(datasets["train"]), 70 + 40) - self.assertEqual(len(datasets["test"]), 200) - - def test_loading_fails_with_negative_fractions(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 0.7, - "HuggingFaceH4/testing_self_instruct_small": -0.3, - } - with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."): - get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"]) - - def test_loading_single_split_with_unit_fractions(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 1.0, + args = ScriptArguments(dataset_mixture=dataset_mixture) + datasets = get_dataset(args) + # Total samples = 100 + 100 + 100 = 300 + # Train: 270, Test: 30 + self.assertEqual(len(datasets["train"]), 270) + self.assertEqual(len(datasets["test"]), 30) + + def test_loading_with_fractional_weights(self): + dataset_mixture = { + "datasets": [ + {"id": "HuggingFaceH4/testing_alpaca_small", "columns": ["prompt", "completion"], "weight": 0.7}, + { + "id": "HuggingFaceH4/testing_self_instruct_small", + "columns": ["prompt", "completion"], + "weight": 0.4, + }, + ], + "seed": 42, + "test_split_size": 0.1, } - datasets = get_datasets(dataset_mixer, splits=["test"], columns_to_keep=["prompt", "completion"]) + args = ScriptArguments(dataset_mixture=dataset_mixture) + datasets = get_dataset(args) + # Total samples = 70 + 40 = 110 + # Train: 99, Test: 11 + self.assertEqual(len(datasets["train"]), 99) + self.assertEqual(len(datasets["test"]), 11) + + def test_loading_fails_with_invalid_dataset_mixture(self): + # Test that invalid dataset_mixture configuration raises error + with pytest.raises(ValueError, match=r"'datasets' must be a list"): + _ = ScriptArguments(dataset_mixture={"datasets": "invalid"}) + + with pytest.raises(ValueError, match=r"dataset_mixture must be a dictionary"): + _ = ScriptArguments(dataset_mixture="invalid") + + def test_loading_single_dataset(self): + # Test loading a single dataset using dataset_name instead of dataset_mixture + args = ScriptArguments(dataset_name="HuggingFaceH4/testing_alpaca_small") + datasets = get_dataset(args) + # Single dataset should have both train and test splits + self.assertIn("train", datasets) + self.assertEqual(len(datasets["train"]), 100) + self.assertIn("test", datasets) self.assertEqual(len(datasets["test"]), 100) - self.assertRaises(KeyError, lambda: datasets["train"]) - - -class ApplyChatTemplateTest(unittest.TestCase): - def setUp(self): - model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha") - data_args = DataArguments() - self.tokenizer = get_tokenizer(model_args, data_args) - self.dataset = Dataset.from_dict( - { - "prompt": ["Hello!"], - "messages": [ - [ - {"role": "system", "content": "You are a happy chatbot"}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Bonjour!"}, - {"role": "user", "content": "How are you?"}, - {"role": "assistant", "content": "I am doing well, thanks!"}, - ] - ], - "chosen": [ - [ - {"role": "system", "content": "You are a happy chatbot"}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Bonjour!"}, - {"role": "user", "content": "How are you?"}, - {"role": "assistant", "content": "I am doing well, thanks!"}, - ] - ], - "rejected": [ - [ - {"role": "system", "content": "You are a happy chatbot"}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Bonjour!"}, - {"role": "user", "content": "How are you?"}, - {"role": "assistant", "content": "Not so good tbh"}, - ] - ], - } - ) - - def test_maybe_insert_system_message(self): - # Chat template that does not accept system prompt. Use community checkpoint since it has no HF token requirement - tokenizer_sys_excl = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-Instruct-v0.3") - # Chat template that accepts system prompt - tokenizer_sys_incl = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct") - messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}] - messages_sys_incl = [{"role": "system", "content": ""}, {"role": "user", "content": "Tell me a joke."}] - - messages_proc_excl = deepcopy(messages_sys_excl) - message_proc_incl = deepcopy(messages_sys_excl) - maybe_insert_system_message(messages_proc_excl, tokenizer_sys_excl) - maybe_insert_system_message(message_proc_incl, tokenizer_sys_incl) - - # output from mistral should not have a system message, output from llama should - self.assertEqual(messages_proc_excl, messages_sys_excl) - self.assertEqual(message_proc_incl, messages_sys_incl) - - def test_sft(self): - dataset = self.dataset.map( - apply_chat_template, - fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"}, - remove_columns=self.dataset.column_names, - ) - self.assertDictEqual( - dataset[0], - { - "text": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\nI am doing well, thanks!\n" - }, - ) - - def test_generation(self): - # Remove last turn from messages - dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]}) - dataset = dataset.map( - apply_chat_template, - fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"}, - remove_columns=self.dataset.column_names, - ) - self.assertDictEqual( - dataset[0], - { - "text": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\n" - }, - ) - - def test_rm(self): - dataset = self.dataset.map( - apply_chat_template, - fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"}, - remove_columns=self.dataset.column_names, - ) - self.assertDictEqual( - dataset[0], - { - "text_chosen": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\nI am doing well, thanks!\n", - "text_rejected": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\nNot so good tbh\n", - }, - ) - - def test_dpo(self): - dataset = self.dataset.map( - apply_chat_template, - fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"}, - remove_columns=self.dataset.column_names, - ) - self.assertDictEqual( - dataset[0], - { - "text_prompt": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n", - "text_chosen": "<|assistant|>\nI am doing well, thanks!\n", - "text_rejected": "<|assistant|>\nNot so good tbh\n", - }, - ) diff --git a/tests/test_decontaminate.py b/tests/test_decontaminate.py deleted file mode 100644 index 334501ef..00000000 --- a/tests/test_decontaminate.py +++ /dev/null @@ -1,48 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from unittest import TestCase - -from datasets import Dataset -from transformers import AutoTokenizer - -from alignment import apply_chat_template, decontaminate_humaneval - - -class DecontamintateHumanEvalTest(TestCase): - """Test we decontaminate HumanEval samples correctly""" - - def setUp(self) -> None: - # Create a dataset with a HumanEval sample wrapped in some fake text - dataset = Dataset.from_dict( - { - "messages": [ - [{"content": "Hello", "role": "user"}], - [ - { - "content": 'Hello, I am\nfrom\n\n typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n """ Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n', - "role": "assistant", - } - ], - ] - } - ) - tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - self.dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "task": "sft"}) - - def test_decontamination(self): - """Test we decontaminate HumanEval samples correctly""" - decontaminated_dataset = self.dataset.filter(decontaminate_humaneval, batched=True) - # Check we recover just the first message - self.assertEqual(decontaminated_dataset[0]["text"], self.dataset[0]["text"]) diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py deleted file mode 100644 index 16ada923..00000000 --- a/tests/test_model_utils.py +++ /dev/null @@ -1,88 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -from alignment import ( - DataArguments, - ModelArguments, - get_peft_config, - get_quantization_config, - get_tokenizer, - is_adapter_model, -) -from alignment.data import DEFAULT_CHAT_TEMPLATE - - -class GetQuantizationConfigTest(unittest.TestCase): - def test_4bit(self): - model_args = ModelArguments(load_in_4bit=True) - quantization_config = get_quantization_config(model_args) - self.assertTrue(quantization_config["load_in_4bit"]) - self.assertEqual(quantization_config["bnb_4bit_compute_dtype"], "float16") - self.assertEqual(quantization_config["bnb_4bit_quant_type"], "nf4") - self.assertFalse(quantization_config["bnb_4bit_use_double_quant"]) - - def test_8bit(self): - model_args = ModelArguments(load_in_8bit=True) - quantization_config = get_quantization_config(model_args) - self.assertTrue(quantization_config["load_in_8bit"]) - - def test_no_quantization(self): - model_args = ModelArguments() - quantization_config = get_quantization_config(model_args) - self.assertIsNone(quantization_config) - - -class GetTokenizerTest(unittest.TestCase): - def setUp(self) -> None: - self.model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha") - - def test_right_truncation_side(self): - tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="right")) - self.assertEqual(tokenizer.truncation_side, "right") - - def test_left_truncation_side(self): - tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="left")) - self.assertEqual(tokenizer.truncation_side, "left") - - def test_default_chat_template(self): - tokenizer = get_tokenizer(self.model_args, DataArguments()) - self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE) - - def test_chatml_chat_template(self): - chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" - tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template)) - self.assertEqual(tokenizer.chat_template, chat_template) - - -class GetPeftConfigTest(unittest.TestCase): - def test_peft_config(self): - model_args = ModelArguments(use_peft=True, lora_r=42, lora_alpha=0.66, lora_dropout=0.99) - peft_config = get_peft_config(model_args) - self.assertEqual(peft_config.r, 42) - self.assertEqual(peft_config.lora_alpha, 0.66) - self.assertEqual(peft_config.lora_dropout, 0.99) - - def test_no_peft_config(self): - model_args = ModelArguments(use_peft=False) - peft_config = get_peft_config(model_args) - self.assertIsNone(peft_config) - - -class IsAdapterModelTest(unittest.TestCase): - def test_is_adapter_model_calls_listdir(self): - # Assert that for an invalid repo name it gets to the point where it calls os.listdir, - # which is expected to raise a FileNotFoundError - self.assertRaises(FileNotFoundError, is_adapter_model, "nonexistent/model")