diff --git a/.gitignore b/.gitignore index 646848ade7..7ab517ff20 100644 --- a/.gitignore +++ b/.gitignore @@ -84,6 +84,7 @@ ENV/ logs/ # data-juicer +tmp/ outputs/ # agentscope runs/ diff --git a/README.md b/README.md index e9d12666db..c8c50999f7 100644 --- a/README.md +++ b/README.md @@ -148,8 +148,11 @@ pip install -e .\[dev\] # Install flash-attn after all dependencies are installed # Note: flash-attn will take a long time to compile, please be patient. -pip install flash-attn -v -# Try the following command if you encounter errors during installation +# for bash +pip install -e .[flash_attn] +# for zsh +pip install -e .\[flash_attn\] +# Try the following command if you encounter errors during flash-attn installation # pip install flash-attn -v --no-build-isolation ``` @@ -263,7 +266,7 @@ Then, for command-line users, run the RFT process with the following command: trinity run --config ``` -> For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm: +> For example, below is the command for fine-tuning Qwen2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm: > ```shell > trinity run --config examples/grpo_gsm8k/gsm8k.yaml > ``` @@ -276,7 +279,7 @@ For more detailed examples about how to use Trinity-RFT, please refer to the fol + [Off-policy mode of RFT](./docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md) + [Asynchronous mode of RFT](./docs/sphinx_doc/source/tutorial/example_async_mode.md) + [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md) -+ [Offline learning by DPO](./docs/sphinx_doc/source/tutorial/example_dpo.md) ++ [Offline learning by DPO or SFT](./docs/sphinx_doc/source/tutorial/example_dpo.md) + [Advanced data processing / human-in-the-loop](./docs/sphinx_doc/source/tutorial/example_data_functionalities.md) diff --git a/docs/sphinx_doc/source/conf.py b/docs/sphinx_doc/source/conf.py index 4842a34557..ffaabf72c9 100644 --- a/docs/sphinx_doc/source/conf.py +++ b/docs/sphinx_doc/source/conf.py @@ -22,12 +22,13 @@ "sphinx.ext.napoleon", "sphinx.ext.autosectionlabel", "myst_parser", + "sphinx.ext.mathjax", ] source_suffix = { ".rst": "restructuredtext", ".md": "markdown", } -myst_enable_extensions = ["colon_fence"] +myst_enable_extensions = ["colon_fence", "amsmath", "dollarmath"] # Prefix document path to section labels, otherwise autogenerated labels would # look like 'heading' rather than 'path/to/file:heading' diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst index 5604faa15d..4b4cab2aa9 100644 --- a/docs/sphinx_doc/source/index.rst +++ b/docs/sphinx_doc/source/index.rst @@ -14,7 +14,7 @@ Welcome to Trinity-RFT's documentation! :maxdepth: 1 :glob: :hidden: - :caption: Tutorial + :caption: Examples tutorial/example_reasoning_basic.md tutorial/example_reasoning_advanced.md @@ -22,8 +22,16 @@ Welcome to Trinity-RFT's documentation! tutorial/example_multi_turn.md tutorial/example_dpo.md tutorial/example_data_functionalities.md - tutorial/trinity_configs.md + +.. toctree:: + :maxdepth: 2 + :glob: + :hidden: + :caption: Guidelines + tutorial/trinity_programming_guide.md + tutorial/trinity_configs.md + tutorial/example_mix_algo.md .. toctree:: :maxdepth: 1 @@ -33,6 +41,7 @@ Welcome to Trinity-RFT's documentation! build_api/trinity.buffer build_api/trinity.explorer build_api/trinity.trainer + build_api/trinity.algorithm build_api/trinity.manager build_api/trinity.common build_api/trinity.utils diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md index 7494212374..c7d9f6385b 100644 --- a/docs/sphinx_doc/source/main.md +++ b/docs/sphinx_doc/source/main.md @@ -84,15 +84,18 @@ e.g., utilizing NCCL (when feasible) for model weight synchronization, sequence ## Getting started - -*Note: this project is currently under active development; comments and suggestions are welcome!* - +```{note} +Note: This project is currently under active development; comments and suggestions are welcome! +``` ### Step 1: preparations - +Trinity-RFT requires +Python version >= 3.10, +CUDA version >= 12.4, +and at least 2 GPUs. Installation from source (recommended): @@ -146,11 +149,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest . docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft:latest ``` -Trinity-RFT requires -Python version >= 3.10, -CUDA version >= 12.4, -and at least 2 GPUs. - ### Step 2: prepare dataset and model @@ -243,7 +241,7 @@ trinity run --config -For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm: +For example, below is the command for fine-tuning Qwen2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm: ```shell trinity run --config examples/grpo_gsm8k/gsm8k.yaml @@ -251,7 +249,7 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml -More example config files can be found in `examples`. +More example config files can be found in [`examples`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/). @@ -260,7 +258,7 @@ For more detailed examples about how to use Trinity-RFT, please refer to the fol + [Off-policy mode of RFT](tutorial/example_reasoning_advanced.md) + [Asynchronous mode of RFT](tutorial/example_async_mode.md) + [Multi-turn tasks](tutorial/example_multi_turn.md) -+ [Offline learning by DPO](tutorial/example_dpo.md) ++ [Offline learning by DPO or SFT](tutorial/example_dpo.md) + [Advanced data processing / human-in-the-loop](tutorial/example_data_functionalities.md) diff --git a/docs/sphinx_doc/source/tutorial/example_async_mode.md b/docs/sphinx_doc/source/tutorial/example_async_mode.md index a565145d83..1f9a9c8665 100644 --- a/docs/sphinx_doc/source/tutorial/example_async_mode.md +++ b/docs/sphinx_doc/source/tutorial/example_async_mode.md @@ -1,6 +1,6 @@ # Asynchronous RFT -This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset. +This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen2.5-1.5B-Instruct model and GSM8K dataset. Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes. diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md index d62f56de3f..27b5fb26bf 100644 --- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md @@ -1,80 +1,97 @@ # Data Processing -## Example: reasoning task +## Example: Data Processor for Task Pipeline -In this example, you will learn how to apply the data module of Trinity-RFT to prepare the dataset before exploring and training. This example takes GSM-8K dataset as the example dataset to figure out: +In this example, you will learn how to apply the data processor of Trinity-RFT to prepare and prioritize the dataset before task exploring and training. This example takes GSM-8K dataset as the example dataset to figure out: -1. how to prepare the data module -2. how to configure the data module -3. what the data module can do +1. how to prepare the data processor +2. how to configure the data processor +3. what the data processor can do -Before getting started, you need to prepare the main environment of Trinity-RFT according to the [installation section of the README file](../main.md), and you need to install [postgresql](https://www.postgresql.org/docs/current/tutorial-install.html) as well. +Before getting started, you need to prepare the main environment of Trinity-RFT according to the [installation section of the README file](../main.md). ### Data Preparation -#### Prepare the Data Module +#### Prepare the Data Processor -As the overall framework of Trinity-RFT shows, the data module is one of the high-level functions. Trinity-RFT encapsulates the data module as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server. +As the overall framework of Trinity-RFT shows, the data processor is one of the high-level functions. Trinity-RFT encapsulates the data processor as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server. ```shell -# prepare split environments, including the one of data module +# prepare split environments, including the one of data processor python scripts/install.py # start all split servers python scripts/start_servers.py ``` -### Configure the Data Module +### Configure the Data Processor -Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data` section in the config file. +Trinity-RFT uses a unified config file to manage all config items. For the data processor, you need to focus on the `data_processor` section in the config file. In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example: ```yaml data_processor: - # basic info - source_data_path: '/path/to/gsm8k' - load_kwargs: - split: 'train' # only need the train split - format: # set the field mappings - prompt_key: 'question' - response_key: 'answer' - # database related. The result dataset will be stored in the database. - db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # task pipeline related + task_pipeline: + # I/O buffers + input_buffers: + - name: 'raw_input' + path: /PATH/TO/GSM8K/ + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: /PATH/TO/OUTPUT/JSONL/FILE + storage_type: 'file' + # format mapping + format: + prompt_key: 'question' + response_key: 'answer' ``` -Here you can set the basic information for the GSM-8K dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training: +Here you can set the basic buffers for the GSM-8K dataset input and output and some other items about downstream dataset loading for exploring and training: -+ `source_data_path`: the path to the raw dataset. -+ `load_kwargs`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library. -+ `format`: some dataset format config items, which are used to map original data field names to unified ones. -+ `db_url`: the URL of the postgresql database to store the result dataset. ++ `data_processor_url`: the URL of the data processor service, which is started in the previous step. ++ `task_pipeline`: the configs for the task pipeline. Task pipeline is used to process the raw dataset. It consists of several inner configs: + + `input_buffers`: the input buffers for the task pipeline. We usually load from raw dataset files in this pipeline, thus we need to the dataset `path` and set the `storage_type` to "file" and set `raw` to True. It allows multiple input buffers. We can name each buffer with the `name` field. + + `output_buffer`: the output buffer for the task pipeline. We usually store the processed dataset in files as well, thus we need to set the `storage_type` to "file". + + `format`: some dataset format config items, which are used to map original data field names to unified ones. -In addition, there are several config items related to the data active iterator, which is used to prepare a better dataset. The core part of the data active iterator, Data-Juicer, provides tens of operators to help clean or calculate key information for each sample in the dataset. You can configure this part depending on how familiar you are with Data-Juicer. +In addition, there are several config items related to the data active iterator in `task_pipeline` part, which is used to prepare a better dataset. The core part of the data active iterator, Data-Juicer, provides tens of operators to help clean or calculate key information for each sample in the dataset. You can configure this part depending on how familiar you are with Data-Juicer. #### Not familiar with Data-Juicer -If you are not familiar with Data-Juicer, the data module provides a natural-language-based method to config the data processing recipe. What you need to do is only describe the demands of how you want to prepare for the raw dataset, and an agent will be invoked to arrange the data processing recipe for you. Here is an example: +If you are not familiar with Data-Juicer, the data processor provides a natural-language-based method to config the data processing recipe. What you need to do is only describe the demands of how you want to prepare for the raw dataset, and an agent will be invoked to arrange the data processing recipe for you. Here is an example: ```yaml data_processor: - # basic info - source_data_path: '/path/to/gsm8k' - load_kwargs: - split: 'train' # only need the train split - format: # set the field mappings - prompt_key: 'question' - response_key: 'answer' - # database related. The result dataset will be stored in the database. - db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' - - #### new part about data active iterator - dj_process_desc: 'Please compute difficulty scores for these math questions.' - agent_model_name: 'qwen-max' - agent_model_config: - config_name: 'my-qwen-instruction' - model_type: 'dashscope_chat' - model_name: 'qwen2.5-72b-instruct' - clean_strategy: 'iterative' + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # task pipeline related + task_pipeline: + # I/O buffers + input_buffers: + - name: 'raw_input' + path: /PATH/TO/GSM8K/ + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: /PATH/TO/OUTPUT/JSONL/FILE + storage_type: 'file' + # format mapping + format: + prompt_key: 'question' + response_key: 'answer' + + #### new part about data active iterator + dj_process_desc: 'Please compute difficulty scores for these math questions.' + agent_model_name: 'qwen-max' + agent_model_config: + config_name: 'my-qwen-instruction' + model_type: 'dashscope_chat' + model_name: 'qwen2.5-72b-instruct' + clean_strategy: 'iterative' ``` You can write your demand description in config item `dj_process_desc`, and set the model name and configs used for the agent in config items `agent_model_name` and `agent_model_config`. Here we use Qwen2.5-72b-Instruct as our recipe managing agent. And you can set the `clean_strategy` to 'iterative' to get a better dataset. @@ -99,19 +116,27 @@ After preparing the Data-Juicer data processing recipe, you can set the `dj_conf ```yaml data_processor: - # basic info - source_data_path: '/path/to/gsm8k' - load_kwargs: - split: 'train' # only need the train split - format: # set the field mappings - prompt_key: 'question' - response_key: 'answer' - # database related. The result dataset will be stored in the database. - db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' - - #### new part about data active iterator - dj_config_path: '/path/to/the/Data-Juicer/data/processing/recipe/above.yaml' - clean_strategy: 'iterative' + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # task pipeline related + task_pipeline: + # I/O buffers + input_buffers: + - name: 'raw_input' + path: /PATH/TO/GSM8K/ + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: /PATH/TO/OUTPUT/JSONL/FILE + storage_type: 'file' + # format mapping + format: + prompt_key: 'question' + response_key: 'answer' + + #### new part about data active iterator + dj_config_path: '/path/to/the/Data-Juicer/data/processing/recipe/above.yaml' + clean_strategy: 'iterative' ``` And you can set the `clean_strategy` to 'iterative' to get a better dataset. @@ -123,7 +148,7 @@ All config items in the `data` section can be found [here](trinity_configs.md). ```{note} -Only when one of `dj_process_desc` and `dj_config_path` is provided, the data module and the data active iterator will be activated. Otherwise, this part will be skipped and it will enter into the exploring stage directly. +Only when one of `xxx_pipeline` is provided, and one of `dj_process_desc` and `dj_config_path` in the pipeline config is provided, the data processor and the data active iterator will be activated. Otherwise, this part will be skipped and it will enter into the exploring stage directly. ``` ### Exploring & Training @@ -140,54 +165,59 @@ ray start --address= trinity run --config ``` -If you follow the steps above, Trinity-RFT will send a request to the data module server, the data active iterator will be activated and compute difficulty scores for each sample in the raw dataset. After that, the data module server stores the result dataset into the database, when exploring begins, it will load the prepared dataset and continue the downstream steps. +If you follow the steps above, Trinity-RFT will send a request to the data processor server, the data active iterator will be activated, compute difficulty scores for each sample in the raw dataset, and rank the dataset according to difficulty scores. After that, the data processor server stores the result dataset into the output buffer, when exploring begins, it will load the prepared dataset and continue the downstream steps. - - -## Example: human in the loop +## Example: Human in the Loop Sometimes, you might need to involve human feedbacks for some raw data. In this example, you will learn how to annotate raw data to get a better dataset before training. This example takes an example Q&A dataset and tries to select the chosen and rejected ones for DPO method. -Before getting started, you need to prepare the main environment of Trinity-RFT according to the installation section of the README file, install postgresql, and [start a label-studio server](https://github.com/modelscope/data-juicer/tree/main/tools/humanops) from Data-Juicer from source. +Before getting started, you need to prepare the main environment of Trinity-RFT according to the installation section of the README file, and [start a label-studio server](https://github.com/modelscope/data-juicer/tree/main/tools/humanops) from Data-Juicer from source. ### Data Preparation -#### Prepare the Data Module +#### Prepare the Data Processor -As the overall framework of Trinity-RFT shows, the data module is one of the high-level functions. Trinity-RFT encapsulates the data module as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server. +As the overall framework of Trinity-RFT shows, the data processor is one of the high-level functions. Trinity-RFT encapsulates the data processor as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server. ```shell -# prepare split environments, including the one of data module +# prepare split environments, including the one of data processor python scripts/install.py # start all split servers python scripts/start_servers.py ``` -### Configure the Data Module +### Configure the Data Processor -Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data` section in the config file. +Trinity-RFT uses a unified config file to manage all config items. For the data processor, you need to focus on the `data_processor` section in the config file. -In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example: +In this example, assume that you need to select the chosen and rejected responses for DPO method. So you can set these config items like the following example: ```yaml data_processor: - # basic info - source_data_path: 'tests/test_data/test_human_annotator' - load_kwargs: - split: 'train' # only need the train split - format: # set the field mappings - prompt_key: 'prompt' - chosen_key: 'chosen' - rejected_key: 'rejected' - #### new part about data active iterator - dj_config_path: 'tests/test_configs/human_annotator_test_dj_cfg.yaml' - # database related. The result dataset will be stored in the database. - db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # task pipeline related + task_pipeline: + # I/O buffers + input_buffers: + - name: 'raw_input' + path: 'tests/test_data/test_human_annotator' + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl' + storage_type: 'file' + format: # set the field mappings + prompt_key: 'prompt' + chosen_key: 'chosen' + rejected_key: 'rejected' + #### new part about data active iterator + dj_config_path: 'tests/test_configs/human_annotator_test_dj_cfg.yaml' ``` Here you can set the basic information for the example dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training, which is similar to the example above. -For this example, we assume that you are somehow familiar with the basic usage of Data-Juicer, so we need to prepare a Data-Juicer data processing recipe in `tests/test_configs/human_annotator_test_dj_cfg.yaml` that includes an OP of `human_preference_annotation_mapper`. For example: +For this example, we assume that you are somehow familiar with the basic usage of Data-Juicer, so we need to prepare a Data-Juicer data processing recipe in [`tests/test_configs/human_annotator_test_dj_cfg.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/tests/test_configs/human_annotator_test_dj_cfg.yaml) that includes an OP of `human_preference_annotation_mapper`. For example: ```yaml project_name: 'demo-human-annotator' @@ -223,7 +253,7 @@ You can set more config items for this OP (e.g. notification when annotation is ### Start Running -When you start running with the RFT config, the data module will start the OP `human_preference_annotation_mapper`, and then you can find a new project on the "Projects" page of the label-studio server. +When you start running with the RFT config, the data processor will start the OP `human_preference_annotation_mapper`, and then you can find a new project on the "Projects" page of the label-studio server. ![](../../assets/data-projects.png) diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md index a6f70f5e62..cd0c214725 100644 --- a/docs/sphinx_doc/source/tutorial/example_dpo.md +++ b/docs/sphinx_doc/source/tutorial/example_dpo.md @@ -1,12 +1,12 @@ -# Offline DPO +# Offline DPO and SFT -This example describes DPO based on the Qwen-2.5-1.5B-Instruct model and [Human-like-DPO-dataset](https://huggingface.co/datasets/HumanLLMs/Human-Like-DPO-Dataset). +This example describes DPO and SFT based on the Qwen2.5-1.5B-Instruct model. ## Step 1: Model and Data Preparation ### Model Preparation -Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`: +Download the Qwen2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`: ```shell # Using Modelscope @@ -20,7 +20,7 @@ More details of model downloading are referred to [ModelScope](https://modelscop ### Data Preparation -Download the Human-Like-DPO-Dataset dataset to the local directory `$DATASET_PATH/human_like_dpo_dataset`: +For DPO, we download the [Human-like-DPO-dataset](https://huggingface.co/datasets/HumanLLMs/Human-Like-DPO-Dataset) to the local directory `$DATASET_PATH/human_like_dpo_dataset`: ```shell # Using Modelscope @@ -34,9 +34,11 @@ More details of dataset downloading are referred to [ModelScope](https://modelsc Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pass the proper keys to the config. -## Step 2: Setup Configuration and Run Experiment +For SFT, we download the dataset to the local directory `/PATH/TO/SFT_DATASET/`, which usually contains message-based data. -### Configuration +## Step 2: Setup Configuration + +### Configuration for DPO We use the configurations in [`dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/dpo.yaml) and [`train_dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/train_dpo.yaml) for this experiment. Some important setups are listed in the following: @@ -48,9 +50,12 @@ name: mode: train algorithm: algorithm_type: dpo + kl_loss_fn: k1 + kl_loss_fn_args: + kl_coef: 0.1 # value of beta in DPO checkpoint_root_dir: /PATH/TO/CHECKPOINT/ model: - model_path: /PATH/TO/MODEL/ + model_path: $MODEL_PATH/Qwen2.5-1.5B-Instruct cluster: node_num: 1 gpu_per_node: 8 @@ -59,9 +64,9 @@ buffer: batch_size: 64 trainer_input: experience_buffer: - name: dpo_buffer + name: human_like_dpo storage_type: file - path: /PATH/TO/DATASET/ + path: $DATASET_PATH/human_like_dpo_dataset format: prompt_type: plaintext # plaintext/messages/chatpair prompt_key: prompt @@ -70,14 +75,50 @@ buffer: trainer: trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml' save_interval: 30 - actor_use_kl_loss: True - actor_kl_loss_coef: 0.1 # value of beta in DPO ``` -### Run the Experiment +### Configuration for SFT + +We set the `algorithm_type` as `sft` to run SFT process. Then we modify the config file `sft.yaml` with the following changes: + +```yaml +project: +name: +mode: train +algorithm: + algorithm_type: sft +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +model: + model_path: /PATH/TO/MODEL/ +cluster: + node_num: 1 + gpu_per_node: 2 +buffer: + total_epochs: 5 + batch_size: 64 + trainer_input: + experience_buffer: + name: + storage_type: file + path: /PATH/TO/SFT_DATASET/ + split: train + format: + prompt_type: messages + messages_key: messages +trainer: + trainer_config_path: /PATH/TO/TRAIN_CONFIG_YAML/ + save_interval: 50 +``` + +## Step 3: Run the Experiment -Run RFT process with the following command: +Run DPO process with the following command: ```shell trinity run --config examples/dpo_humanlike/dpo.yaml ``` +or, for SFT: + +```shell +trinity run --config /PATH/TO/sft.yaml +``` diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md new file mode 100644 index 0000000000..b106293eed --- /dev/null +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -0,0 +1,310 @@ +# Algorithm Development + +```{note} +This guide is an advanced version of the {ref}`Algorithms ` section in the Developer Guide. +``` + +This guide introduces how to integrate a new algorithm to Trinity-RFT. +As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective: + +$$ +\mathcal{J}_{\text{Mix}}(\theta) = +(1-\mu) \mathcal{J}_{\text{GRPO}}(\theta) ++ +\mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'} +\left[ + \frac{1}{T'_b} \sum_{t=1}^{T'_b} + \log \pi_\theta(o'_{b,t} \mid q'_b, o'_{b, }, + { "role": "user", "content": "What is the sum of 4 and 12?" }, + { "role": "assistant", "content": "thinking process...\n16" } ] +}, +... +``` +The path to expert data is passed to `buffer.trainer_input.sft_warmup_dataset` for later use. + + +## Step 1: Define the Algorithm + +In `trinity/algorithm/algorithm.py`, we introduce a new algorithm type `MIX`. + +```python +@ALGORITHM_TYPE.register_module("mix") +class MIXAlgorithm(AlgorithmType): + """MIX algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 8, + "policy_loss_fn": "mix", + "advantage_fn": "grpo", + "sample_strategy": "mix", + } +``` + + +## Step 2: Define the Sampling Strategy + +We need to read two kinds of experiences: usual experiences and expert experiences in each step. For this purpose, we define a new experience sampling strategy named `MixSampleStrategy`. + + +```python +class MixSampleStrategy(SampleStrategy): + """The default sample strategy.""" + + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) + tot_batch_size = buffer_config.read_batch_size + expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) + + # experience buffer + usual_buffer_config = copy.deepcopy(buffer_config) + usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size + self.usual_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore + ) + + if buffer_config.trainer_input.sft_warmup_dataset is None: + raise ValueError( + "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" + ) + + # expert experience buffer + expert_buffer_config = copy.deepcopy(buffer_config) + expert_buffer_config.read_batch_size = expert_batch_size + self.expert_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config + ) + + def sample(self, step: int) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + usual_exp_list = self.usual_exp_buffer.read() + for exp in usual_exp_list: + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = False + + expert_exp_list = self.expert_exp_buffer.read() + for exp in expert_exp_list: + exp.reward = 0.0 + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = True + + exp_list = usual_exp_list + expert_exp_list + repr_samples = representative_sample(exp_list) + + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) + + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto_mix(exps, is_expert_mask) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") +``` + +We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type. + +```diff ++ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), ++ "is_expert_mask": is_expert_mask, + } + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + return DataProto.from_single_dict(batch_dict) +``` + + +## Step 3: Define the Policy Loss Function + +We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively. + +```python +@POLICY_LOSS_FN.register_module("mix") +class MIXPolicyLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + mu: float = 0.1, + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + use_dynamic_bsz: Optional[bool] = None, + repeat_times: Optional[int] = None, + ppo_mini_batch_size: Optional[int] = None, + ppo_micro_batch_size_per_gpu: Optional[int] = None, + ngpus_trainer: Optional[int] = None, + read_batch_size_usual: Optional[int] = None, + read_batch_size_expert: Optional[int] = None, + use_token_level_loss_in_sft: bool = True, + ) -> None: + super().__init__(backend=backend) + self.mu = mu + self.use_dynamic_bsz = use_dynamic_bsz + self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore + self.gradient_accumulation = ( + ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore + ) + self.read_batch_size_usual = read_batch_size_usual + self.read_batch_size_expert = read_batch_size_expert + self.grpo_loss_fn = PPOPolicyLossFn( + clip_range=clip_range, + clip_range_low=clip_range_low, + clip_range_high=clip_range_high, + ) + self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft) + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + is_expert_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + assert ( + len(is_expert_mask) == logprob.shape[0] + ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" + + n_usual_exp = torch.sum(~is_expert_mask).item() + n_expert_exp = torch.sum(is_expert_mask).item() + + if self.use_dynamic_bsz: + per_micro_batch_weight_usual = self.experience_per_gpu / ( + logprob.shape[0] * self.read_batch_size_usual + ) + per_micro_batch_weight_expert = self.experience_per_gpu / ( + logprob.shape[0] * self.read_batch_size_expert + ) + else: + per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore + per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore + + if n_usual_exp > 0: + grpo_loss, grpo_metrics = self.grpo_loss_fn( + logprob[~is_expert_mask], + old_logprob[~is_expert_mask], + action_mask[~is_expert_mask], + advantages[~is_expert_mask], + **kwargs, + ) + grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual + grpo_metrics = { + k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() + } + else: + grpo_loss = torch.tensor(0.0, device=logprob.device) + grpo_metrics = {} + + # SFT Loss (expert) + if n_expert_exp > 0: + sft_loss, sft_metrics = self.sft_loss_fn( + logprob[is_expert_mask], + action_mask[is_expert_mask], + ) + sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert + sft_metrics = { + k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() + } + else: + sft_loss = torch.tensor(0.0, device=logprob.device) + sft_metrics = {} + + loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss + + metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} + metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + metrics.update({"loss": loss.item()}) + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "mu": 0.1, + "clip_range": 0.2, + } +``` + +## Step 4: Run the Experiment + +With the above newly-defined classes and functions, we can run the experiments without modifying other process. +An example showing some important configurations is shown below, including the weighting factor $\mu$ as `algorithm.policy_loss_fn_args['mu']` and the batch size of expert experiences $B'$, calculated as the product of `buffer.batch_size`, `algorithm.sample_strategy_args['expert_data_ratio']` and `algorithm.repeat_times`. +For the full configuration, please refer to [`mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/mix_math.yaml) and [`train_mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/train_mix_math.yaml). + +```yaml +algorithm: + algorithm_type: mix + repeat_times: 8 + sample_strategy_args: + expert_data_ratio: 0.25 + policy_loss_fn_args: + mu: 0.1 + clip_range: 0.2 + use_token_level_loss_in_sft: False + use_dynamic_bsz: False + repeat_times: 8 + ppo_mini_batch_size: 32 + ppo_micro_batch_size_per_gpu: 4 + ngpus_trainer: 4 + read_batch_size_expert: 64 + read_batch_size_usual: 192 +``` + +With the above configurations, the experiment can be run with the following command: + +```bash +trinity run --config examples/mix_math/mix_math.yaml +``` diff --git a/docs/sphinx_doc/source/tutorial/example_multi_turn.md b/docs/sphinx_doc/source/tutorial/example_multi_turn.md index 46cc4ab32e..3cf5b89145 100644 --- a/docs/sphinx_doc/source/tutorial/example_multi_turn.md +++ b/docs/sphinx_doc/source/tutorial/example_multi_turn.md @@ -15,8 +15,8 @@ To run the ALFworld and WebShop env, you need to setup the corresponding environ - WebShop is a simulated online shopping environment where AI agents learn to shop based on user requirements. The platform allows agents to browse products, compare options, and make purchase decisions, mimicking real-world e-commerce interactions. You may refer to their original environment to complete the setup. -- For ALFworld, refer to: https://github.com/alfworld/alfworld -- For WebShop, refer to: https://github.com/princeton-nlp/WebShop +- For ALFWorld, refer to the [ALFWorld](https://github.com/alfworld/alfworld) repository. +- For WebShop, refer to the [WebShop](https://github.com/princeton-nlp/WebShop) repository. ### Data Preparation Our dataset follows the format in Huggingface datasets library, so we should correspondingly convert our env dataset. @@ -36,7 +36,7 @@ The task is described as an environment instead of a single prompt. ## Step 2: Config preparation and run the experiment -You can refer to `example_reasoning_basic` to setup the config and others. The default config files are [`alfworld.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_alfworld/alfworld.yaml) and [`webshop.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_webshop/webshop.yaml), respectively. +You can refer to [Quick Start](./example_reasoning_basic.md) to setup the config and others. The default config files are [`alfworld.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_alfworld/alfworld.yaml) and [`webshop.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_webshop/webshop.yaml), respectively. You may revise the configurations properly and run the experiment! ```bash @@ -104,7 +104,7 @@ class AlfworldWorkflow(MultiTurnWorkflow): ... ``` -and include them in the init files in `trinity/common/workflows/__init__.py` +and include it in the init file `trinity/common/workflows/__init__.py` ```diff # -*- coding: utf-8 -*- @@ -120,7 +120,7 @@ and include them in the init files in `trinity/common/workflows/__init__.py` ] ``` -Then you are all set! It should be pretty simple😄, and both environments converge. +Then you are all set! It should be pretty simple😄, and the training processes in both environments converge. ![](../../assets/alfworld_reward_curve.png) ![](../../assets/webshop_reward_curve.png) diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md index a80032bc12..aa4439e866 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md @@ -6,7 +6,7 @@ Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) a - +(OPMD)= ## OPMD: a native off-policy RL algorithm diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md index dc45994e98..8d8309a913 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md @@ -37,6 +37,12 @@ pip install flash-attn -v # pip install flash-attn -v --no-build-isolation ``` +Installation using pip: + +```shell +pip install trinity-rft +``` + Installation from docker: We provided a dockerfile for Trinity-RFT. @@ -60,7 +66,7 @@ docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v 1. +- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. +- `advantage_fn`: The advantage function used for computing advantages. +- `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward. +- `kl_loss_fn`: The KL loss function used for computing KL loss. +- `entropy_loss_fn`: The entropy loss function used for computing entropy loss. + --- @@ -100,8 +110,8 @@ monitor: ``` - `monitor_type`: Type of monitoring system. Options: - - `wandb`: Logs to Weights & Biases. Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs. - - `tensorboard`: Logs to TensorBoard. Files are saved under `///monitor/tensorboard`. + - `wandb`: Logs to [Weights & Biases](https://docs.wandb.ai/quickstart/). Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs. + - `tensorboard`: Logs to [TensorBoard](https://www.tensorflow.org/tensorboard). Files are saved under `///monitor/tensorboard`. --- @@ -111,13 +121,13 @@ Defines the model paths and token limits. ```yaml model: - model_path: '/PATH/TO/MODEL/CHECKPOINT/' + model_path: /PATH/TO/MODEL/ critic_model_path: '' max_prompt_tokens: 4096 max_response_tokens: 16384 ``` -- `model_path`: Path to the model checkpoint being trained. +- `model_path`: Path to the model being trained. - `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`. - `max_prompt_tokens`: Maximum number of tokens allowed in input prompts. - `max_response_tokens`: Maximum number of tokens allowed in generated responses. @@ -164,8 +174,8 @@ buffer: default_reward_fn_type: 'countdown_reward' ``` -- `batch_size`: Number of samples used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*. -- `total_epochs`: Total number of training epochs. Not applicable for streaming datasets (e.g., queue-based buffers). +- `batch_size`: Number of tasks used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*. +- `total_epochs`: Total number of training epochs. ### Explorer Input @@ -216,6 +226,8 @@ The configuration for each task dataset is defined as follows: - For `file` storage type, the path is the path to the directory that contains the task dataset files. - For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here. - For `sql` storage type, the path is the path to the sqlite database file. +- `subset_name`: The subset name of the task dataset. Default is `None`. +- `split`: The split of the task dataset. Default is `train`. - `format`: Defines keys for prompts and responses in the dataset. - `prompt_key`: Specifies which column in the dataset contains the prompt data. - `response_key`: Specifies which column in the dataset contains the response data. @@ -291,9 +303,9 @@ synchronizer: ``` - `sync_method`: Method of synchronization. Options: - - `nccl`: Uses NCCL for fast synchronization. - - `checkpoint`: Loads latest model from disk. -- `sync_interval`: Interval (in steps) between synchronizations. + - `nccl`: Uses NCCL for fast synchronization. Supported for `both` mode. + - `checkpoint`: Loads latest model from disk. Supported for `train`, `explore`, or `bench` mode. +- `sync_interval`: Interval (in steps) of model weight synchronization between trainer and explorer. - `sync_timeout`: Timeout duration for synchronization. --- @@ -313,7 +325,7 @@ trainer: - `trainer_type`: Trainer backend implementation. Currently only supports `verl`. - `save_interval`: Frequency (in steps) at which to save model checkpoints. - `trainer_config_path`: The path to the trainer configuration file. -- `train_config`: The configuration of the trainer. Only one needs to be set for `trainer.trainer_config` and `trainer.trainer_config_path` +- `trainer_config`: The trainer configuration provided inline. Only one of `trainer_config_path` and `trainer_config` should be specified. --- @@ -323,7 +335,7 @@ Configures preprocessing and data cleaning pipelines. ```yaml data_processor: - source_data_path: '/PATH/TO/DATASET' + source_data_path: /PATH/TO/DATASET load_kwargs: split: 'train' format: @@ -334,7 +346,7 @@ data_processor: db_url: 'postgresql://{username}@localhost:5432/{db_name}' ``` -- `source_data_path`: Path to the raw dataset. +- `source_data_path`: Path to the task dataset. - `load_kwargs`: Arguments passed to HuggingFace’s `load_dataset()`. - `dj_config_path`: Path to Data-Juicer configuration for cleaning. - `clean_strategy`: Strategy for iterative data cleaning. @@ -365,11 +377,6 @@ actor_rollout_ref: use_dynamic_bsz: True ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: False # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -388,10 +395,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -436,22 +439,6 @@ critic: grad_clip: 1.0 cliprange_value: 0.5 -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - norm_adv_by_std_in_grpo: True - use_kl_in_reward: False - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - horizon: 10000 - target_kl: 0.1 - trainer: balance_batch: True # total_training_steps: null @@ -472,11 +459,7 @@ trainer: - `actor_rollout_ref.model.use_remove_padding`: Whether to remove pad tokens, which will reduce training time. - `actor_rollout_ref.actor.use_dynamic_bsz`: Whether to reorganize the batch data, specifically to splice the shorter data to reduce the batch size in the actual training process. - `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu`: Batch size for one GPU in one forward pass. -- `actor_rollout_ref.actor.kl_loss_type`: How to compute kl loss, optional value is `kl`, `abs`, `mse` or `low_var_kl`. - `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size. -- `actor_rollout_ref.actor.tau`: strength of regularization w.r.t. old / ref policy. -- `actor_rollout_ref.actor.opmd_baseline`: mean / logavgexp, applicable to opmd. -- `actor_rollout_ref.actor.use_uid`: True / False, applicable to pairwise_opmd. - `actor_rollout_ref.actor.optim.lr`: Learning rate for actor model. - `actor_rollout_ref.actor.optim.lr_warmup_steps_ratio`: Ratio of warmup steps for learning rate. - `actor_rollout_ref.actor.optim.warmup_style`: Warmup style for learning rate. @@ -494,8 +477,6 @@ trainer: - `critic.grad_clip`: Gradient clip for critic model training. - `critic.cliprange_value`: Used for compute value loss. -- `algorithm`: Training algorithm settings. - - `trainer.balance_batch`: Whether to balance batch size between GPUs during training. - `trainer.resume_mode`: Resume mode for training. Support `disable`, `auto` and `resume_path`. - `trainer.resume_from_path`: Path to resume from. diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 1b8e3fc56b..e07e6bb3dc 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -1,6 +1,16 @@ # Developer Guide -This guide introduces how to add new workflows to Trinity-RFT and provides relevant development guidelines. +This guide introduces how to develop new modules in Trinity-RFT and provides relevant development guidelines. + +Trinity-RFT consists of three main modules: **Explorer**, **Trainer** and **Buffer**. +We decouple the RL pipeline into three modules to make it easier to customize and extend. +Below is a table summarizing the modules and components that developers with different tragets need to focus on. + +| Development Target | Core Module | Key Component | +|--------------------|-------------|---------------| +| Apply existing RL algorithms to new environments. | *Explorer* | `Workflow` | +| Design new RL algorithms. | *Trainer* | `Algorithm` | +| Enhance the RL process from the data perspective. | *Buffer* | Data Processing Module (Coming soon) | ```{note} Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code. @@ -8,9 +18,10 @@ Trinity-RFT is still under development, and the following interfaces may change. --- -## Creating New Workflows +## Workflows (For RL Environment Developers) -Trinity-RFT allows developers to register new workflows (e.g., for multi-turn interactions or agentic scenarios). Below are the steps to create a new workflow: +In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments. +A qualified workflow needs to use the trained model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow: --- @@ -18,19 +29,16 @@ Trinity-RFT allows developers to register new workflows (e.g., for multi-turn in Before starting development, it's important to understand several core concepts: - - **Task** ({class}`trinity.common.workflows.Task`): Represents a data structure that can be converted into a `Workflow`. The content of the `Task` varies depending on the task type: - **Math problems**: A `Task` contains the problem description and the golden answer. - **Programming scenarios**: A `Task` includes the problem description, test cases, runtime environment, and other complex information. - -- **Workflow** ({class}`trinity.common.workflows.Workflow`): Can be understood as the running state of a `Task`. It defines the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows: +- **Workflow** ({class}`trinity.common.workflows.Workflow`): Describes how a `Task` is executed. It defines the interaction flow between Agents and Environments, including logic similar to *Rollout* and *Reward* calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows: - `MathWorkflow` ({class}`trinity.common.workflows.MathWorkflow`): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards). - `WebShopWorkflow` ({class}`trinity.common.workflows.WebShopWorkflow`): For webshop scenarios, it contains multi-turn interaction with environment. - `CodeWorkflow` (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results. - ... - - **Experience** ({class}`trinity.common.experience.Experience`): The output of running a `Workflow`. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc. --- @@ -40,12 +48,12 @@ Before starting development, it's important to understand several core concepts: The task dataset is loaded via the `buffer.explorer_input.taskset` configuration entry in your YAML config file. To handle differences in `Task` contents, Trinity-RFT provides a unified `Task` interface containing the following fields. - - **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file. - - **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field. - - **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields. - - **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`. - - **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`. - - **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field. +- **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file. +- **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field. +- **`raw_task`** (`Dict`): A record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields. +- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`. +- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`. +- **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field. ```{tip} `workflow`, `workflow_args` and `raw_task` provide different levels of customization. @@ -82,7 +90,6 @@ buffer: In this example, each task object's `raw_task` is a `Dict` with two keys (`question` and `answer`). The `MathWorkflow` uses the `prompt_key` and `response_key` to extract the question and answer from the `raw_task` and use the `rollout_args` to generate the response. - --- ### Step 2: Implement a New Workflow @@ -106,8 +113,7 @@ class Workflow(ABC): """Run the workflow and return a list of Experiences.""" ``` - -#### Initializing Your Workflow +#### Initialize Your Workflow During initialization, `Workflow` receives the following parameters: @@ -115,7 +121,6 @@ During initialization, `Workflow` receives the following parameters: - `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset. - `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs. - ```{tip} You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. ``` @@ -143,7 +148,6 @@ We first call the model to generate multiple response using the provided questio Then we calculate the reward for each response using the `calculate_reward` function. Finally, we construct a list of `Experience` with the responses and rewards and return it. - ```python class ExampleWorkflow(Workflow): @@ -215,7 +219,6 @@ For workflows that are not intended to be contributed to Trinity-RFT project, yo You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `/trinity/plugins` as the default directory. ``` - #### Avoid Re-initialization For heavy workflows, re-initializing every time can incurs extra computational costs. @@ -235,7 +238,6 @@ class ExampleWorkflow(Workflow): self.answer = task.raw_task.get("answer") ``` - #### Full Code Example ```python @@ -289,7 +291,6 @@ class ExampleWorkflow(Workflow): self.answer = task.raw_task.get("answer") ``` - --- ### Step 3: Use Your Workflow @@ -314,6 +315,198 @@ trinity run --config --- +(Algorithms)= +## Algorithms (For RL Algorithm Developers) + +Trinity-RFT provides a standardized process for implementing new algorithms. + +### Step 0: Basic Concepts of Algorithm Module + +In Trinity-RFT, the algorithm module is primarily responsible for extracting experience data from the Replay Buffer during the RL process and calculating the loss to update models based on this data. +To avoid implementing a new Trainer class each time a new algorithm is added, we have decomposed the representative PPO algorithm process into multiple sub-modules to adapt to various algorithms. + +- **Sample Strategy** ({class}`trinity.algorithm.SampleStrategy`): Responsible for sampling experience data from the buffer module. By customizing this module, you can implement functionalities like filtering experience data or mixed sampling from multiple data sources. +- **Advantage Fn**({class}`trinity.algorithm.AdvantageFn`): Responsible for calculating the Advantage and Returns of experience data. +- **Policy Loss Fn**({class}`trinity.algorithm.PolicyLossFn`): Responsible for calculating the core training loss of the policy network. +- **KL Fn**({class}`trinity.algorithm.KLFn`): Responsible for calculating KL Divergence, which is generally used in two places in existing RL algorithms: Reward Penalty and Actor Loss. +- **Entropy Loss Fn**({class}`trinity.algorithm.EntropyLossFn`): Responsible for calculating the entropy loss of the policy network. + +We provide several implementations of above modules in `trinity/algorithm`. + +--- + +### Step 1: Implement Algorithm Components + + +Trinity-RFT allows developers to customize all the above modules. Developers only need to implement specific modules according to the requirements of their new algorithm. This section will provide a simple introduction using the {ref}`OPMD ` algorithm as an example. + +The main difference between OPMD and PPO algorithms lies in the calculation of Advantage and Policy Loss. Therefore, only new Advantage Fn and Policy Loss Fn modules need to be implemented. + +--- + +#### Step 1.1: Implement `AdvantageFn` + +Developers need to implement the {class}`trinity.algorithm.AdvantageFn` interface, which mainly includes two methods: + +- `__call__`: Calculates advantages and returns based on input experience data, records observable metrics during the calculation process, and returns the experience data containing advantages and returns as well as a metrics dictionary. The input experience data format is [verl](https://github.com/volcengine/verl)'s `DataProto`. +- `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file. + +After implementation, you need to register this module through {class}`trinity.algorithm.ADVANTAGE_FN`. Once registered, the module can be configured in the configuration file using the registered name. + +Here's an implementation example for the OPMD algorithm's Advantage Fn: + +```python +# trinity/algorithm/advantage_fn/opmd.py +# import some modules +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn + + +@ADVANTAGE_FN.register_module("opmd") +class OPMDAdvantageFn(AdvantageFn): + """OPMD advantage computation""" + + def __init__( + self, + opmd_baseline: str = "mean", + tau: float = 1.0, + ) -> None: + self.opmd_baseline = opmd_baseline + self.tau = tau + + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + # calculate advantages and returns based on the exps + + # record some metrics + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "opmd_baseline": "mean", + "tau": 1.0, + } +``` + +#### Step 1.2: Implement `PolicyLossFn` + +Developers need to implement the {class}`trinity.algorithm.PolicyLossFn` interface, which is similar to `AdvantageFn` and includes two methods: + +- `__call__`: Calculates the loss based on input parameters. Unlike `AdvantageFn`, the input parameters here are all `torch.Tensor`. This interface automatically scans the parameter list of the `__call__` method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters from `kwargs`. +- `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file. + +Similarly, after implementation, you need to register this module through {class}`trinity.algorithm.POLICY_LOSS_FN`. + +Here's an implementation example for the OPMD algorithm's Policy Loss Fn. Since OPMD's Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the `__call__` method: + + +```python +@POLICY_LOSS_FN.register_module("opmd") +class OPMDPolicyLossFn(PolicyLossFn): + def __init__(self, tau: float = 1.0) -> None: + self.tau = tau + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + pg_losses = -advantages * logprob + opmd_loss = masked_mean(pg_losses, action_mask) + opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta) + return opmd_loss, {"opmd_loss": opmd_loss.detach().item()} + + @classmethod + def default_args(cls) -> Dict: + return {"tau": 1.0} +``` + +--- + +### Step 2: Register Your Algorithm + +The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect. + +To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {object}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration. + +The `AlgorithmType` class includes the following attributes and methods: + +- `use_critic`: Whether to use the Critic model +- `use_reference`: Whether to use the Reference model +- `use_advantage`: Whether to calculate Advantage; if False, the `AdvantageFn` call will be skipped +- `can_balance_batch`: Whether the algorithm allows automatic balancing when splitting a batch into microbatches (which permute the order of samples) +- `schema`: The format of experience data corresponding to the algorithm +- `default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE` + +Similarly, after implementation, you need to register this module through `ALGORITHM_TYPE`. + +Below is the implementation for the OPMD algorithm. +Since the OPMD algorithm doesn't need to use the Critic model, `use_critic` is set to `False`. +The dictionary returned by the `default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss. + +```python +@ALGORITHM_TYPE.register_module("opmd") +class OPMDAlgorithm(AlgorithmType): + """OPMD algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "sample_strategy": "warmup", + "policy_loss_fn": "opmd", + "advantage_fn": "opmd", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } +``` + +--- + +### Step 3: Use Your Algorithm + +After completing all the above steps, you can use the newly registered algorithm through a YAML configuration file. + +For default configurations, you just need to add the following content to your `config.yaml` file: + +```yaml +# some other configs +algorithm: + algorithm_type: "opmd" +# some other configs +``` + +If you need to modify certain parameters, you can simply add the corresponding parameters within the `algorithm` section. For example, if you need to modify `repeat_times` and the initialization parameters of `AdvantageFn` and `PolicyLossFn`, the modified `config.yaml` file would be as follows: + +```yaml +# some other configs +algorithm: + algorithm_type: "opmd" + repeat_times: 8 + advantage_fn_args: + opmd_baseline: "logavgexp" + tau: 0.99 + policy_loss_fn_args: + tau: 0.99 +# some other configs +``` + +--- + ## Adding New Config Entries for the Config Generator (Advanced) ### Step 0: Understanding Streamlit @@ -344,11 +537,11 @@ The `CONFIG_GENERATORS.register_config` decorator automatically passes `key=conf ``` For `train_batch_size`, we will use the following settings: + - Default value: 96 - Visibility condition: `lambda: st.session_state["trainer_gpu_num"] > 0` - Additional config: `{"_train_batch_size_per_gpu": 16}` - Here's the complete code for the `train_batch_size` parameter: ```python @@ -408,6 +601,7 @@ To successfully integrate new parameters into the `config_manager.py` file, plea Incorporate the new parameter into the relevant section using the `self.get_configs` method within the `ConfigManager` class. Example: + ```python class ConfigManager: def _expert_buffer_part(self): @@ -421,6 +615,7 @@ To successfully integrate new parameters into the `config_manager.py` file, plea Utilize `st.session_state` to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML. Example: + ```python class ConfigManager: def _gen_buffer_config(self): @@ -447,7 +642,7 @@ pip install -e .[dev] # pip install -e .\[dev\] # Run code style checks -pre-commit --all-files +pre-commit run --all-files # Commit the code after all checks pass git commit -am "create example workflow" diff --git a/environments/data.yaml b/environments/data.yaml index 6acdf04dc9..d43ece076b 100644 --- a/environments/data.yaml +++ b/environments/data.yaml @@ -6,10 +6,5 @@ dependencies: - pip: - py-data-juicer - agentscope - - flask - - omegaconf - - sqlalchemy - - psycopg2 - - networkx - transformers - "-e ..[dev]" diff --git a/examples/async_gsm8k/verl_config.yaml b/examples/async_gsm8k/verl_config.yaml index de1b08f590..fc44fdad94 100644 --- a/examples/async_gsm8k/verl_config.yaml +++ b/examples/async_gsm8k/verl_config.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: True # False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -33,10 +28,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -48,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 8cd3dbe0c8..0a0864b8ef 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -3,6 +3,9 @@ name: "trinity_dpo" mode: train algorithm: algorithm_type: dpo + kl_loss_fn: k1 + kl_loss_fn_args: + kl_coef: 0.1 checkpoint_root_dir: /PATH/TO/CHECKPOINT/ model: model_path: /PATH/TO/MODEL @@ -34,5 +37,3 @@ trainer: trainer_type: 'verl' trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml' save_interval: 30 - actor_use_kl_loss: True - actor_kl_loss_coef: 0.1 diff --git a/examples/dpo_humanlike/train_dpo.yaml b/examples/dpo_humanlike/train_dpo.yaml index 8ffc68b397..d5074848b0 100644 --- a/examples/dpo_humanlike/train_dpo.yaml +++ b/examples/dpo_humanlike/train_dpo.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True - kl_loss_coef: 0.1 # NOTE: beta for DPO - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -26,8 +21,7 @@ actor_rollout_ref: min_lr_ratio: 0.1 # only useful for warmup with cosine warmup_style: cosine # select from constant/cosine total_training_steps: 783 # - beta1: 0.9 - beta2: 0.95 + betas: [0.9, 0.95] fsdp_config: wrap_policy: # transformer_layer_cls_to_wrap: None @@ -47,18 +41,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: False total_training_steps: 783 # diff --git a/examples/grpo_alfworld/train_alfworld.yaml b/examples/grpo_alfworld/train_alfworld.yaml index 215b1817ab..5b73ec7403 100644 --- a/examples/grpo_alfworld/train_alfworld.yaml +++ b/examples/grpo_alfworld/train_alfworld.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -44,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index 2a87ef288b..0763586457 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -4,19 +4,6 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT/ algorithm: algorithm_type: grpo repeat_times: 8 -data_processor: - # basic info - source_data_path: 'openai/gsm8k' - # data active iterator related - dj_process_desc: 'Please compute difficulty scores for these math questions.' - agent_model_name: 'qwen-max' - agent_model_config: - config_name: 'my-qwen-instruction' - model_type: 'dashscope_chat' - model_name: 'qwen2.5-72b-instruct' - clean_strategy: 'iterative' - # db related - db_url: '' model: model_path: /PATH/TO/MODEL/ @@ -41,9 +28,7 @@ buffer: prompt_key: 'question' response_key: 'answer' rollout_args: - n: 8 temperature: 1.0 - logprobs: 0 eval_tasksets: - name: gsm8k-eval storage_type: file diff --git a/examples/grpo_gsm8k/train_gsm8k.yaml b/examples/grpo_gsm8k/train_gsm8k.yaml index de1b08f590..fc44fdad94 100644 --- a/examples/grpo_gsm8k/train_gsm8k.yaml +++ b/examples/grpo_gsm8k/train_gsm8k.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: True # False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -33,10 +28,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -48,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/grpo_gsm8k_task_pipeline/README.md b/examples/grpo_gsm8k_task_pipeline/README.md new file mode 100644 index 0000000000..ead6a56185 --- /dev/null +++ b/examples/grpo_gsm8k_task_pipeline/README.md @@ -0,0 +1,7 @@ +# GRPO on GSM8K dataset with Task Pipeline + +This example shows the usage of GRPO on the GSM8K dataset, with a task pipeline to prioritize the raw dataset before training. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_data_functionalities.md). + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml new file mode 100644 index 0000000000..36514e0e01 --- /dev/null +++ b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml @@ -0,0 +1,95 @@ +project: "Trinity-RFT-gsm8k-task-pipeline" +name: "qwen2.5-1.5B-gsm8k-task-pipeline" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: grpo + repeat_times: 8 +data_processor: + data_processor_url: 'http://127.0.0.1:5005/data_processor' + # task pipeline related + task_pipeline: + # I/O buffers + input_buffers: + - name: 'raw_input' + path: 'openai/gsm8k' + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl' + storage_type: 'file' + # format mapping + format: + prompt_key: 'question' + response_key: 'answer' + # data active iterator related + dj_process_desc: 'Please compute difficulty scores for these math questions.' + agent_model_name: 'qwen-max' + agent_model_config: + config_name: 'my-qwen-instruction' + model_type: 'dashscope_chat' + model_name: 'qwen2.5-72b-instruct' + clean_strategy: 'iterative' + +model: + model_path: /PATH/TO/MODEL/ + max_prompt_tokens: 256 + max_response_tokens: 1024 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 96 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: './outputs/task_pipeline_output/' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' + # sft_warmup_steps: 0 + # sft_warmup_dataset: # Uncomment these to enable sft warmup + # name: warmup_data + # storage_type: file + # path: '/PATH/TO/WARMUP_DATA/' +explorer: + eval_interval: 50 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml' + save_interval: 100 diff --git a/examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml b/examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml new file mode 100644 index 0000000000..fc44fdad94 --- /dev/null +++ b/examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml @@ -0,0 +1,50 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 128 + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # total_training_steps: null + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/examples/grpo_math/train_math.yaml b/examples/grpo_math/train_math.yaml index 78bcb862c6..0a46bd1788 100644 --- a/examples/grpo_math/train_math.yaml +++ b/examples/grpo_math/train_math.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: True # False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.0001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -33,10 +28,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -48,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.0001 - trainer: balance_batch: True # auto: find the last ckpt to resume. If can't find, start from scratch diff --git a/examples/grpo_sciworld/train_sciworld.yaml b/examples/grpo_sciworld/train_sciworld.yaml index 215b1817ab..5b73ec7403 100644 --- a/examples/grpo_sciworld/train_sciworld.yaml +++ b/examples/grpo_sciworld/train_sciworld.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -44,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/grpo_webshop/train_webshop.yaml b/examples/grpo_webshop/train_webshop.yaml index 215b1817ab..5b73ec7403 100644 --- a/examples/grpo_webshop/train_webshop.yaml +++ b/examples/grpo_webshop/train_webshop.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -44,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/mix_math/README.md b/examples/mix_math/README.md new file mode 100644 index 0000000000..2ef160b0f2 --- /dev/null +++ b/examples/mix_math/README.md @@ -0,0 +1,7 @@ +# Example: MIX on MATH dataset + +This example shows the usage of a new algorithm MIX on the MATH dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md). + +The config files are located in [`mix_math.yaml`](mix_math.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml). diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml new file mode 100644 index 0000000000..b92edd4b25 --- /dev/null +++ b/examples/mix_math/mix_math.yaml @@ -0,0 +1,87 @@ +project: "mix_math" +name: "expert0.25_mu0.1" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: mix + repeat_times: 8 + sample_strategy_args: + expert_data_ratio: 0.25 + policy_loss_fn_args: + mu: 0.1 + clip_range: 0.2 + use_token_level_loss_in_sft: False + use_dynamic_bsz: False + repeat_times: 8 + ppo_mini_batch_size: 32 + ppo_micro_batch_size_per_gpu: 4 + ngpus_trainer: 4 + read_batch_size_expert: 64 + read_batch_size_usual: 192 +model: + model_path: /PATH/TO/MODEL/ + max_prompt_tokens: 1024 + max_response_tokens: 10240 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 40 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: math_train + storage_type: file + path: /PATH/TO/DATASET/ + split: 'train' + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: + - name: math_eval + storage_type: file + path: /PATH/TO/DATASET/ + split: 'test' + format: + prompt_key: 'problem' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + path: /PATH/TO/BUFFER/ + sft_warmup_dataset: + name: math_sft + storage_type: file + algorithm_type: sft + path: /PATH/TO/EXPERT_DATA/ + split: 'train' + format: + prompt_type: messages + messages_key: 'messages' +explorer: + eval_interval: 10 + runner_num: 16 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/mix_math/train_mix_math.yaml' + save_interval: 50 +monitor: + monitor_type: wandb diff --git a/examples/mix_math/train_mix_math.yaml b/examples/mix_math/train_mix_math.yaml new file mode 100644 index 0000000000..ca072b78f6 --- /dev/null +++ b/examples/mix_math/train_mix_math.yaml @@ -0,0 +1,49 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 128 + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 25600 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml index 326904d987..5ddd5124ee 100644 --- a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml @@ -15,8 +15,8 @@ # entropy_coeff: default to 0.0 for now # # optimizer: -# beta1, beta2: 0.0, 0.95 # smaller than default values (0.9, 0.999), as a remedy for abrupt distribution shift -# lr: set smaller to account for beta1 = 0.0 +# betas: [0.0, 0.95] # smaller than default values (0.9, 0.999), as a remedy for abrupt distribution shift +# lr: set smaller to account for betas[0] = 0.0 # # misc: # adv_estimator: grpo # merely to disable critic model, doesn't affect adv compute when algorithm_type is opmd @@ -36,11 +36,6 @@ actor_rollout_ref: use_dynamic_bsz: True ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.000 - use_kl_loss: True - kl_loss_coef: 0.001 - kl_loss_type: mse ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -50,8 +45,7 @@ actor_rollout_ref: # min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program - beta1: 0.0 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval) - beta2: 0.95 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval) + betas: [0.0, 0.95] # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval) fsdp_config: wrap_policy: # transformer_layer_cls_to_wrap: None @@ -59,10 +53,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 4.0 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -74,18 +64,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.000 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/ppo_countdown/train_countdown.yaml b/examples/ppo_countdown/train_countdown.yaml index ae16122ef7..191c345b90 100644 --- a/examples/ppo_countdown/train_countdown.yaml +++ b/examples/ppo_countdown/train_countdown.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: True ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: False # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -35,10 +30,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -82,18 +73,6 @@ critic: grad_clip: 1.0 cliprange_value: 0.5 -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/pyproject.toml b/pyproject.toml index a295c74fed..fd7ca1f3c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "trinity-rft" -version = "0.1.1" +version = "0.2.0.dev0" authors = [ {name="Trinity-RFT Team", email="trinity-rft@outlook.com"}, ] @@ -21,9 +21,9 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ - "verl==0.3.0.post1", + "verl==0.4.0", "ray[default]>=2.45.0", - "vllm==0.8.5.post1", + "vllm==0.9.1", "tensordict==0.6.2", "wandb", "omegaconf", @@ -39,6 +39,7 @@ dependencies = [ "requests", "tensorboard", "openai", + "jsonlines", ] [project.scripts] @@ -69,6 +70,10 @@ doc = [ "myst-parser", ] +flash_attn = [ + "flash-attn==2.8.0.post2" +] + [tool.setuptools.packages.find] where = ["."] include = ["trinity*"] diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py new file mode 100644 index 0000000000..ba88feb2d7 --- /dev/null +++ b/tests/algorithm/policy_loss_test.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +"""Test for policy loss functions""" + +import unittest + +import torch +from verl import DataProto + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN + + +class VerlPolicyLossTest(unittest.TestCase): + def setUp(self): + seed = 42 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + shape = (5, 20) + self.logprob = 2 * torch.rand(shape) - 1 + self.input_data = DataProto.from_dict( + { + "old_log_probs": 2 * torch.rand(shape) - 1, + "ref_log_prob": 2 * torch.rand(shape) - 1, + "response_mask": torch.rand(shape) > 0.5, + "advantages": 2 * torch.rand(shape) - 1, + "is_expert_mask": torch.rand(shape[0]) > 0.5, + } + ) + + def test_ppo_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + ppo_loss = torch.tensor(0.28560468554496765) + pg_clipfrac = torch.tensor(0.3541666567325592) + ppo_kl = torch.tensor(-0.21663446724414825) + self.assertTrue(torch.allclose(loss, ppo_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss)) + + def test_sft_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("sft") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + sft_loss = torch.tensor(-0.07560186833143234) + self.assertTrue(torch.allclose(loss, sft_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["sft_loss"]), sft_loss)) + + def test_dpo_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("dpo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + dpo_loss = torch.tensor(0.5406752228736877) + chosen_reward = torch.tensor(0.7082431316375732) + rejected_reward = torch.tensor(0.3757950782775879) + accuracy_mean = torch.tensor(1.0) + self.assertTrue(torch.allclose(loss, dpo_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["chosen_reward"]), chosen_reward)) + self.assertTrue(torch.allclose(torch.tensor(metrics["rejected_reward"]), rejected_reward)) + self.assertTrue(torch.allclose(torch.tensor(metrics["accuracy_mean"]), accuracy_mean)) + self.assertTrue(torch.allclose(torch.tensor(metrics["dpo_loss"]), dpo_loss)) + + def test_opmd_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("opmd") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + opmd_loss = torch.tensor(-0.009589947760105133) + self.assertTrue(torch.allclose(loss, opmd_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["opmd_loss"]), opmd_loss)) + + def test_mix_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("mix") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + mix_loss = torch.tensor(0.6581965088844299) + pg_clipfrac = torch.tensor(0.7777777910232544) + ppo_kl = torch.tensor(-1.0737695693969727) + pg_loss = torch.tensor(0.7236452102661133) + sft_loss = torch.tensor(0.06915830634534359) + self.assertTrue(torch.allclose(loss, mix_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 363a4939ad..e53669a850 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -9,12 +9,54 @@ get_unittest_dataset_config, ) from trinity.buffer.buffer import get_buffer_reader, get_buffer_writer +from trinity.buffer.reader.file_reader import RawDataReader from trinity.buffer.utils import default_storage_path +from trinity.buffer.writer.file_writer import JSONWriter from trinity.common.config import StorageConfig from trinity.common.constants import StorageType class TestFileBuffer(unittest.TestCase): + temp_output_path = "tmp/test_file_buffer/" + + @classmethod + def setUpClass(cls): + super().setUpClass() + os.makedirs(cls.temp_output_path, exist_ok=True) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + if os.path.exists(cls.temp_output_path): + os.system(f"rm -rf {cls.temp_output_path}") + + def test_file_buffer(self): + meta = StorageConfig( + name="test_buffer", + path=os.path.join(self.temp_output_path, "buffer.jsonl"), + storage_type=StorageType.FILE, + raw=True, + ) + data = [ + {"key1": 1, "key2": 2}, + {"key1": 3, "key2": 4}, + {"key1": 5, "key2": 6}, + {"key1": 7, "key2": 8}, + ] + + # test writer + writer = JSONWriter(meta, None) + writer.write(data) + writer.finish() + + # test reader + meta.path = self.temp_output_path + reader = RawDataReader(meta, None) + loaded_data = reader.read() + self.assertEqual(len(loaded_data), 4) + self.assertEqual(loaded_data, data) + self.assertRaises(StopIteration, reader.read) + def test_file_reader(self): """Test file reader.""" reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 4f3947c795..03e96e4291 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -6,7 +6,7 @@ from trinity.buffer.reader.queue_reader import QueueReader from trinity.buffer.writer.queue_writer import QueueWriter from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import AlgorithmType, StorageType +from trinity.common.constants import StorageType from trinity.common.experience import Experience BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl") @@ -19,7 +19,7 @@ def test_queue_buffer(self): read_batch_size = 4 meta = StorageConfig( name="test_buffer", - algorithm_type=AlgorithmType.PPO, + algorithm_type="ppo", storage_type=StorageType.QUEUE, path=BUFFER_FILE_PATH, ) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 2146794ebd..56305be671 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -7,7 +7,7 @@ from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import AlgorithmType, StorageType +from trinity.common.constants import StorageType from trinity.common.experience import Experience db_path = os.path.join(os.path.dirname(__file__), "test.db") @@ -20,7 +20,7 @@ def test_create_sql_buffer(self) -> None: read_batch_size = 4 meta = StorageConfig( name="test_buffer", - algorithm_type=AlgorithmType.PPO, + algorithm_type="ppo", path=f"sqlite:///{db_path}", storage_type=StorageType.SQL, wrap_in_ray=True, diff --git a/tests/common/config_test.py b/tests/common/config_test.py index e1ac0aa7d4..da4fd914a0 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -47,6 +47,7 @@ def test_all_examples_are_valid(self): config_path = os.path.join(example_dir, example_name, filename) try: config = load_config(config_path) + config.checkpoint_root_dir = "./.cache/" config.check_and_update() except Exception as e: print(f"Error loading config {config_path}: {e}") diff --git a/tests/data/controllers/task_parser_test.py b/tests/data/controllers/task_parser_test.py index 542c491f41..af36f8777a 100644 --- a/tests/data/controllers/task_parser_test.py +++ b/tests/data/controllers/task_parser_test.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- """Test cases for data task parser.""" +import os import unittest import agentscope from agentscope.models import DashScopeChatWrapper from loguru import logger -from trinity.common.config import Config +from trinity.common.config import DataPipelineConfig from trinity.data.controllers.task_parser import DataTaskParser @@ -16,7 +17,7 @@ class TestTaskParser(unittest.TestCase): def setUp(self) -> None: print("setup", flush=True) - api_key = "your_dashscope_key" + api_key = os.environ.get("OPENAI_API_KEY", None) agentscope.init( model_configs=[ @@ -43,25 +44,20 @@ def _run_test(self, rft_config, return_none=False): logger.info("None dj config.") else: self.assertIsNotNone(dj_config) - op_weights = {} - for op in dj_config.process: - op_name = list(op.keys())[0] - op_weights[op_name] = op[op_name]["op_weight"] - logger.info(op_weights) def test_instruction1(self): - rft_config = Config() - rft_config.data.dj_process_desc = "Please recommend a data filtering strategy for me." + rft_config = DataPipelineConfig() + rft_config.dj_process_desc = "Please recommend a data filtering strategy for me." self._run_test(rft_config) def test_instruction2(self): - rft_config = Config() - rft_config.data.dj_process_desc = "Do nothing." + rft_config = DataPipelineConfig() + rft_config.dj_process_desc = "Do nothing." self._run_test(rft_config, return_none=True) def test_instruction3(self): - rft_config = Config() - rft_config.data.dj_process_desc = "Remove samples with repeat contents." + rft_config = DataPipelineConfig() + rft_config.dj_process_desc = "Remove samples with repeat contents." self._run_test(rft_config) diff --git a/tests/data/core/dataset_test.py b/tests/data/core/dataset_test.py index be6e765fbd..76758e84d6 100644 --- a/tests/data/core/dataset_test.py +++ b/tests/data/core/dataset_test.py @@ -3,10 +3,7 @@ import os import unittest -from trinity.common.config import DataProcessorConfig, FormatConfig -from trinity.common.rewards import AccuracyReward -from trinity.common.task import TaskSet -from trinity.common.workflows import MathWorkflow, SimpleWorkflow +from trinity.common.config import DataPipelineConfig, FormatConfig, StorageConfig from trinity.data.core.dataset import RewardSchema, RftDataset from trinity.data.core.formatter import BoxedMathAnswerFormatter, RLHFFormatter @@ -15,28 +12,38 @@ class TestRftDataset(unittest.TestCase): """Test cases for RftDataset""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_pipeline_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", response_key="solution", solution_key="solution", ), ) - self.data_config_sample_level_setting = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10_with_rewfn_workflow", - ), + self.data_pipeline_config_sample_level_setting = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10_with_rewfn_workflow", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", response_key="solution", @@ -47,13 +54,19 @@ def setUp(self) -> None: ) def test_rft_dataset_init(self): - dataset = RftDataset(data_config=self.data_config, reward_schema="default") + dataset = RftDataset( + data_pipeline_config=self.data_pipeline_config, reward_schema="default" + ) + dataset.read_from_buffer() self.assertEqual(len(dataset), 10) self.assertIsInstance(dataset.reward_schema, RewardSchema) def test_format_dataset(self): - dataset = RftDataset(data_config=self.data_config, reward_schema="default") + dataset = RftDataset( + data_pipeline_config=self.data_pipeline_config, reward_schema="default" + ) + dataset.read_from_buffer() original_data = dataset.data # no formatter dataset.format(formatters=[]) @@ -62,56 +75,12 @@ def test_format_dataset(self): # apply formatters dataset.format( formatters=[ - BoxedMathAnswerFormatter(config=self.data_config.format), - RLHFFormatter(config=self.data_config.format), + BoxedMathAnswerFormatter(config=self.data_pipeline_config.format), + RLHFFormatter(config=self.data_pipeline_config.format), ] ) self.assertNotEqual(dataset.data, original_data) - def test_to_taskset(self): - dataset = RftDataset(data_config=self.data_config, reward_schema="default") - taskset = dataset.to_taskset() - self.assertIsInstance(taskset, TaskSet) - self.assertEqual(len(taskset), 10) - self.assertIsNone(taskset.reward_fn) - self.assertIsNone(taskset.workflow) - self.assertEqual(taskset._index, 0) - - def test_to_taskset_with_global_settings(self): - dataset = RftDataset(data_config=self.data_config, reward_schema="default") - taskset = dataset.to_taskset( - reward_fn=AccuracyReward, - workflow=SimpleWorkflow, - ) - self.assertIsInstance(taskset, TaskSet) - self.assertEqual(taskset.workflow, SimpleWorkflow) - self.assertEqual(taskset.reward_fn, AccuracyReward) - - def test_to_taskset_with_sample_level_settings(self): - dataset = RftDataset( - data_config=self.data_config_sample_level_setting, reward_schema="default" - ) - taskset = dataset.to_taskset() - self.assertIsInstance(taskset, TaskSet) - for task in taskset.tasks: - self.assertEqual(task.workflow, MathWorkflow) - self.assertEqual(task.reward_fn, AccuracyReward) - - def test_to_taskset_with_both_settings(self): - dataset = RftDataset( - data_config=self.data_config_sample_level_setting, reward_schema="default" - ) - taskset = dataset.to_taskset( - reward_fn=AccuracyReward, - workflow=SimpleWorkflow, - ) - self.assertIsInstance(taskset, TaskSet) - for task in taskset.tasks: - self.assertEqual(task.workflow, MathWorkflow) - self.assertEqual(task.reward_fn, AccuracyReward) - self.assertEqual(taskset.workflow, SimpleWorkflow) - self.assertEqual(taskset.reward_fn, AccuracyReward) - if __name__ == "__main__": unittest.main() diff --git a/tests/data/core/formatter_test.py b/tests/data/core/formatter_test.py index 363c736ed9..dbb73ed971 100644 --- a/tests/data/core/formatter_test.py +++ b/tests/data/core/formatter_test.py @@ -3,7 +3,7 @@ import os import unittest -from trinity.common.config import DataProcessorConfig, FormatConfig +from trinity.common.config import DataPipelineConfig, FormatConfig, StorageConfig from trinity.data.core.dataset import RftDataset from trinity.data.core.formatter import ( BoxedMathAnswerFormatter, @@ -18,14 +18,19 @@ class TestBoxedMathDataset(unittest.TestCase): """Test cases for RftDataset""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", response_key="answer", @@ -43,12 +48,13 @@ def test_init(self): self.assertEqual(formatter.config.chat_template, "User: {}\nAssistant: ") # test for default configs self.assertEqual(formatter.config.reward_key, "") - self.assertEqual(formatter.config.chosen_key, "") - self.assertEqual(formatter.config.rejected_key, "") + self.assertEqual(formatter.config.chosen_key, "chosen") + self.assertEqual(formatter.config.rejected_key, "rejected") self.assertEqual(formatter.config.label_key, "") def test_transform(self): - dataset = RftDataset(data_config=self.data_config, reward_schema="default") + dataset = RftDataset(data_pipeline_config=self.data_config, reward_schema="default") + dataset.read_from_buffer() formatter = BoxedMathAnswerFormatter(config=self.data_config.format) self.assertNotIn(formatter.config.response_key, dataset.data.column_names) dataset.format(formatter) @@ -59,14 +65,19 @@ class TestRLHFFormatter(unittest.TestCase): """Test cases for RLHFFormatter""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", chat_template="User: {}\nAssistant: ", @@ -107,14 +118,19 @@ class TestRewardFormatter(unittest.TestCase): """Test cases for RewardFormatter""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", chosen_key="chosen", @@ -164,14 +180,19 @@ class TestSFTFormatter(unittest.TestCase): """Test cases for SFTFormatter""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", response_key="answer", @@ -217,14 +238,19 @@ class TestComposedFormatter(unittest.TestCase): """Test cases for ComposedFormatter""" def setUp(self) -> None: - self.data_config = DataProcessorConfig( - source_data_path=os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "test_data", - "test_10", - ), + self.data_config = DataPipelineConfig( + input_buffers=[ + StorageConfig( + path=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "test_data", + "test_10", + ), + raw=True, + ) + ], format=FormatConfig( prompt_key="problem", response_key="answer", diff --git a/tests/data/processor/cleaner_test.py b/tests/data/processor/cleaner_test.py index d21a6960c5..ef2aa13d20 100644 --- a/tests/data/processor/cleaner_test.py +++ b/tests/data/processor/cleaner_test.py @@ -15,7 +15,7 @@ def setUp(self) -> None: print("setup", flush=True) self.rft_config = load_config("./tests/test_configs/cleaner_test_rft_cfg.yaml") - print(self.rft_config) + # print(self.rft_config) self.ds_list = [ {"text": "Today is"}, {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"}, @@ -25,95 +25,67 @@ def setUp(self) -> None: ] def _run_test(self, tgt_list, weight=1, data_dist="gaussian"): - task_parser = DataTaskParser(self.rft_config) + task_parser = DataTaskParser(self.rft_config.data_processor.task_pipeline) dj_config, _, _, _ = task_parser.parse_to_dj_config() + op_weights = {} for op_config in dj_config.process: - _, op_args = list(op_config.items())[0] - op_args["op_weight"] = weight + op_name, _ = list(op_config.items())[0] + op_weights[op_name] = weight cleaner = DataCleaner( dj_config, clean_strategy="iterative", - min_size_ratio=self.rft_config.data.min_size_ratio, + min_size_ratio=self.rft_config.data_processor.task_pipeline.min_size_ratio, data_dist=data_dist, + op_weights=op_weights, ) - dataset = RftDataset(self.rft_config.data) + dataset = RftDataset(self.rft_config.data_processor.task_pipeline) + dataset.read_from_buffer() dataset = cleaner.process([dataset]) - res_list = dataset.to_list() + res_list = dataset.data.select_columns("text").to_list() + print(res_list) self.assertEqual(res_list, tgt_list) self.assertNotIn("clean_email_mapper", cleaner.dj_cfg.process) def test_dj_executor(self): tgt_list = [ - { - "text": "a v s e c s f e f g a a a ", - "__dj__stats__": {"text_len": 27}, - }, - { - "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►", - "__dj__stats__": {"text_len": 34}, - }, - { - "text": "中文也是一个字算一个长度", - "__dj__stats__": {"text_len": 12}, - }, + {"text": "Today is"}, + {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"}, + {"text": "a v s e c s f e f g a a a "}, + {"text": "中文也是一个字算一个长度"}, ] - self.rft_config.data.min_size_ratio = None + self.rft_config.data_processor.task_pipeline.min_size_ratio = None self._run_test(tgt_list) def test_iterative_clean(self): tgt_list = [ - { - "text": "a v s e c s f e f g a a a ", - "__dj__stats__": {"text_len": 27}, - }, - { - "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►", - "__dj__stats__": {"text_len": 34}, - }, + {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"}, + {"text": "a v s e c s f e f g a a a "}, ] - self.rft_config.data.min_size_ratio = 0.5 + self.rft_config.data_processor.task_pipeline.min_size_ratio = 0.5 self._run_test(tgt_list) def test_weight(self): tgt_list = [ - { - "text": "a v s e c s f e f g a a a ", - "__dj__stats__": {"text_len": 27}, - }, - { - "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►", - "__dj__stats__": {"text_len": 34}, - }, - { - "text": "中文也是一个字算一个长度", - "__dj__stats__": {"text_len": 12}, - }, + {"text": "Today is"}, + {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"}, + {"text": "a v s e c s f e f g a a a "}, ] - self.rft_config.data.min_size_ratio = 0.5 + self.rft_config.data_processor.task_pipeline.min_size_ratio = 0.5 self._run_test(tgt_list, weight=0.5) def test_uniform_dist(self): - tgt_list = [ - { - "text": "a v s e c s f e f g a a a ", - "__dj__stats__": {"text_len": 27}, - }, - { - "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►", - "__dj__stats__": {"text_len": 34}, - }, - ] + tgt_list = [] - self.rft_config.data.min_size_ratio = 0.5 + self.rft_config.data_processor.task_pipeline.min_size_ratio = 0.5 self._run_test(tgt_list, data_dist="uniform") diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index eaee55b40a..52f961bda4 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -10,7 +10,7 @@ from tests.tools import get_unittest_dataset_config from trinity.buffer.reader.queue_reader import QueueReader from trinity.common.config import InferenceModelConfig, StorageConfig, load_config -from trinity.common.constants import AlgorithmType, StorageType +from trinity.common.constants import StorageType from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel from trinity.common.workflows import Task @@ -105,7 +105,7 @@ def setUp(self): ) = StorageConfig( name="test", storage_type=StorageType.QUEUE, - algorithm_type=AlgorithmType.PPO, + algorithm_type="ppo", path="", ) self.queue = QueueReader( diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 09b6f9ca0d..98180fff48 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -5,13 +5,23 @@ checkpoint_root_dir: '' algorithm: algorithm_type: ppo repeat_times: 1 + policy_loss_fn: ppo + policy_loss_fn_args: + clip_range: 0.2 + advantage_fn: ppo + advantage_fn_args: + gamma: 1.0 + lam: 1.0 + kl_penalty_fn: k3 + kl_loss_fn: k2 + model: model_path: '' max_prompt_tokens: 2048 max_response_tokens: 2048 cluster: # 2 for explorer, 2 for trainer - node_num: 1 - gpu_per_node: 4 + node_num: 2 + gpu_per_node: 2 buffer: total_epochs: 1 batch_size: 4 diff --git a/tests/template/data/sft_for_gsm8k/sft.jsonl b/tests/template/data/sft_for_gsm8k/sft.jsonl new file mode 100644 index 0000000000..a8d6972103 --- /dev/null +++ b/tests/template/data/sft_for_gsm8k/sft.jsonl @@ -0,0 +1,32 @@ +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} diff --git a/tests/template/verl_config.yaml b/tests/template/verl_config.yaml index b17fc87958..bb5c21612a 100644 --- a/tests/template/verl_config.yaml +++ b/tests/template/verl_config.yaml @@ -12,16 +12,11 @@ actor_rollout_ref: use_dynamic_bsz: True ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: False # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size checkpoint: - contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + contents: ["model", "optimizer", "extra"] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space optim: lr: 1e-6 lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime @@ -35,10 +30,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -81,14 +72,8 @@ critic: shuffle: ${actor_rollout_ref.actor.shuffle} grad_clip: 1.0 cliprange_value: 0.5 - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 + checkpoint: + contents: ["model", "optimizer", "extra"] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space trainer: balance_batch: True diff --git a/tests/test_configs/active_iterator_test_cfg.yaml b/tests/test_configs/active_iterator_test_cfg.yaml index 3b105e1f66..3e6008b7cf 100644 --- a/tests/test_configs/active_iterator_test_cfg.yaml +++ b/tests/test_configs/active_iterator_test_cfg.yaml @@ -1,13 +1,18 @@ data_processor: # basic info - source_data_path: 'tests/test_data/test_10/' - load_kwargs: - split: 'train' - format: - prompt_key: 'problem' - response_key: 'solution' - # cleaner related - dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml' - clean_strategy: 'iterative' - # db related - db_url: 'postgresql://{username}@localhost:5432/{db_name}' + task_pipeline: + input_buffers: + - name: 'raw_input' + path: 'tests/test_data/test_10/' + storage_type: 'file' + raw: true + output_buffer: + name: 'raw_output' + path: './outputs/task_pipeline_output/processed.jsonl' + storage_type: 'file' + format: + prompt_key: 'problem' + response_key: 'solution' + # cleaner related + dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml' + clean_strategy: 'iterative' diff --git a/tests/test_configs/active_iterator_test_dj_cfg.yaml b/tests/test_configs/active_iterator_test_dj_cfg.yaml index f7f848e338..367709968f 100644 --- a/tests/test_configs/active_iterator_test_dj_cfg.yaml +++ b/tests/test_configs/active_iterator_test_dj_cfg.yaml @@ -1,7 +1,5 @@ project_name: 'demo-process' -export_path: './outputs/demo-process/demo-processed.jsonl' - text_keys: 'solution' process: diff --git a/tests/test_configs/cleaner_test_dj_cfg.yaml b/tests/test_configs/cleaner_test_dj_cfg.yaml index 9e2da88d64..cf11488963 100644 --- a/tests/test_configs/cleaner_test_dj_cfg.yaml +++ b/tests/test_configs/cleaner_test_dj_cfg.yaml @@ -3,7 +3,5 @@ project_name: 'demo-process' export_path: './outputs/demo-process/demo-processed.jsonl' process: - - text_length_filter: - min_len: 10 - max_len: 50 + - alphanumeric_filter: - clean_email_mapper: diff --git a/tests/test_configs/cleaner_test_rft_cfg.yaml b/tests/test_configs/cleaner_test_rft_cfg.yaml index 7f8581c0ef..c78e3a1ac8 100644 --- a/tests/test_configs/cleaner_test_rft_cfg.yaml +++ b/tests/test_configs/cleaner_test_rft_cfg.yaml @@ -1,5 +1,7 @@ data_processor: - source_data_path: './tests/test_data/test_cleaner' - load_kwargs: {"split": "train"} - dj_config_path: './tests/test_configs/cleaner_test_dj_cfg.yaml' - clean_strategy: 'iterative' + task_pipeline: + input_buffers: + - path: './tests/test_data/test_cleaner' + raw: true + dj_config_path: './tests/test_configs/cleaner_test_dj_cfg.yaml' + clean_strategy: 'iterative' diff --git a/tests/test_configs/human_annotator_test_rft_cfg.yaml b/tests/test_configs/human_annotator_test_rft_cfg.yaml index 79d8b8108b..b20f015182 100644 --- a/tests/test_configs/human_annotator_test_rft_cfg.yaml +++ b/tests/test_configs/human_annotator_test_rft_cfg.yaml @@ -1,10 +1,10 @@ data_processor: - source_data_path: './tests/test_data/test_human_annotator' - load_kwargs: {"split": "train"} - dj_config_path: './tests/test_configs/human_annotator_test_dj_cfg.yaml' - format: - prompt_key: 'prompt' - chosen_key: 'chosen' - rejected_key: 'rejected' - # db related - db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' + task_pipeline: + input_buffers: + - path: './tests/test_data/test_human_annotator' + raw: true + dj_config_path: './tests/test_configs/human_annotator_test_dj_cfg.yaml' + format: + prompt_key: 'prompt' + chosen_key: 'chosen' + rejected_key: 'rejected' diff --git a/tests/tools.py b/tests/tools.py index 2e34438d66..3111839a37 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -13,6 +13,7 @@ StorageConfig, load_config, ) +from trinity.common.constants import PromptType def get_template_config() -> Config: @@ -59,6 +60,47 @@ def get_unittest_dataset_config( default_workflow_type="math_workflow", default_reward_fn_type="countdown_reward", ) + elif dataset_name == "gsm8k": + return StorageConfig( + name=dataset_name, + path="openai/gsm8k", + split=split, + subset_name="main", + format=FormatConfig( + prompt_key="question", + response_key="answer", + ), + rollout_args=GenerationConfig( + n=1, + temperature=1.0, + logprobs=0, + ), + default_workflow_type="math_workflow", + default_reward_fn_type="math_reward", + ) + elif dataset_name == "sft_for_gsm8k": + return StorageConfig( + name=dataset_name, + path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"), + split="train", + format=FormatConfig( + prompt_type=PromptType.PLAINTEXT, + prompt_key="prompt", + response_key="response", + ), + ) + elif dataset_name == "dpo": + return StorageConfig( + name=dataset_name, + path="HumanLLMs/Human-Like-DPO-Dataset", + split="train", + format=FormatConfig( + prompt_type=PromptType.PLAINTEXT, + prompt_key="prompt", + chosen_key="chosen", + rejected_key="rejected", + ), + ) else: raise ValueError(f"Unknown dataset name: {dataset_name}") @@ -104,6 +146,11 @@ def metric_steps(self, metric_name: str) -> List[int]: raise ValueError(f"Metric '{metric_name}' does not exist.") return list(self._metrics[metric_name].keys()) + def metric_values(self, metric_name: str) -> List: + if not self.metric_exist(metric_name): + raise ValueError(f"Metric '{metric_name}' does not exist.") + return list(self._metrics[metric_name].values()) + def metric_list(self, metric_prefix: str) -> List[str]: return [name for name in self._metrics if name.startswith(metric_prefix)] diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 24e6730a1e..0ec438c2db 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -14,7 +14,7 @@ get_template_config, get_unittest_dataset_config, ) -from trinity.cli.launcher import bench, both +from trinity.cli.launcher import bench, both, train from trinity.common.constants import SyncMethod @@ -67,6 +67,10 @@ def test_trainer(self): actor_metrics = parser.metric_list("actor") self.assertTrue(len(actor_metrics) > 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) + actor_kl_metrics = parser.metric_list("actor/kl") + self.assertTrue(len(actor_kl_metrics) > 0) + critic_kl_metrics = parser.metric_list("critic/kl") + self.assertTrue(len(critic_kl_metrics) > 0) response_metrics = parser.metric_list("response_length") self.assertTrue(len(response_metrics) > 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) @@ -86,7 +90,7 @@ def test_trainer(self): ) self.assertTrue(os.path.exists(checkpoint_step_4)) self.assertTrue(os.path.exists(checkpoint_step_8)) - + # TODO: Reinit will fail when using v1 engine, find a way to fix it ray.init(ignore_reinit_error=True) # test bench mode self.config.mode = "bench" @@ -109,3 +113,151 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestStepAheadAsyncRL(BaseTrainerCase): + def test_trainer(self): + """Test the explore step ahead trainer""" + # train 4 step, sync_offset=1, sync_interval=2 + # Explorer: + # | 1 | 2 | 3 |sync| 4 | + # |---|---|---|sync|---| + # Trainer: + # | 1 | 2 |sync| 3 | 4 | + # |---|---|sync|---|---| + self.config.buffer.total_epochs = 1 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + self.config.trainer.save_interval = 4 + self.config.synchronizer.sync_interval = 2 + self.config.synchronizer.sync_offset = 1 + self.config.check_and_update() + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 1 + self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 1 + + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + actor_kl_metrics = parser.metric_list("actor/kl") + self.assertTrue(len(actor_kl_metrics) > 0) + critic_kl_metrics = parser.metric_list("critic/kl") + self.assertTrue(len(critic_kl_metrics) > 0) + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) + ray.shutdown(_exiting_interpreter=True) + # check checkpoint + from trinity.common.models.utils import get_checkpoint_dir_with_step_num + + checkpoint_step_4 = get_checkpoint_dir_with_step_num( + checkpoint_root_path=self.config.checkpoint_job_dir, + trainer_type=self.config.trainer.trainer_type, + step_num=4, + ) + self.assertTrue(os.path.exists(checkpoint_step_4)) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestTrainerGSM8K(BaseTrainerCase): + def test_trainer(self): + """Test GSM8K.""" + # test both mode + self.config.algorithm.algorithm_type = "grpo" + self.config.algorithm.repeat_times = 4 + # self.config.algorithm.repeat_times = 8 # TODO: used for real testing + self.config.algorithm.advantage_fn = "grpo" + self.config.algorithm.advantage_fn_args = {} + # self.config.buffer.batch_size = 96 # TODO: used for real testing + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.check_and_update() + self.config.trainer.trainer_config.trainer.total_training_steps = 4 + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 + self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5 + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) + # TODO: used for real testing + # rewards = parser.metric_values("critic/rewards/mean") + # self.assertTrue(0.4 < rewards[0] < 0.55) + # self.assertTrue(0.4 < rewards[1] < 0.55) + # self.assertTrue(0.6 < rewards[2] < 0.7) + # self.assertTrue(0.6 < rewards[3] < 0.7) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestTrainerSFTWarmupGSM8K(BaseTrainerCase): + def test_trainer(self): + """Test GSM8K With SFT.""" + # test both mode + self.config.algorithm.algorithm_type = "grpo" + self.config.algorithm.repeat_times = 4 + self.config.algorithm.advantage_fn = "grpo" + self.config.algorithm.advantage_fn_args = {} + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.buffer.trainer_input.sft_warmup_steps = 2 + self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config( + "sft_for_gsm8k" + ) + self.config.check_and_update() + self.config.trainer.trainer_config.trainer.total_training_steps = 4 + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 + self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5 + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT + self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 4) # RFT + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) + + def tearDown(self): + # TODO: remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestTrainerDPO(BaseTrainerCase): + def test_trainer(self): + """Test DPO.""" + # test both mode + self.config.mode = "train" + self.config.algorithm.algorithm_type = "dpo" + self.config.algorithm.policy_loss_fn = "dpo" + self.config.algorithm.policy_loss_fn_args = {} + # self.config.buffer.batch_size = 32 + self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo") + self.config.check_and_update() + self.config.trainer.trainer_config.trainer.total_training_steps = 4 + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 + self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 5e-7 + train(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) diff --git a/trinity/__init__.py b/trinity/__init__.py index ff7c8c4b29..c5b13a8976 100644 --- a/trinity/__init__.py +++ b/trinity/__init__.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- """Trinity-RFT (Reinforcement Fine-Tuning)""" -__version__ = "0.1.1" +__version__ = "0.2.0.dev0" diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py new file mode 100644 index 0000000000..667aa10d74 --- /dev/null +++ b/trinity/algorithm/__init__.py @@ -0,0 +1,21 @@ +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.algorithm import ALGORITHM_TYPE, AlgorithmType +from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn +from trinity.algorithm.kl_fn import KL_FN, KLFn +from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy + +__all__ = [ + "ALGORITHM_TYPE", + "AlgorithmType", + "AdvantageFn", + "ADVANTAGE_FN", + "PolicyLossFn", + "POLICY_LOSS_FN", + "KLFn", + "KL_FN", + "EntropyLossFn", + "ENTROPY_LOSS_FN", + "SampleStrategy", + "SAMPLE_STRATEGY", +] diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py new file mode 100644 index 0000000000..7bcf682e4b --- /dev/null +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -0,0 +1,20 @@ +from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn +from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn +from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn +from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import ( + REINFORCEPLUSPLUSAdvantageFn, +) +from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn +from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn + +__all__ = [ + "ADVANTAGE_FN", + "AdvantageFn", + "PPOAdvantageFn", + "GRPOAdvantageFn", + "REINFORCEPLUSPLUSAdvantageFn", + "REMAXAdvantageFn", + "RLOOAdvantageFn", + "OPMDAdvantageFn", +] diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py new file mode 100644 index 0000000000..21e3668a53 --- /dev/null +++ b/trinity/algorithm/advantage_fn/advantage_fn.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple + +from trinity.utils.registry import Registry + +ADVANTAGE_FN = Registry("advantage_fn") + + +class AdvantageFn(ABC): + @abstractmethod + def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]: + """Calculate advantages from experiences + + Args: + exps (`DataProto`): The input experiences. + kwargs (`Dict`): The step-level parameters for calculating advantages. + + Returns: + `DataProto`: The experiences with advantages. + `Dict`: The metrics for logging. + """ + + @classmethod + @abstractmethod + def default_args(cls) -> Dict: + """ + Returns: + `Dict`: The default init arguments for the advantage function. + """ diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py new file mode 100644 index 0000000000..553af6d065 --- /dev/null +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -0,0 +1,83 @@ +"""GRPO advantage computation + +Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py +""" + +from collections import defaultdict +from typing import Dict, Tuple + +import torch +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn + + +@ADVANTAGE_FN.register_module("grpo") +class GRPOAdvantageFn(AdvantageFn): + """GRPO advantage computation""" + + def __init__( + self, + epsilon: float = 1e-6, + ) -> None: + self.epsilon = epsilon + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + """ + Compute advantage for GRPO, operating only on Outcome reward + (with only one scalar reward for each response). + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + scores: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + eos_mask = exps.batch["response_mask"] + index = exps.non_tensor_batch["uid"] + epsilon = self.epsilon + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "epsilon": 1e-6, + } diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py new file mode 100644 index 0000000000..b27e2c9ab0 --- /dev/null +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -0,0 +1,94 @@ +"""OPMD advantage computation""" + +from collections import defaultdict +from typing import Dict, Tuple + +import torch +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn + + +@ADVANTAGE_FN.register_module("opmd") +class OPMDAdvantageFn(AdvantageFn): + """OPMD advantage computation""" + + def __init__( + self, + opmd_baseline: str = "mean", + tau: float = 1.0, + ) -> None: + self.opmd_baseline = opmd_baseline + self.tau = tau + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + """Modified from compute_grpo_outcome_advantage + + Compute advantage for OPMD, operating only on Outcome reward + (with only one scalar reward for each response). + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + scores: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + eos_mask = exps.batch["response_mask"] + # TODO (yanxi): confirm consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation + index = exps.non_tensor_batch["uid"] + opmd_baseline = self.opmd_baseline + tau = self.tau + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2baseline = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2baseline[idx] = torch.tensor(0.0) + # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?) + elif len(id2score[idx]) > 1: + if opmd_baseline == "mean": + id2baseline[idx] = torch.mean(torch.tensor(id2score[idx])) + elif opmd_baseline == "logavgexp": + rewards_tensor = torch.tensor(id2score[idx]) + # here we use the fact that logavgexp(x) = logsumexp(x) - log(len(x)) + id2baseline[idx] = tau * ( + torch.logsumexp(rewards_tensor / tau, dim=-1) + - torch.log(torch.tensor(len(id2score[idx]))) + ) + else: + raise NotImplementedError + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2baseline[index[i]] + scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "opmd_baseline": "mean", + "tau": 1.0, + } diff --git a/trinity/algorithm/advantage_fn/ppo_advantage.py b/trinity/algorithm/advantage_fn/ppo_advantage.py new file mode 100644 index 0000000000..31fda4454c --- /dev/null +++ b/trinity/algorithm/advantage_fn/ppo_advantage.py @@ -0,0 +1,86 @@ +"""PPO's GAE advantage computation + +Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py +""" + +from typing import Dict, Tuple + +import torch +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.utils import masked_whiten + + +@ADVANTAGE_FN.register_module("ppo") +class PPOAdvantageFn(AdvantageFn): + def __init__( + self, + gamma: float = 1.0, + lam: float = 1.0, + ) -> None: + self.gamma = gamma + self.lam = lam + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + """ + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + values: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. + gamma: `(float)` + discounted factor used in RL + lam: `(float)` + lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + values = exps.batch["values"] + eos_mask = exps.batch["response_mask"] + gamma = self.gamma + lam = self.lam + + with torch.no_grad(): + lastgaelam = 0 + advantages_reversed = [] + gen_len = token_level_rewards.shape[-1] + + # values = values * eos_mask TODO: may use in multi-turn + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] + + lastgaelam = delta + gamma * lam * lastgaelam + # lastgaelam = torch.where( # TODO: may use in multi-turn + # eos_mask[:, t] == 1, delta + gamma * lam * lastgaelam, lastgaelam + # ) + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + + returns = advantages + values + advantages = masked_whiten(advantages, eos_mask) + + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "gamma": 1.0, + "lam": 1.0, + } diff --git a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py new file mode 100644 index 0000000000..eb63c3605b --- /dev/null +++ b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py @@ -0,0 +1,66 @@ +"""REINFORCE++ advantage computation + +Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py +""" + +from typing import Dict, Tuple + +import torch +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.utils import masked_whiten + + +@ADVANTAGE_FN.register_module("reinforceplusplus") +class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn): + def __init__(self, gamma: float = 1.0) -> None: + self.gamma = gamma + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + """ + Compute advantage for REINFORCE++. + This implementation is based on the paper: https://arxiv.org/abs/2501.03262 + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + eos_mask = exps.batch["response_mask"] + gamma = self.gamma + + with torch.no_grad(): + returns = torch.zeros_like(token_level_rewards) + running_return = 0 + + for t in reversed(range(token_level_rewards.shape[1])): + running_return = token_level_rewards[:, t] + gamma * running_return + returns[:, t] = running_return + + advantages = masked_whiten(returns, eos_mask) + advantages = advantages * eos_mask + + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "gamma": 1.0, + } diff --git a/trinity/algorithm/advantage_fn/remax_advantage.py b/trinity/algorithm/advantage_fn/remax_advantage.py new file mode 100644 index 0000000000..07f92d91a0 --- /dev/null +++ b/trinity/algorithm/advantage_fn/remax_advantage.py @@ -0,0 +1,66 @@ +"""REMAX advantage computation + +Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py +""" + +from typing import Dict, Tuple + +import torch +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn + + +@ADVANTAGE_FN.register_module("remax") +class REMAXAdvantageFn(AdvantageFn): + def __init__(self) -> None: + pass + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + """ + Compute advantage for ReMax, operating only on Outcome reward + (with only one scalar reward for each response). + This implementation is based on the paper: https://arxiv.org/abs/2310.10505 + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + reward_baselines: `(torch.Tensor)` + shape: (bs,) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + reward_baselines = exps.batch["reward_baselines"] + eos_mask = exps.batch["response_mask"] + + response_length = token_level_rewards.shape[-1] + token_level_rewards.sum(dim=-1) + + with torch.no_grad(): + returns = ( + (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + ) + advantages = ( + returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask + ) + + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/algorithm/advantage_fn/rloo_advantage.py b/trinity/algorithm/advantage_fn/rloo_advantage.py new file mode 100644 index 0000000000..fb2680a68b --- /dev/null +++ b/trinity/algorithm/advantage_fn/rloo_advantage.py @@ -0,0 +1,75 @@ +"""RLOO advantage computation + +Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py +""" + +from collections import defaultdict +from typing import Dict, Tuple + +import torch +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn + + +@ADVANTAGE_FN.register_module("rloo") +class RLOOAdvantageFn(AdvantageFn): + def __init__(self) -> None: + pass + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + """ + Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + scores: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + eos_mask = exps.batch["response_mask"] + index = exps.non_tensor_batch["uid"] + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + response_num = len(id2score[index[i]]) + if response_num > 1: + scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[ + index[i] + ] * response_num / (response_num - 1) + scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py new file mode 100644 index 0000000000..54f5c3d296 --- /dev/null +++ b/trinity/algorithm/algorithm.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +"""Algorithm classes.""" + +from abc import ABC, ABCMeta, abstractmethod +from typing import Dict + +from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel +from trinity.common.config import Config +from trinity.common.constants import SyncMethod +from trinity.utils.log import get_logger +from trinity.utils.registry import Registry + +logger = get_logger(__name__) + +ALGORITHM_TYPE = Registry("algorithm") + + +class ConstantMeta(ABCMeta): + def __setattr__(cls, name, value): + if name in cls.__dict__: + raise AttributeError(f"{name} is already defined in {cls.__name__}") + return super().__setattr__(name, value) + + +class AlgorithmType(ABC, metaclass=ConstantMeta): + use_critic: bool + use_reference: bool + use_advantage: bool + can_balance_batch: bool + schema: type + + @classmethod + @abstractmethod + def default_config(cls) -> Dict: + raise NotImplementedError + + @classmethod + def name(cls) -> str: + return cls._name + + @classmethod + def check_config(cls, config: Config) -> None: + pass + + +@ALGORITHM_TYPE.register_module("sft") +class SFTAlgorithm(AlgorithmType): + """SFT Algorithm.""" + + use_critic: bool = False + use_reference: bool = False + use_advantage: bool = False + can_balance_batch: bool = True + schema: type = SFTDataModel + + @classmethod + def default_config(cls) -> Dict: + return { + "sample_strategy": "default", + "policy_loss_fn": "sft", + "kl_loss_fn": "none", + "entropy_loss_fn": "none", + } + + +@ALGORITHM_TYPE.register_module("ppo") +class PPOAlgorithm(AlgorithmType): + """PPO Algorithm.""" + + use_critic: bool = True + use_reference: bool = True + use_advantage: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 1, + "sample_strategy": "warmup", + "policy_loss_fn": "ppo", + "advantage_fn": "ppo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } + + +@ALGORITHM_TYPE.register_module("grpo") +class GRPOAlgorithm(AlgorithmType): + """GRPO algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "sample_strategy": "warmup", + "policy_loss_fn": "ppo", + "advantage_fn": "grpo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } + + +@ALGORITHM_TYPE.register_module("opmd") +class OPMDAlgorithm(AlgorithmType): + """OPMD algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "sample_strategy": "warmup", + "policy_loss_fn": "opmd", + "advantage_fn": "opmd", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } + + +@ALGORITHM_TYPE.register_module("dpo") +class DPOAlgorithm(AlgorithmType): + """DPO algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = False + can_balance_batch: bool = False + schema: type = DPODataModel + + @classmethod + def default_config(cls) -> Dict: + return { + "sample_strategy": "dpo", + "policy_loss_fn": "dpo", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } + + @classmethod + def check_config(cls, config: Config) -> None: + if config.model == "train": + if ( + config.buffer.trainer_input.experience_buffer is None + or not config.buffer.trainer_input.experience_buffer.path + ): + raise ValueError( + "`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == dpo`" + ) + elif config.mode in ["both", "explore"]: + raise ValueError(f"DPO does not support `{config.mode}` mode") + + if config.synchronizer.sync_method != SyncMethod.CHECKPOINT: + config.synchronizer.sync_method = SyncMethod.CHECKPOINT + logger.warning( + "DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." + ) + if config.algorithm.repeat_times != 2: + config.algorithm.repeat_times = 2 # Fake repeat times + if config.algorithm.kl_loss_fn in {"none", None}: + config.algorithm.kl_loss_fn = "k2" + logger.warning("DPO must use KL loss. Set `algorithm.kl_loss_fn` to `k2`") + + +@ALGORITHM_TYPE.register_module("mix") +class MIXAlgorithm(AlgorithmType): + """MIX algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = True + use_rollout: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 8, + "policy_loss_fn": "mix", + "advantage_fn": "grpo", + "sample_strategy": "mix", + } diff --git a/trinity/algorithm/algorithm_manager.py b/trinity/algorithm/algorithm_manager.py new file mode 100644 index 0000000000..82cef5ebbd --- /dev/null +++ b/trinity/algorithm/algorithm_manager.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +"""AlgorithmManager for switching between SFT and RFT.""" + +from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ENTROPY_LOSS_FN +from trinity.algorithm.kl_fn.kl_fn import KL_FN +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN +from trinity.common.config import AlgorithmConfig, Config + + +class AlgorithmManager: + def __init__(self, config: Config): + self.config = config + sft_type = ALGORITHM_TYPE.get("sft") + sft_default_config = sft_type.default_config() + self.sft_algorithm_config = AlgorithmConfig( + algorithm_type="sft", + **sft_default_config, + ) + policy_fn_cls = POLICY_LOSS_FN.get(self.sft_algorithm_config.policy_loss_fn) + self.sft_algorithm_config.policy_loss_fn_args = policy_fn_cls.default_args() + kl_loss_fn_cls = KL_FN.get(self.sft_algorithm_config.kl_loss_fn) + self.sft_algorithm_config.kl_loss_fn_args = kl_loss_fn_cls.default_args() + entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(self.sft_algorithm_config.entropy_loss_fn) + self.sft_algorithm_config.entropy_loss_fn_args = entropy_loss_fn_cls.default_args() + + def get_current_algorithm_config(self, global_steps: int): + if global_steps <= self.config.buffer.trainer_input.sft_warmup_steps: + return self.sft_algorithm_config + else: + return self.config.algorithm + + def need_save(self, global_steps: int): + return global_steps == self.config.buffer.trainer_input.sft_warmup_steps diff --git a/trinity/algorithm/entropy_loss_fn/__init__.py b/trinity/algorithm/entropy_loss_fn/__init__.py new file mode 100644 index 0000000000..d932b94fde --- /dev/null +++ b/trinity/algorithm/entropy_loss_fn/__init__.py @@ -0,0 +1,9 @@ +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ( + ENTROPY_LOSS_FN, + EntropyLossFn, +) + +__all__ = [ + "EntropyLossFn", + "ENTROPY_LOSS_FN", +] diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py new file mode 100644 index 0000000000..d6179a832c --- /dev/null +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -0,0 +1,77 @@ +from abc import ABC, abstractmethod +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.utils import masked_mean +from trinity.utils.registry import Registry + +ENTROPY_LOSS_FN = Registry("entropy_loss_fn") + + +class EntropyLossFn(ABC): + """ + Entropy loss function. + """ + + @abstractmethod + def __call__( + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + """ + Args: + entropy (`torch.Tensor`): The entropy generated by the policy model. + action_mask (`torch.Tensor`): The action mask. + + Returns: + `torch.Tensor`: The calculated entropy loss. + `Dict`: The metrics for logging + """ + + @classmethod + def default_args(cls) -> Dict: + """ + Returns: + `Dict`: The default arguments for the entropy loss function. + """ + return {"entropy_coef": 0.0} + + +@ENTROPY_LOSS_FN.register_module("default") +class DefaultEntropyLossFn(EntropyLossFn): + """ + Basic entropy loss function. + """ + + def __init__(self, entropy_coef: float): + self.entropy_coef = entropy_coef + + def __call__( + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + entropy_loss = masked_mean(entropy, action_mask) + return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} + + +@ENTROPY_LOSS_FN.register_module("none") +class DummyEntropyLossFn(EntropyLossFn): + """ + Dummy entropy loss function. + """ + + def __init__(self, entropy_coef: float): + self.entropy_coef = entropy_coef + + def __call__( + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + return torch.tensor(0.0), {} diff --git a/trinity/algorithm/key_mapper.py b/trinity/algorithm/key_mapper.py new file mode 100644 index 0000000000..09c1f988a6 --- /dev/null +++ b/trinity/algorithm/key_mapper.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +"""Key Mapper""" + +from typing import Dict + + +class KeyMapper: + def __init__(self, to_trinity_map: Dict[str, str]): + self.to_trinity_map = to_trinity_map + self.from_trinity_map = {v: k for k, v in self.to_trinity_map.items()} + + def to_trinity(self, key: str) -> str: + return self.to_trinity_map.get(key, key) + + def from_trinity(self, key: str) -> str: + return self.from_trinity_map.get(key, key) + + +ALL_MAPPERS = { + "verl": KeyMapper( + { + "log_prob": "logprob", + "old_log_probs": "old_logprob", + "ref_log_prob": "ref_logprob", + "response_mask": "action_mask", + "advantages": "advantages", + } + ), +} diff --git a/trinity/algorithm/kl_fn/__init__.py b/trinity/algorithm/kl_fn/__init__.py new file mode 100644 index 0000000000..875c620442 --- /dev/null +++ b/trinity/algorithm/kl_fn/__init__.py @@ -0,0 +1,3 @@ +from trinity.algorithm.kl_fn.kl_fn import KL_FN, KLFn + +__all__ = ["KLFn", "KL_FN"] diff --git a/trinity/algorithm/kl_fn/kl_fn.py b/trinity/algorithm/kl_fn/kl_fn.py new file mode 100644 index 0000000000..62ed48cd49 --- /dev/null +++ b/trinity/algorithm/kl_fn/kl_fn.py @@ -0,0 +1,165 @@ +"""KL penalty and loss. + +Ref: +https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py +https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py +https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/utils.py +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Tuple + +import torch + +from trinity.algorithm.utils import masked_mean +from trinity.utils.registry import Registry + +KL_FN = Registry("kl_fn") + + +class KLFn(ABC): + """ + KL penalty and loss. + """ + + def __init__( + self, + adaptive: bool = False, + kl_coef: float = 0.001, + target_kl: Optional[float] = None, + horizon: Optional[float] = None, + ) -> None: + self.kl_coef = kl_coef + self.adaptive = adaptive + self.target_kl = target_kl + self.horizon = horizon + if adaptive and (target_kl is None or horizon is None): + raise ValueError("Target KL and horizon must be provided for adaptive KL.") + + def update_kl_coef(self, current_kl: float, batch_size: int) -> None: + """Update kl coefficient.""" + if self.adaptive: + target_kl = self.target_kl + proportional_error = torch.clip(current_kl / target_kl - 1, -0.2, 0.2).item() # type: ignore + multiplier = 1 + proportional_error * batch_size / self.horizon + self.kl_coef *= multiplier + + def apply_kl_penalty_to_reward(self, experiences: Any) -> Tuple[Any, Dict]: + """Apply KL penalty to reward. Only support DataProto input for now.""" + responses = experiences.batch["responses"] + response_length = responses.size(1) + token_level_scores = experiences.batch["token_level_scores"] + batch_size = experiences.batch.batch_size[0] + attention_mask = experiences.batch["attention_mask"] + response_mask = experiences.batch["response_mask"] + assert response_mask.shape == attention_mask[:, -response_length:].shape + logprob = experiences.batch["old_log_probs"] + ref_logprob = experiences.batch["ref_log_prob"] + + if "ref_log_prob" in experiences.batch.keys(): + kl = self.calculate_kl(logprob, ref_logprob) + kl = kl * response_mask + kl_coef = self.kl_coef + experiences.batch["token_level_rewards"] = token_level_scores - kl_coef * kl + else: + kl_coef = 0.0 + kl = torch.zeros_like(response_mask, dtype=torch.float32) + experiences.batch["token_level_rewards"] = token_level_scores + + current_kl = masked_mean(kl, mask=response_mask, axis=-1).mean(dim=0).item() + self.update_kl_coef(current_kl=current_kl, batch_size=batch_size) + + metrics = { + "kl": current_kl, + "kl_coef": kl_coef, + } + + return experiences, metrics + + def calculate_kl_loss( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + response_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict]: + """Compute KL loss.""" + kl = self.calculate_kl(logprob, ref_logprob) + kl_loss = masked_mean(kl, response_mask) + metrics = { + "kl_loss": kl_loss.detach().item(), + "kl_coef": self.kl_coef, + } + return kl_loss * self.kl_coef, metrics + + @abstractmethod + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + """Compute KL divergence between logprob and ref_logprob.""" + + @classmethod + def default_args(cls): + """Get the default initialization arguments.""" + return {"adaptive": False, "kl_coef": 0.001} + + +@KL_FN.register_module("none") +class DummyKLFn(KLFn): + """ + Dummy KL function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + return torch.zeros_like(logprob) + + def apply_kl_penalty_to_reward(self, experiences: Any) -> Tuple[Any, Dict]: + experiences.batch["token_level_rewards"] = experiences.batch["token_level_scores"] + return experiences, {} + + def calculate_kl_loss( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + response_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict]: + # return a zero tensor + return torch.tensor(0.0), {} + + +@KL_FN.register_module("k1") +class K1Fn(KLFn): + """ + KL K1 function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + return logprob - ref_logprob + + +@KL_FN.register_module("k2") +class K2Fn(KLFn): + """ + KL K2 function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + return (logprob - ref_logprob).square() * 0.5 + + +@KL_FN.register_module("k3") +class K3Fn(KLFn): + """ + KL K3 function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + logr = ref_logprob - logprob + return logr.exp() - 1 - logr + + +@KL_FN.register_module("abs") +class AbsFn(KLFn): + """ + KL Abs function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + return torch.abs(logprob - ref_logprob) diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py new file mode 100644 index 0000000000..705fb2525a --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -0,0 +1,16 @@ +from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn +from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn +from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn + +__all__ = [ + "POLICY_LOSS_FN", + "PolicyLossFn", + "PPOPolicyLossFn", + "OPMDPolicyLossFn", + "DPOLossFn", + "SFTLossFn", + "MIXPolicyLossFn", +] diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py new file mode 100644 index 0000000000..0858cb7002 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py @@ -0,0 +1,67 @@ +"""DPO loss function.""" + +from typing import Dict, Tuple + +import torch +import torch.nn.functional as F + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_sum + + +@POLICY_LOSS_FN.register_module("dpo") +class DPOLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + beta: float = 0.1, + label_smoothing: float = 0.0, + ) -> None: + super().__init__(backend=backend) + self.beta = beta + self.label_smoothing = label_smoothing + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + chosen_logprob = logprob[::2] + rejected_logprob = logprob[1::2] + chosen_mask = action_mask[::2] + rejected_mask = action_mask[1::2] + chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask) + rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask) + + chosen_ref_logprob = ref_logprob[::2] + rejected_ref_logprob = ref_logprob[1::2] + chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask) + rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask) + + chosen_ratios = chosen_logprob_sum - chosen_ref_logprob_sum + rejected_ratios = rejected_logprob_sum - rejected_ref_logprob_sum + logits = chosen_ratios - rejected_ratios + # TODO: support other loss functions + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + loss = losses.mean() + chosen_reward = self.beta * chosen_ratios.detach().mean().item() + rejected_reward = self.beta * rejected_ratios.detach().mean().item() + accuracy_mean = (chosen_ratios.detach() > rejected_ratios.detach()).float().mean().item() + return loss, { + "chosen_reward": chosen_reward, + "rejected_reward": rejected_reward, + "accuracy_mean": accuracy_mean, + "dpo_loss": loss.detach().item(), + } + + @classmethod + def default_args(cls) -> Dict: + return { + "beta": 0.1, + "label_smoothing": 0.0, + } diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py new file mode 100644 index 0000000000..76c89c42d9 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -0,0 +1,129 @@ +"""Mix policy loss function.""" + +from typing import Dict, Optional, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn + + +@POLICY_LOSS_FN.register_module("mix") +class MIXPolicyLossFn(PolicyLossFn): + """Implements a mixed policy loss combining GRPO and SFT losses. + + This loss function applies different loss components to data based on whether + it comes from an expert or not, as indicated by `is_expert_mask`. It combines: + - GRPO loss (self.grpo_loss_fn) for non-expert data + - SFT loss (self.sft_loss_fn) for expert data + - Weighting parameter `mu` + + The per-sample weights are normalized using either `experience_per_gpu` or + `gradient_accumulation`, depending on whether dynamic batch sizing is enabled, + to ensure consistent weighting across different batches of the same type experiences. + """ + + def __init__( + self, + backend: str = "verl", + mu: float = 0.1, + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + use_dynamic_bsz: Optional[bool] = None, + repeat_times: int = 1, + ppo_mini_batch_size: int = 1, + ppo_micro_batch_size_per_gpu: int = 1, + ngpus_trainer: int = 1, + read_batch_size_usual: int = 1, + read_batch_size_expert: int = 1, + use_token_level_loss_in_sft: bool = True, + ) -> None: + super().__init__(backend=backend) + self.mu = mu + self.use_dynamic_bsz = use_dynamic_bsz + self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer + self.gradient_accumulation = ( + ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu + ) + self.read_batch_size_usual = read_batch_size_usual // ngpus_trainer + self.read_batch_size_expert = read_batch_size_expert // ngpus_trainer + self.grpo_loss_fn = PPOPolicyLossFn( + clip_range=clip_range, + clip_range_low=clip_range_low, + clip_range_high=clip_range_high, + ) + self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft) + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + is_expert_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + assert ( + len(is_expert_mask) == logprob.shape[0] + ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" + + n_usual_exp = torch.sum(~is_expert_mask).item() + n_expert_exp = torch.sum(is_expert_mask).item() + + if self.use_dynamic_bsz: + per_micro_batch_weight_usual = self.experience_per_gpu / ( + logprob.shape[0] * self.read_batch_size_usual + ) + per_micro_batch_weight_expert = self.experience_per_gpu / ( + logprob.shape[0] * self.read_batch_size_expert + ) + else: + per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore + per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore + + if n_usual_exp > 0: + grpo_loss, grpo_metrics = self.grpo_loss_fn( + logprob[~is_expert_mask], + old_logprob[~is_expert_mask], + action_mask[~is_expert_mask], + advantages[~is_expert_mask], + **kwargs, + ) + grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual + grpo_metrics = { + k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() + } + else: + grpo_loss = torch.tensor(0.0, device=logprob.device) + grpo_metrics = {} + + # SFT Loss (expert) + if n_expert_exp > 0: + sft_loss, sft_metrics = self.sft_loss_fn( + logprob[is_expert_mask], + action_mask[is_expert_mask], + ) + sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert + sft_metrics = { + k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() + } + else: + sft_loss = torch.tensor(0.0, device=logprob.device) + sft_metrics = {} + + loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss + + metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} + metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + metrics.update({"loss": loss.item()}) + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "mu": 0.1, + "clip_range": 0.2, + } diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py new file mode 100644 index 0000000000..618301b319 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py @@ -0,0 +1,31 @@ +"""OPMD policy loss function.""" + +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("opmd") +class OPMDPolicyLossFn(PolicyLossFn): + def __init__(self, backend: str = "verl", tau: float = 1.0) -> None: + super().__init__(backend=backend) + self.tau = tau + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + pg_losses = -advantages * logprob + opmd_loss = masked_mean(pg_losses, action_mask) + opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta) + return opmd_loss, {"opmd_loss": opmd_loss.detach().item()} + + @classmethod + def default_args(cls) -> Dict: + return {"tau": 1.0} diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py new file mode 100644 index 0000000000..aa6025252e --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -0,0 +1,121 @@ +import inspect +from abc import ABC, ABCMeta, abstractmethod +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.key_mapper import ALL_MAPPERS +from trinity.utils.registry import Registry + +POLICY_LOSS_FN = Registry("policy_loss_fn") + + +class PolicyLossFnMeta(ABCMeta): + """Metaclass for policy loss functions that handles parameter name mapping and filtering.""" + + ignore_keys = {"self", "kwargs", "logprob"} # Keys to exclude from parameter selection + + def __new__(cls, name, bases, dct): + """ + Metaclass constructor that automatically generates parameter handling logic. + + For example with `PPOPolicyLossFn` class: + .. code-block:: python + class PPOPolicyLossFn(PolicyLossFn): + ... + def __call__( + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + ... + + This metaclass analyzes the __call__ method's parameters to: + 1. Generate _select_keys containing all non-ignored parameters + 2. Create select_keys property that maps parameters to trainer-specific names + 3. Apply decorator to automatically convert input parameter names using the mapper + """ + signature = inspect.signature(dct["__call__"]) + param_names = [ + key for key in signature.parameters.keys() if key not in PolicyLossFnMeta.ignore_keys + ] + dct["_select_keys"] = param_names + + # Property to return trainer-specific parameter names + def select_keys(self): + """Returns parameter keys mapped to the specific training framework's naming convention.""" + keys = [self.mapper.from_trinity(key) for key in self._select_keys] + return keys + + # Decorator to handle parameter name conversion before calling __call__ + def decorator(func): + def wrapper(self, *args, **kwargs): + """Filters and converts parameter names according to the training framework's convention.""" + new_kwargs = {} + for key, value in kwargs.items(): + key = self.mapper.to_trinity(key) + if key == "logprob" or key in self._select_keys: # remove unused keys + new_kwargs[key] = value + return func(self, *args, **new_kwargs) + + return wrapper + + # Add the property and decorated method to the class + dct["select_keys"] = property(select_keys) + dct["__call__"] = decorator(dct["__call__"]) + return super().__new__(cls, name, bases, dct) + + +class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta): + """ + Abstract base class for policy loss functions. + + This class provides the interface for implementing different policy gradient loss functions + while handling parameter name mapping between different training frameworks. + """ + + def __init__(self, backend: str = "verl"): + """ + Initialize the policy loss function. + + Args: + backend: The training framework/backend to use (e.g., "verl") + """ + self.backend = backend + self.mapper = ALL_MAPPERS[self.backend] + + @abstractmethod + def __call__( + self, + logprob: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + """ + Calculate the policy loss. + + Args: + logprob (`torch.Tensor`): The log probability generated by the policy model. + + Kwargs (optional): + old_logprob (`torch.Tensor`): The log probability generated by the reference model. + action_mask (`torch.Tensor`): The action mask. + advantages (`torch.Tensor`): The advantages. + kwargs (`Dict`): The step-level parameters for calculating the policy loss. + + Returns: + `torch.Tensor`: Policy loss + `Dict`: The metrics for logging. + """ + + @classmethod + @abstractmethod + def default_args(cls) -> Dict: + """ + Get default initialization arguments for this loss function. + + Returns: + `Dict`: The default init arguments for the policy loss function. + """ diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py new file mode 100644 index 0000000000..a4cc0b2d03 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -0,0 +1,65 @@ +"""PPO policy loss function. + +Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py +""" + +from typing import Dict, Optional, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("ppo") +class PPOPolicyLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + ) -> None: + super().__init__(backend=backend) + if clip_range_low is None: + self.clip_range_low = clip_range + else: + self.clip_range_low = clip_range_low + if clip_range_high is None: + self.clip_range_high = clip_range + else: + self.clip_range_high = clip_range_high + assert self.clip_range_low is not None, "clip_range_low must be specified." + assert self.clip_range_high is not None, "clip_range_high must be specified." + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + negative_approx_kl = logprob - old_logprob + ratio = torch.exp(negative_approx_kl) + ppo_kl = masked_mean(-negative_approx_kl, action_mask) + + pg_losses = -advantages * ratio + pg_losses2 = -advantages * torch.clamp( + ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore + ) + + pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), action_mask) + pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask) + metrics = { + "pg_clipfrac": pg_clipfrac.detach().item(), + "ppo_kl": ppo_kl.detach().item(), + "pg_loss": pg_loss.detach().item(), + } + return pg_loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "clip_range": 0.2, + } diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py new file mode 100644 index 0000000000..2c824f1c09 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/sft_loss.py @@ -0,0 +1,33 @@ +"""SFT loss function.""" + +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("sft") +class SFTLossFn(PolicyLossFn): + def __init__(self, backend: str = "verl", use_token_level_loss: bool = True) -> None: + super().__init__(backend=backend) + self.use_token_level_loss = use_token_level_loss + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + if self.use_token_level_loss: + sft_loss = masked_mean(-logprob, action_mask) + else: + sft_loss = masked_mean(-logprob, action_mask, axis=1).mean() + return sft_loss, {"sft_loss": sft_loss.detach().item()} + + @classmethod + def default_args(cls): + return { + "use_token_level_loss": True, + } diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py new file mode 100644 index 0000000000..cd4b9e0d66 --- /dev/null +++ b/trinity/algorithm/sample_strategy/__init__.py @@ -0,0 +1,15 @@ +from trinity.algorithm.sample_strategy.mix_sample_strategy import MixSampleStrategy +from trinity.algorithm.sample_strategy.sample_strategy import ( + SAMPLE_STRATEGY, + DefaultSampleStrategy, + SampleStrategy, + WarmupSampleStrategy, +) + +__all__ = [ + "SAMPLE_STRATEGY", + "SampleStrategy", + "DefaultSampleStrategy", + "WarmupSampleStrategy", + "MixSampleStrategy", +] diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py new file mode 100644 index 0000000000..25811e9190 --- /dev/null +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -0,0 +1,118 @@ +import copy +from math import ceil +from typing import Any, Dict, List, Tuple + +import numpy as np +import torch +from verl.trainer.ppo.ray_trainer import DataProto + +from trinity.algorithm.sample_strategy.sample_strategy import ( + SAMPLE_STRATEGY, + SampleStrategy, +) +from trinity.algorithm.sample_strategy.utils import representative_sample +from trinity.buffer import get_buffer_reader +from trinity.common.config import BufferConfig +from trinity.common.experience import Experiences +from trinity.utils.timer import Timer + + +@SAMPLE_STRATEGY.register_module("mix") +class MixSampleStrategy(SampleStrategy): + """The default sample strategy.""" + + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) + tot_batch_size = buffer_config.read_batch_size + expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) + + # experience buffer + usual_buffer_config = copy.deepcopy(buffer_config) + usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size + self.usual_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore + ) + + if buffer_config.trainer_input.sft_warmup_dataset is None: + raise ValueError( + "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" + ) + + # expert experience buffer + expert_buffer_config = copy.deepcopy(buffer_config) + expert_buffer_config.read_batch_size = expert_batch_size + self.expert_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config + ) + + def sample(self, step: int) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + usual_exp_list = self.usual_exp_buffer.read() + for exp in usual_exp_list: + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = False + + expert_exp_list = self.expert_exp_buffer.read() + for exp in expert_exp_list: + exp.reward = 0.0 + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = True + + exp_list = usual_exp_list + expert_exp_list + repr_samples = representative_sample(exp_list) + + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) + + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto_mix(exps, is_expert_mask) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + @classmethod + def default_args(cls) -> Dict: + return { + "expert_data_ratio": 0.5, + } + + +def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), + "is_expert_mask": is_expert_mask, + } + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + return DataProto.from_single_dict(batch_dict) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py new file mode 100644 index 0000000000..6e530d32ce --- /dev/null +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -0,0 +1,142 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple + +from trinity.algorithm.sample_strategy.utils import representative_sample, to_data_proto +from trinity.buffer import get_buffer_reader +from trinity.common.config import BufferConfig +from trinity.common.experience import Experiences +from trinity.utils.registry import Registry +from trinity.utils.timer import Timer + +SAMPLE_STRATEGY = Registry("sample_strategy") + + +class SampleStrategy(ABC): + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs) -> None: + self.pad_token_id = buffer_config.pad_token_id + self.trainer_type = trainer_type + + @abstractmethod + def sample(self, step: int) -> Tuple[Any, Dict, List]: + """Sample data from buffer. + + Args: + step (`int`): The step number of current step. + + Returns: + `Any`: The sampled data. + `Dict`: Metrics for logging. + `List`: Representative data for logging. + """ + + # Experimental API + @abstractmethod + def warmup_state(self, step: int) -> Tuple[bool, bool]: + """Check the warmup state of the current step. + + Args: + step (`int`): The step number of current step. + + Returns: + `bool`: Current step is in warmup or not. + `bool`: Warmup is finished on this step or not. + """ + + @classmethod + @abstractmethod + def default_args(cls) -> dict: + """Get the default arguments of the sample strategy.""" + + +@SAMPLE_STRATEGY.register_module("warmup") +class WarmupSampleStrategy(SampleStrategy): + """The default sample strategy.""" + + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore + ) + self.sft_warmup_steps = buffer_config.trainer_input.sft_warmup_steps + if self.sft_warmup_steps > 0 and buffer_config.trainer_input.sft_warmup_dataset is None: + raise ValueError("sft_warmup_dataset is required when sft_warmup_steps > 0") + if buffer_config.trainer_input.sft_warmup_dataset is not None: + self.sft_buffer = get_buffer_reader( + buffer_config.trainer_input.sft_warmup_dataset, buffer_config + ) + else: + self.sft_buffer = None + + def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + if step <= self.sft_warmup_steps: + exp_list = self.sft_buffer.read() + else: + exp_list = self.exp_buffer.read() + repr_samples = representative_sample(exp_list) + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto(exps) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + def warmup_state(self, step: int) -> Tuple[bool, bool]: + return step <= self.sft_warmup_steps, step == self.sft_warmup_steps + + @classmethod + def default_args(cls) -> dict: + return {} + + +@SAMPLE_STRATEGY.register_module("default") +class DefaultSampleStrategy(SampleStrategy): + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore + ) + + def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + exp_list = self.exp_buffer.read() + repr_samples = representative_sample(exp_list) + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto(exps) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + def warmup_state(self, step: int) -> Tuple[bool, bool]: + return False, False + + @classmethod + def default_args(cls) -> dict: + return {} + + +@SAMPLE_STRATEGY.register_module("dpo") +class DPOSampleStrategy(WarmupSampleStrategy): + def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + if step <= self.sft_warmup_steps: + exp_list = self.sft_buffer.read() + else: + exp_list = self.exp_buffer.read() + repr_samples = representative_sample(exp_list) + with Timer(metrics, "gather_time"): + exps = Experiences.gather_dpo_experiences(exp_list, pad_token_id=self.pad_token_id) # type: ignore + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto(exps) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py new file mode 100644 index 0000000000..8c443a20b1 --- /dev/null +++ b/trinity/algorithm/sample_strategy/utils.py @@ -0,0 +1,78 @@ +import random +from typing import List + +import numpy as np +import torch +from verl.trainer.ppo.ray_trainer import DataProto + +from trinity.common.experience import Experience, Experiences + + +def to_data_proto(experiences: Experiences) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), + } + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + return DataProto.from_single_dict(batch_dict) + + +def representative_sample(experiences: List[Experience]) -> List[dict]: + if experiences[0].reward is None: + sample = random.choice(experiences) + return [ + { + "prompt": sample.prompt_text, + "response": sample.response_text, + } + ] + samples = [] + min_reward_sample = None + max_reward_sample = None + for exp in experiences: + if exp.reward is None: + continue + if min_reward_sample is None or exp.reward < min_reward_sample.reward: + min_reward_sample = exp + if max_reward_sample is None or exp.reward > max_reward_sample.reward: + max_reward_sample = exp + if min_reward_sample is not None: + samples.append( + { + "prompt": min_reward_sample.prompt_text, + "response": min_reward_sample.response_text, + "reward": min_reward_sample.reward, + } + ) + if max_reward_sample is not None: + samples.append( + { + "prompt": max_reward_sample.prompt_text, + "response": max_reward_sample.response_text, + "reward": max_reward_sample.reward, + } + ) + return samples diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py new file mode 100644 index 0000000000..8660a6376c --- /dev/null +++ b/trinity/algorithm/utils.py @@ -0,0 +1,62 @@ +"""Common utils for algorithm module. + +Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py +""" + +import torch + + +def masked_sum(values, mask, axis=None): + """Compute mean of tensor with a masked values.""" + return (values * mask).sum(axis=axis) + + +def masked_mean(values, mask, axis=None): + """Compute mean of tensor with a masked values.""" + return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8) + + +def masked_var(values, mask, unbiased=True): + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError("At least one element in the mask has to be 1.") + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + if mask_sum == 1: + raise ValueError("The sum of the mask is one, which can cause a division by zero.") + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values, mask, shift_mean=True): + """ + Whiten `values` by normalizing with mean and variance computed over `mask`. + + Args: + values (torch.Tensor): Input tensor. + mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats. + shift_mean (bool): If True (default), output is zero-mean; + if False, the original mean is re-added after scaling. + + Returns: + torch.Tensor: Whitened tensor of same shape as `values`. + """ + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) -> dict: + if dst_metrics is None: + dst_metrics = {} + for k, v in src_metrics.items(): + dst_metrics[f"{prefix}/{k}"] = v + return dst_metrics diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index 69f40a8d70..060ed05b9e 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -41,9 +41,11 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig elif storage_config.storage_type == StorageType.FILE: from trinity.buffer.reader.file_reader import FILE_READERS - file_read_type = storage_config.algorithm_type - if file_read_type is not None: - file_read_type = file_read_type.value + algorithm_type = storage_config.algorithm_type + if storage_config.raw: + file_read_type = "raw" + elif algorithm_type is not None: + file_read_type = algorithm_type else: file_read_type = "rollout" return FILE_READERS.get(file_read_type)(storage_config, buffer_config) # type: ignore diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 63de366db6..b7cf06b2b5 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -142,6 +142,8 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: raise ValueError( f"File path must end with '.json' or '.jsonl', got {storage_config.path}" ) + path_dir = os.path.dirname(storage_config.path) + os.makedirs(path_dir, exist_ok=True) self.file = open(storage_config.path, "a", encoding="utf-8") self.encoder = _Encoder(ensure_ascii=False) diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 5eec9c4464..43d83f2cf4 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -7,9 +7,10 @@ from datasets import Dataset, load_dataset from tqdm import tqdm +from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm from trinity.buffer.buffer_reader import BufferReader from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType +from trinity.common.constants import PromptType, ReadStrategy, TaskType from trinity.common.experience import Experience from trinity.common.rewards import REWARD_FUNCTIONS from trinity.common.workflows import WORKFLOWS, Task @@ -65,7 +66,7 @@ def read_batch(self, batch_size: int) -> List: return batch -@FILE_READERS.register_module(AlgorithmType.SFT.value) +@FILE_READERS.register_module(SFTAlgorithm.name()) class SFTDataReader(BufferReader): """Reader for SFT file data.""" @@ -78,9 +79,9 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.response_key = meta.format.response_key self.read_batch_size = config.read_batch_size self.dataset = _HFBatchReader( - load_dataset(meta.path, name=subset_name, split=self.split) + load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), + max_epoch=meta.total_epochs, ) # TODO: support resume - self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) def read( @@ -144,7 +145,7 @@ def read( return exp_list -@FILE_READERS.register_module(AlgorithmType.DPO.value) +@FILE_READERS.register_module(DPOAlgorithm.name()) class DPODataReader(BufferReader): def __init__(self, meta: StorageConfig, config: BufferConfig): self.split = meta.split @@ -155,9 +156,9 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.rejected_key = meta.format.rejected_key self.read_batch_size = config.read_batch_size self.dataset = _HFBatchReader( - load_dataset(meta.path, name=subset_name, split=self.split) + load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), + max_epoch=meta.total_epochs, ) # TODO: support resume - self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) def _get_assistant_message(self, item) -> dict: @@ -227,7 +228,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.epoch = 0 datasets.disable_caching() self.dataset = _HFBatchReader( - load_dataset(meta.path, name=subset_name, split=self.split), + load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), max_epoch=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1, offset=self.meta.index, ) @@ -272,3 +273,23 @@ def read( ) tasks.append(task) return tasks + + +@FILE_READERS.register_module("raw") +class RawDataReader(BufferReader): + def __init__(self, meta: StorageConfig, config: Optional[BufferConfig]): + self.returned = False + self.dataset = load_dataset( + meta.path, name=meta.subset_name, split=meta.split, trust_remote_code=True + ) + + def __len__(self): + return len(self.dataset) + + def read( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: + if self.returned: + raise StopIteration + self.returned = True + return self.dataset.to_list() diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index db2e4ca137..21289c7768 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -5,9 +5,7 @@ from sqlalchemy import Column, Float, Integer, LargeBinary, String from sqlalchemy.ext.declarative import declarative_base -from trinity.common.constants import AlgorithmType from trinity.common.experience import Experience -from trinity.common.models.utils import tokenize_and_mask_messages_hf Base = declarative_base() @@ -85,6 +83,8 @@ def from_messages( chat_template: Optional[str] = None, ) -> "SFTDataModel": """Convert a list of messages into a single instance of SFT data.""" + from trinity.common.models.utils import tokenize_and_mask_messages_hf + token_ids, action_mask = tokenize_and_mask_messages_hf( tokenizer=tokenizer, messages=messages, @@ -125,22 +125,15 @@ def to_experience(self) -> Experience: return exp -SCHEMA_MAPPING = { - None: TaskModel, - AlgorithmType.SFT: SFTDataModel, - AlgorithmType.PPO: ExperienceModel, - AlgorithmType.GRPO: ExperienceModel, - AlgorithmType.OPMD: ExperienceModel, - AlgorithmType.DPO: DPODataModel, -} - - -def create_dynamic_table(algorithm_type: Union[AlgorithmType | None], table_name: str) -> Any: +def create_dynamic_table(algorithm_type: Union[str | None], table_name: str) -> Any: """Create a dynamic table based on the provided algorithm type and table name.""" - if algorithm_type not in SCHEMA_MAPPING: - raise ValueError(f"Unknown schema: {algorithm_type}") + if algorithm_type is None: + base_class = TaskModel + else: + from trinity.algorithm.algorithm import ALGORITHM_TYPE - base_class = SCHEMA_MAPPING[algorithm_type] + algorithm = ALGORITHM_TYPE.get(algorithm_type) + base_class = algorithm.schema table_attrs = { "__tablename__": table_name, diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index e251792f7c..8864dc9b82 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -14,8 +14,6 @@ class SQLWriter(BufferWriter): def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.SQL # we only support write RFT algorithm buffer for now - # TODO: support other algorithms - assert meta.algorithm_type.is_rft(), "Only RFT buffer is supported for writing." self.wrap_in_ray = meta.wrap_in_ray self.db_wrapper = DBWrapper.get_wrapper(meta, config) diff --git a/trinity/cli/client.py b/trinity/cli/client.py index 311de1b9d8..cc3318b570 100644 --- a/trinity/cli/client.py +++ b/trinity/cli/client.py @@ -31,12 +31,12 @@ def request(url, **kwargs): if __name__ == "__main__": # --- only for local testing - LOCAL_DATA_WORKFLOW_SERVER_URL = "http://127.0.0.1:5005/data_workflow" + LOCAL_DATA_PROCESSOR_SERVER_URL = "http://127.0.0.1:5005/data_processor" LOCAL_TRINITY_TRAINING_SERVER_URL = "http://127.0.0.1:5006/trinity_rft" # --- only for local testing res = request( - url=LOCAL_DATA_WORKFLOW_SERVER_URL, + url=LOCAL_DATA_PROCESSOR_SERVER_URL, configPath="examples/grpo_gsm8k/gsm8k.yaml", ) if res: diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index d3156ebd6f..3ea4f0486f 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -2,13 +2,14 @@ import argparse import os import sys +import traceback from pathlib import Path from pprint import pprint import ray -from trinity.common.config import Config, load_config -from trinity.common.constants import AlgorithmType +from trinity.common.config import Config, DataPipelineConfig, load_config +from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger @@ -19,56 +20,41 @@ def bench(config: Config) -> None: """Evaluate model.""" - explorer = Explorer.remote(config) + explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config) try: ray.get(explorer.prepare.remote()) ray.get(explorer.benchmark.remote()) logger.info("Benchmark finished.") ray.get(explorer.shutdown.remote()) - except Exception as e: - logger.error(f"Benchmark failed: {e}") - raise e + except Exception: + error_msg = traceback.format_exc() + logger.error(f"Benchmark failed:\n{error_msg}") def explore(config: Config) -> None: """Run explorer.""" - explorer = Explorer.remote(config) try: + explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config) ray.get(explorer.prepare.remote()) ray.get(explorer.sync_weight.remote()) ray.get(explorer.explore.remote()) - logger.info("Explore finished.") ray.get(explorer.shutdown.remote()) - except Exception as e: - logger.error(f"Explore failed: {e}") - raise e + except Exception: + error_msg = traceback.format_exc() + logger.error(f"Explorer failed:\n{error_msg}") def train(config: Config) -> None: """Run trainer.""" - - trainer = Trainer.remote(config) - ray.get(trainer.prepare.remote()) - - if config.buffer.trainer_input.sft_warmup_steps > 0: - while True: - train_continue, train_step_num = ray.get( - trainer.train_one_period.remote(AlgorithmType.SFT) - ) - if train_step_num <= config.buffer.trainer_input.sft_warmup_steps: - logger.info(f"SFT warmup step {train_step_num} finished.") - if not train_continue: - logger.info("SFT warmup finished.") - break - - algo_type = config.algorithm.algorithm_type try: - ray.get(trainer.train.remote(algo_type)) - logger.info("Train finished.") + trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config) + ray.get(trainer.prepare.remote()) + ray.get(trainer.sync_weight.remote()) + ray.get(trainer.train.remote()) ray.get(trainer.shutdown.remote()) - except Exception as e: - logger.error(f"Train failed {e}.") - raise e + except Exception: + error_msg = traceback.format_exc() + logger.error(f"Trainer failed:\n{error_msg}") def both(config: Config) -> None: @@ -81,76 +67,58 @@ def both(config: Config) -> None: the latest step. The specific number of experiences may vary for different algorithms and tasks. """ - explorer = Explorer.remote(config) - trainer = Trainer.remote(config) + explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config) + trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config) ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()]) - logger.info("Setup explorer and trainer finished.") ray.get( [ explorer.prepare.remote(), trainer.prepare.remote(), ] ) - # sync weight before training start - ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) - - if config.buffer.trainer_input.sft_warmup_steps > 0: - while True: - train_continue, train_step_num = ray.get( - trainer.train_one_period.remote(AlgorithmType.SFT) - ) - if train_step_num <= config.buffer.trainer_input.sft_warmup_steps: - logger.info(f"SFT warmup step {train_step_num} finished.") - if not train_continue: - logger.info("SFT warmup finished.") - break - ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) - - algo_type = config.algorithm.algorithm_type - while True: - try: - ref_explore = explorer.explore_one_period.remote() - ref_train = trainer.train_one_period.remote(algo_type) - explore_continue, explore_step_num = ray.get(ref_explore) - train_continue, train_step_num = ray.get(ref_train) - if not explore_continue: - # If explore finished, the trainer may not have enough experiences to continue, - # which will cause the trainer be blocked. So we stop the training process - # immediately. - # TODO: use a more elegant way to stop the training process. - logger.info("Explorer finished, stopping...") - break - if not train_continue: - logger.info("Trainer finished, stopping...") - break - ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) - logger.info("Model weight synchronized.") - except Exception as e: - logger.error(e) - logger.error("Training stopped due to exception.") - raise e - if explore_step_num % config.explorer.eval_interval == 0: - try: - ray.get(explorer.eval.remote()) - logger.info("Evaluation finished.") - except Exception as e: - logger.error(e) - logger.error("Evaluation failed.") - raise e - ray.get(explorer.flush_log.remote(step=explore_step_num)) - ray.get(trainer.flush_log.remote(step=train_step_num)) - - ray.get(explorer.shutdown.remote()) - ray.get(trainer.shutdown.remote()) - - -def activate_data_module(data_workflow_url: str, config_path: str): + ray.get( + [ + explorer.sync_weight.remote(), + trainer.sync_weight.remote(), + ] + ) + ready_ref, wait_ref = ray.wait( + [ + explorer.explore.remote(), + trainer.train.remote(), + ], + num_returns=1, + ) + + ready = ray.get(ready_ref[0]) + if ready == TRAINER_NAME: + logger.info( + "===========================================================\n" + "> Launcher detected that the `Trainer` process has finished.\n" + "> Stopping the explorer process immediately.\n" + "===========================================================" + ) + ray.wait(wait_ref, timeout=5) + elif ready == EXPLORER_NAME: + logger.info( + "============================================================\n" + "> Launcher detected that the `Explorer` process has finished.\n" + f"> Waiting {config.synchronizer.sync_timeout} s for the trainer process...\n" + "> You can force stop the Trainer process by pressing Ctrl+C.\n" + "============================================================" + ) + ray.wait(wait_ref, timeout=config.synchronizer.sync_timeout) + explorer.shutdown.remote() + trainer.shutdown.remote() + + +def activate_data_module(data_processor_url: str, config_path: str): """Check whether to activate data module and preprocess datasets.""" from trinity.cli.client import request - logger.info("Activating data module...") + logger.info(f"Activating data module of {data_processor_url}...") res = request( - url=data_workflow_url, + url=data_processor_url, configPath=config_path, ) if res["return_code"] != 0: @@ -158,17 +126,71 @@ def activate_data_module(data_workflow_url: str, config_path: str): return +def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_type: str): + """ + Check if the data pipeline is valid. The config should: + 1. Non-empty input buffer + 2. Different input/output buffers + + :param data_pipeline_config: the input data pipeline to be validated. + :param pipeline_type: the type of pipeline, should be one of ["task", "experience"] + """ + input_buffers = data_pipeline_config.input_buffers + output_buffer = data_pipeline_config.output_buffer + # common checks + # check if the input buffer list is empty + if len(input_buffers) == 0: + logger.warning("Empty input buffers in the data pipeline. Won't activate it.") + return False + # check if the input and output buffers are different + input_buffer_names = [buffer.name for buffer in input_buffers] + if output_buffer.name in input_buffer_names: + logger.warning("Output buffer exists in input buffers. Won't activate it.") + return False + if pipeline_type == "task": + # task pipeline specific + # "raw" field should be True for task pipeline because the data source must be raw data files + for buffer in input_buffers: + if not buffer.raw: + logger.warning( + 'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.' + ) + return False + elif pipeline_type == "experience": + # experience pipeline specific + raise NotImplementedError("experience_pipeline is not implemented yet.") + else: + logger.warning( + f'Invalid pipeline type: {pipeline_type}. Should be one of ["task", "experience"].' + ) + return False + return True + + def run(config_path: str, dlc: bool = False, plugin_dir: str = None): load_plugins(plugin_dir) config = load_config(config_path) config.check_and_update() pprint(config) - # try to activate data module + # try to activate task pipeline for raw data data_processor_config = config.data_processor - if data_processor_config.data_workflow_url and ( - data_processor_config.dj_config_path or data_processor_config.dj_process_desc + if ( + data_processor_config.data_processor_url + and data_processor_config.task_pipeline + and validate_data_pipeline(data_processor_config.task_pipeline, "task") + ): + activate_data_module( + f"{data_processor_config.data_processor_url}/task_pipeline", config_path + ) + # try to activate experience pipeline for experiences + if ( + data_processor_config.data_processor_url + and data_processor_config.experience_pipeline + and validate_data_pipeline(data_processor_config.experience_pipeline, "experience") ): - activate_data_module(data_processor_config.data_workflow_url, config_path) + activate_data_module( + f"{data_processor_config.data_processor_url}/experience_pipeline", config_path + ) ray_namespace = f"{config.project}-{config.name}" if dlc: from trinity.utils.dlc_utils import setup_ray_cluster diff --git a/trinity/common/config.py b/trinity/common/config.py index 42c54c442a..f4480da311 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -7,7 +7,6 @@ from omegaconf import OmegaConf from trinity.common.constants import ( - AlgorithmType, PromptType, ReadStrategy, StorageType, @@ -71,6 +70,9 @@ class StorageConfig: storage_type: StorageType = StorageType.FILE path: Optional[str] = None + # only available for StorageType.FILE. When requiring data processing on raw data, set the raw to True. + raw: bool = False + # used for StorageType.FILE split: str = "train" subset_name: Optional[str] = None @@ -90,7 +92,7 @@ class StorageConfig: workflow_args: dict = field(default_factory=dict) # ! DO NOT SET, automatically set from algorithm.algorithm_type - algorithm_type: Optional[AlgorithmType] = None + algorithm_type: Optional[str] = None # ! DO NOT SET, automatically set from buffer.total_epochs total_epochs: int = 1 # automatically set @@ -100,16 +102,17 @@ class StorageConfig: @dataclass -class DataProcessorConfig: - """Data-Juicer config""" +class DataPipelineConfig: + """Config for data pipeline.""" - data_workflow_url: Optional[str] = None + # I/O buffer + input_buffers: List[StorageConfig] = field(default_factory=list) + output_buffer: StorageConfig = field(default_factory=StorageConfig) - source_data_path: str = "" + # data format format: FormatConfig = field(default_factory=FormatConfig) # data active iterator related - load_kwargs: Dict[str, Any] = field(default_factory=dict) dj_config_path: Optional[str] = None # The path to Data-Juicer config file. dj_process_desc: Optional[ str @@ -122,10 +125,18 @@ class DataProcessorConfig: priority_weights: Optional[Dict[str, float]] = None data_dist: Optional[str] = "gaussian" # one of ["gaussian", "uniform"] - # dataset database related - db_url: str = "" - max_retry_times: int = 3 - max_retry_interval: int = 1 + +@dataclass +class DataProcessorConfig: + """Data-Juicer config""" + + data_processor_url: Optional[str] = None + + # support two types of data pipelines for now + # 1. For task. Data preprocessing from raw dataset to the task set + task_pipeline: Optional[DataPipelineConfig] = None + # 2. For experience. Data processing for rollouts + experience_pipeline: Optional[DataPipelineConfig] = None @dataclass @@ -176,12 +187,36 @@ class InferenceModelConfig: class AlgorithmConfig: """Config for algorithm.""" - algorithm_type: AlgorithmType = AlgorithmType.PPO + algorithm_type: str = "ppo" # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 - gamma: Optional[float] = None - lam: Optional[float] = None - # TODO: add more algorithm params here + + sample_strategy: Optional[str] = None + sample_strategy_args: Optional[dict] = None + + advantage_fn: Optional[str] = None # "ppo" + # If not set, use AdvantageFn.default_args() + advantage_fn_args: Optional[dict] = None + + kl_penalty_fn: Optional[str] = None # "none" # set to "none" to disable kl penalty in reward + # If not set, use kl_penalty_fn.default_args() + kl_penalty_fn_args: Optional[dict] = None + + policy_loss_fn: Optional[str] = None # "ppo" + # If not set, use PolicyLossFn.default_args() + policy_loss_fn_args: Optional[dict] = None + + kl_loss_fn: Optional[str] = None # "k2" # set to "none" to disable kl loss + # If not set, use kl_loss_fn.default_args() + kl_loss_fn_args: Optional[dict] = None + + entropy_loss_fn: Optional[str] = None # "default" + # If not set, use entropy_loss_fn.default_args() + entropy_loss_fn_args: Optional[dict] = None + + # used for SFT warmup + # TODO: move this to SFT warmup + use_token_level_loss: bool = True @dataclass @@ -270,11 +305,7 @@ class TrainerConfig: enable_preview: bool = True # enable rollout preview in wandb # trainer configs - actor_use_kl_loss: Optional[bool] = None - actor_kl_loss_coef: Optional[float] = None - actor_entropy_coef: Optional[float] = None actor_grad_clip: Optional[float] = None - actor_clip_ratio: Optional[float] = None # TODO: extract more train-related params from underlying trainer engine # Only one needs to be set for `trainer_config` and `trainer_config_path` @@ -284,7 +315,7 @@ class TrainerConfig: @dataclass class MonitorConfig: - # TODO: support multiple monitors (List[MonitorType]) + # TODO: support multiple monitors (List[str]) monitor_type: str = "tensorboard" # the default args for monitor monitor_args: Dict = field(default_factory=dict) @@ -300,8 +331,10 @@ class SynchronizerConfig: sync_method: SyncMethod = SyncMethod.NCCL # sync weights every `sync_interval` steps sync_interval: int = 1 + # allow explorer to run `sync_offset` steps before sync + sync_offset: int = 0 # waiting for `sync_timeout` seconds before timeout in `nccl` method - sync_timeout: int = 1200 + sync_timeout: int = 1800 # wait for the lastest checkpoint to be ready # TODO: to be used wait_for_checkpoint: bool = False @@ -342,32 +375,25 @@ def _check_deprecated(self) -> None: def _check_interval(self) -> None: assert self.synchronizer.sync_interval > 0 - # check eval_interval - if ( - self.mode != "bench" - and self.algorithm.algorithm_type != AlgorithmType.DPO - and self.explorer.eval_interval % self.synchronizer.sync_interval != 0 - ): - self.explorer.eval_interval = ( - max(self.explorer.eval_interval // self.synchronizer.sync_interval, 1) - ) * self.synchronizer.sync_interval - logger.warning( - f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.explorer.eval_interval}." - ) - - # check save_interval - if ( - self.mode != "bench" - and self.algorithm.algorithm_type != AlgorithmType.DPO - and self.synchronizer.sync_method == SyncMethod.CHECKPOINT - ): - if self.trainer.save_interval != self.synchronizer.sync_interval: + if self.mode != "bench" and self.algorithm.algorithm_type != "dpo": # TODO + # check eval_interval + if self.explorer.eval_interval % self.synchronizer.sync_interval != 0: + self.explorer.eval_interval = ( + max(self.explorer.eval_interval // self.synchronizer.sync_interval, 1) + ) * self.synchronizer.sync_interval logger.warning( - f"When `algorithm.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, " - f"`trainer.save_interval` will be set to " - f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." + f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.explorer.eval_interval}." ) - self.trainer.save_interval = self.synchronizer.sync_interval + + # check save_interval + if self.synchronizer.sync_method == SyncMethod.CHECKPOINT: + if self.trainer.save_interval != self.synchronizer.sync_interval: + logger.warning( + f"When `algorithm.algorithm_type` != `dpo` and `synchronizer.sync_method` == `checkpoint`, " + f"`trainer.save_interval` will be set to " + f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." + ) + self.trainer.save_interval = self.synchronizer.sync_interval def _check_buffer(self) -> None: # noqa: C901 # TODO: split this function into different buffer read/writer @@ -437,14 +463,7 @@ def _check_buffer(self) -> None: # noqa: C901 ) self.buffer.trainer_input.experience_buffer.storage_type = StorageType.QUEUE elif self.mode == "train": # TODO: to be check - if self.algorithm.algorithm_type.is_dpo(): - if ( - self.buffer.trainer_input.experience_buffer is None - or not self.buffer.trainer_input.experience_buffer.path - ): - raise ValueError( - "`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == AlgorithmType.DPO`" - ) + pass if self.buffer.trainer_input.experience_buffer is not None: self.buffer.trainer_input.experience_buffer.algorithm_type = ( self.algorithm.algorithm_type @@ -462,10 +481,10 @@ def _check_buffer(self) -> None: # noqa: C901 and self.buffer.trainer_input.sft_warmup_dataset is None ): raise ValueError( - "buffer.trainer_input.sft_warmup_dataset is required when buffer.trainer_input.sft_warmup_steps > 0" + "`buffer.trainer_input.sft_warmup_dataset` is required when `buffer.trainer_input.sft_warmup_steps` > 0" ) if self.buffer.trainer_input.sft_warmup_dataset is not None: - self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT + self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = "sft" # TODO # set read_batch_size / pad_token_id / tokenizer_path self.buffer.read_batch_size = self.buffer.batch_size * self.algorithm.repeat_times @@ -490,15 +509,78 @@ def _check_buffer(self) -> None: # noqa: C901 f"your checkpoint directory: {self.checkpoint_job_dir}" ) + def _check_algorithm(self) -> None: + from trinity.algorithm import ( + ADVANTAGE_FN, + ENTROPY_LOSS_FN, + KL_FN, + POLICY_LOSS_FN, + SAMPLE_STRATEGY, + ) + from trinity.algorithm.algorithm import ALGORITHM_TYPE + + algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type) + algorithm.check_config(self) + default_config = { + "sample_strategy": "warmup", + "policy_loss_fn": "ppo", + "advantage_fn": "ppo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } + default_config.update(algorithm.default_config()) + for key, value in default_config.items(): + if getattr(self.algorithm, key, None) is None: + setattr(self.algorithm, key, value) + + # TODO: simplify the following code + sample_strategy_cls = SAMPLE_STRATEGY.get(self.algorithm.sample_strategy) + if sample_strategy_cls is None: + raise ValueError(f"Invalid sample_strategy: {self.algorithm.sample_strategy}") + if self.algorithm.sample_strategy_args is None: + self.algorithm.sample_strategy_args = sample_strategy_cls.default_args() + + policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn) + if policy_fn_cls is None: + raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}") + if self.algorithm.policy_loss_fn_args is None: + self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args() + + advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn) + if advantage_fn_cls is None: + raise ValueError(f"Invalid advantage_fn: {self.algorithm.advantage_fn}") + if self.algorithm.advantage_fn_args is None: + self.algorithm.advantage_fn_args = advantage_fn_cls.default_args() + + kl_loss_fn_cls = KL_FN.get(self.algorithm.kl_loss_fn) + if kl_loss_fn_cls is None: + raise ValueError(f"Invalid kl_loss_fn: {self.algorithm.kl_loss_fn}") + if self.algorithm.kl_loss_fn_args is None: + self.algorithm.kl_loss_fn_args = kl_loss_fn_cls.default_args() + + kl_penalty_fn_cls = KL_FN.get(self.algorithm.kl_penalty_fn) + if kl_penalty_fn_cls is None: + raise ValueError(f"Invalid kl_penalty_fn: {self.algorithm.kl_penalty_fn}") + if self.algorithm.kl_penalty_fn_args is None: + self.algorithm.kl_penalty_fn_args = kl_penalty_fn_cls.default_args() + + entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(self.algorithm.entropy_loss_fn) + if entropy_loss_fn_cls is None: + raise ValueError(f"Invalid entropy_loss_fn: {self.algorithm.entropy_loss_fn}") + if self.algorithm.entropy_loss_fn_args is None: + self.algorithm.entropy_loss_fn_args = entropy_loss_fn_cls.default_args() + def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" self._check_deprecated() + # check algorithm + self._check_algorithm() + # check mode if self.mode not in ["explore", "train", "both", "bench"]: raise ValueError(f"Invalid mode: {self.mode}") - if self.algorithm.algorithm_type == AlgorithmType.DPO and self.mode == "both": - raise ValueError("DPO does not support `both` mode") # prepare for the checkpoint directory if not os.path.isabs(self.checkpoint_root_dir): @@ -537,17 +619,6 @@ def check_and_update(self) -> None: # noqa: C901 logger.warning( f"`{self.mode}` mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." ) - if ( - self.algorithm.algorithm_type == AlgorithmType.DPO - and self.synchronizer.sync_method != SyncMethod.CHECKPOINT - ): - self.synchronizer.sync_method = SyncMethod.CHECKPOINT - logger.warning( - "DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." - ) - if self.algorithm.algorithm_type == AlgorithmType.DPO and self.algorithm.repeat_times != 2: - self.algorithm.repeat_times = 2 - logger.warning("DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2.") self._check_interval() diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 5eaa3b58ff..9a428131fe 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -8,6 +8,9 @@ # names +EXPLORER_NAME = "explorer" +TRAINER_NAME = "trainer" + ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync" @@ -65,34 +68,6 @@ class StorageType(CaseInsensitiveEnum): FILE = "file" -class AlgorithmType(CaseInsensitiveEnum): - """Algorithm Type.""" - - SFT = "sft" - PPO = "ppo" - GRPO = "grpo" - OPMD = "opmd" - PAIRWISE_OPMD = "pairwise_opmd" - DPO = "dpo" - - def is_rft(self) -> bool: - """Check if the algorithm is RFT.""" - return self in [ - AlgorithmType.PPO, - AlgorithmType.GRPO, - AlgorithmType.OPMD, - AlgorithmType.PAIRWISE_OPMD, - ] - - def is_sft(self) -> bool: - """Check if the algorithm is SFT.""" - return self == AlgorithmType.SFT - - def is_dpo(self) -> bool: - """Check if the algorithm is DPO.""" - return self == AlgorithmType.DPO - - class MonitorType(CaseInsensitiveEnum): """Monitor Type.""" @@ -120,3 +95,11 @@ class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta): NCCL = "nccl" CHECKPOINT = "checkpoint" + + +class RunningStatus(Enum): + """Running status of explorer and trainer.""" + + RUNNING = "running" + WAITING_SYNC = "waiting_sync" + STOPPED = "stopped" diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 8b80a71bd6..fd5670b390 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -103,6 +103,7 @@ def create_inference_models( num_gpus=0 if config.explorer.rollout_model.tensor_parallel_size > 1 else 1, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, + placement_group_capture_child_tasks=True, placement_group_bundle_index=bundles_for_engine[0], ), ) @@ -121,6 +122,7 @@ def create_inference_models( bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size) model_config.enable_openai_api = True model_config.engine_type = "vllm_async" + model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) engines.append( ray.remote(vLLMAysncRolloutModel) .options( @@ -128,6 +130,7 @@ def create_inference_models( num_gpus=0 if model_config.tensor_parallel_size > 1 else 1, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, + placement_group_capture_child_tasks=True, placement_group_bundle_index=bundles_for_engine[0], ), ) diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index a8751e7240..5cc770e64f 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -156,7 +156,6 @@ def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None iteration = f.read().strip() return os.path.join(checkpoint_path, f"global_step_{iteration}") else: - logger.error(f"No iteration file found in {checkpoint_path}") raise FileNotFoundError(f"No iteration file found in {checkpoint_path}") else: # load specific iteration checkpoint diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 27faa4c44a..8a8a089afa 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -100,6 +100,7 @@ def __init__( self.action_mask_method = tokenize_and_mask_messages_default else: self.action_mask_method = tokenize_and_mask_messages_hf + self.state_dict_meta = None self.ckp_version = 0 # TODO: resume the value from the checkpoint self.api_server_host = None self.api_server_port = None @@ -264,10 +265,11 @@ async def _collective_rpc( method, timeout, args, kwargs ) - async def sync_model(self, update_weight_args_list) -> bool: + async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool: """Sync model weights to vLLM.""" - for args in update_weight_args_list: - await self._collective_rpc("update_weight", args=args) + if update_weight_args_list is not None: + await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) + await self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.ckp_version += 1 return True @@ -282,6 +284,7 @@ async def init_process_group( backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, + state_dict_meta: dict = None, ): return await self._collective_rpc( "init_process_group", @@ -294,12 +297,10 @@ async def init_process_group( backend, timeout, update_with_checkpoint, + state_dict_meta, ), ) - async def update_weight(self, name, dtype, shape, empty_cache=False): - return await self._collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) - async def run_api_server(self): """Run the OpenAI API server in a Ray actor. diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 9459cd7511..878fe0bd9c 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -8,7 +8,7 @@ import os import re import threading -from typing import List +from typing import List, Optional, Tuple import torch import vllm @@ -85,6 +85,7 @@ def __init__(self, config: InferenceModelConfig): else: self.action_mask_method = tokenize_and_mask_messages_hf self.lock = threading.Lock() + self.state_dict_meta = None self.ckp_version = 0 # TODO: resume the value from the checkpoint def init_process_group( @@ -97,6 +98,7 @@ def init_process_group( backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, + state_dict_meta: dict = None, ): return self.llm.collective_rpc( "init_process_group", @@ -109,12 +111,10 @@ def init_process_group( backend, timeout, update_with_checkpoint, + state_dict_meta, ), ) - def update_weight(self, name, dtype, shape, empty_cache=False): - return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) - def reset_prefix_cache(self): self.llm.llm_engine.reset_prefix_cache() @@ -274,11 +274,11 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: def has_api_server(self) -> bool: return False - def sync_model(self, update_weight_args_list) -> bool: + def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool: """Sync model weights to vLLM.""" - with self.lock: - for args in update_weight_args_list: - self.llm.collective_rpc("update_weight", args=args) + if update_weight_args_list is not None: + self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) + self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.ckp_version += 1 return True diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 4a32628a96..883e470381 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -4,6 +4,7 @@ import torch import torch.distributed +from trinity.common.constants import EXPLORER_NAME from trinity.utils.distributed import init_process_group, is_ipv6_address from trinity.utils.log import get_logger @@ -21,22 +22,21 @@ def init_process_group( backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, + state_dict_meta: list = None, ): """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), "default torch process group must be initialized" assert group_name != "", "group name must not be empty" + self.set_state_dict_meta(state_dict_meta) self._update_with_checkpoint = update_with_checkpoint - if self._update_with_checkpoint: - logger.info( - f"init_process_group (checkpoint): address={master_address}:{master_port}, rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}" - ) - self._weight_update_rank = torch.distributed.get_rank() + rank_offset - else: - logger.info( - f"init_process_group (nccl): rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}" - ) - self._weight_update_rank = torch.distributed.get_rank() + rank_offset - + self._weight_update_rank = torch.distributed.get_rank() + rank_offset + logger.info( + f"vLLM starting init_process_group ({'checkpoint' if self._update_with_checkpoint else 'nccl'}):\n" + f" > address={master_address}:{master_port}\n" + f" > rank={torch.distributed.get_rank()}\n" + f" > rank_offset={rank_offset}\n" + f" > world_size={world_size}" + ) if is_ipv6_address(master_address): # using tcp://ipv6:port will lead to ValueError init_method = f"tcp://[{master_address}]:{master_port}" @@ -51,24 +51,28 @@ def init_process_group( rank=self._weight_update_rank, group_name=group_name, ) - logger.info( - f"init_process_group: master_address={master_address}, master_port={master_port}, " - f"rank={self._weight_update_rank}, world_size={world_size}, group_name={group_name}" - ) + logger.info("vLLM init_process_group finished.") self._explorer_actor = None - def update_weight(self, name, dtype, shape, empty_cache=False): - """Broadcast weight to all vllm workers from source rank 0 (actor model)""" - if self._weight_update_rank == 0: - if self._explorer_actor is None: - self._explorer_actor = ray.get_actor(name="explorer") - weight = ray.get(self._explorer_actor.get_weight.remote(name)) - weight = weight.to(self.device) - else: - weight = torch.empty(shape, dtype=dtype, device="cuda") + def set_state_dict_meta(self, state_dict_meta): + self._state_dict_meta = state_dict_meta - torch.distributed.broadcast(weight, 0, group=self._model_update_group) - weight = weight.type(self.model_config.dtype) - - self.model_runner.model.load_weights(weights=[(name, weight)]) - del weight + def update_weight(self): + """Broadcast weight to all vllm workers from source rank 0 (actor model)""" + assert self._state_dict_meta is not None + if self._explorer_actor is None: + self._explorer_actor = ray.get_actor(name=EXPLORER_NAME) + for name, dtype_str, shape in self._state_dict_meta: + if self._weight_update_rank == 0: + weight = ray.get(self._explorer_actor.get_weight.remote(name)) + weight = weight.to(self.device) + else: + dtype = getattr(torch, dtype_str.split(".")[-1]) + weight = torch.empty(shape, dtype=dtype, device=self.device) + torch.distributed.broadcast(weight, 0, group=self._model_update_group) + weight = weight.type(self.model_config.dtype) + self.model_runner.model.load_weights(weights=[(name, weight)]) + del weight + torch.distributed.barrier() + torch.cuda.synchronize() + torch.cuda.empty_cache() diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index e5d0d9d55f..1ec0653503 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -5,8 +5,6 @@ from omegaconf import OmegaConf from trinity.common.config import BufferConfig, Config, SynchronizerConfig -from trinity.common.constants import AlgorithmType -from trinity.trainer.verl.ray_trainer import AdvantageEstimator from trinity.utils.log import get_logger logger = get_logger(__name__) @@ -34,8 +32,7 @@ class Optim: min_lr_ratio: Optional[float] = 0.0 warmup_style: str = "constant" total_training_steps: int = -1 - beta1: float = 0.9 - beta2: float = 0.999 + betas: List[float] = field(default_factory=lambda: [0.9, 0.999]) @dataclass @@ -68,21 +65,19 @@ class Actor: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} ) grad_clip: float = 1.0 - clip_ratio: float = 0.2 - entropy_coeff: float = 0.001 - use_kl_loss: bool = False - kl_loss_coef: float = 0.001 - kl_loss_type: str = "low_var_kl" ppo_epochs: int = 1 shuffle: bool = False ulysses_sequence_parallel_size: int = 1 checkpoint: Checkpoint = field(default_factory=Checkpoint) optim: Optim = field(default_factory=Optim) fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) - algorithm_type: AlgorithmType = AlgorithmType.PPO - tau: float = 0.001 # strength of regularization w.r.t. old / ref policy - opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd - use_uid: bool = False # True / False, applicable to pairwise_opmd + # do not set + loss_agg_mode: str = "token-mean" + clip_ratio: float = 0.2 + entropy_coeff: float = 0.001 + use_kl_loss: bool = False + kl_loss_coef: float = 0.001 + kl_loss_type: str = "low_var_kl" @dataclass @@ -95,10 +90,25 @@ class Ref: ulysses_sequence_parallel_size: int = 1 +@dataclass +class _ValKwargs: + do_sample: bool = False + + +@dataclass +class _MultiTurn: + enable: bool = False + + @dataclass class Rollout: + # do not set + val_kwargs: _ValKwargs = field(default_factory=_ValKwargs) + multi_turn: _MultiTurn = field(default_factory=_MultiTurn) temperature: float = 1.0 n: int = 1 # > 1 for grpo + log_prob_micro_batch_size: Optional[int] = None + log_prob_micro_batch_size_per_gpu: int = 1 @dataclass @@ -142,6 +152,7 @@ class Critic: cliprange_value: float = 0.0 checkpoint: Checkpoint = field(default_factory=Checkpoint) rollout_n: int = 1 + loss_agg_mode: str = "token-mean" @dataclass @@ -182,6 +193,9 @@ class KL_Ctrl: @dataclass class Algorithm: + # ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl, + # and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args + # if they are really needed (e.g., for GAE advantage/returns computation) gamma: float = 1.0 lam: float = 1.0 adv_estimator: str = "gae" @@ -299,37 +313,21 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.critic.ppo_mini_batch_size = config.buffer.batch_size self.critic.rollout_n = self.actor_rollout_ref.rollout.n - if config.trainer.actor_use_kl_loss is not None: - self.actor_rollout_ref.actor.use_kl_loss = config.trainer.actor_use_kl_loss - if config.trainer.actor_kl_loss_coef is not None: - self.actor_rollout_ref.actor.kl_loss_coef = config.trainer.actor_kl_loss_coef - if config.trainer.actor_entropy_coef is not None: - self.actor_rollout_ref.actor.entropy_coeff = config.trainer.actor_entropy_coef if config.trainer.actor_grad_clip is not None: self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip - if config.trainer.actor_clip_ratio is not None: - self.actor_rollout_ref.actor.clip_ratio = config.trainer.actor_clip_ratio # Algorithm related config - if config.algorithm.gamma is not None: - self.algorithm.gamma = config.algorithm.gamma - if config.algorithm.lam is not None: - self.algorithm.lam = config.algorithm.lam - self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type - if config.algorithm.algorithm_type == AlgorithmType.PPO: - logger.info("Using GAE `adv_estimator` for PPO") - self.algorithm.adv_estimator = AdvantageEstimator.GAE.value - elif config.algorithm.algorithm_type == AlgorithmType.GRPO: - logger.info("Using GRPO `adv_estimator` for GRPO") - self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value - - if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO - if not self.actor_rollout_ref.actor.use_kl_loss: - self.actor_rollout_ref.actor.use_kl_loss = True - logger.warning("DPO must use KL loss.") + self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none" + self.algorithm.use_kl_in_reward = config.algorithm.kl_penalty_fn != "none" + # TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to + # True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper). + # Need to double check whether this is indeed the case, + # and see if adv_estimator can be removed completely. + + if config.algorithm.algorithm_type == "dpo": # for DPO logger.warning("DPO micro batch size is doubled for computing loss.") - self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2 # type: ignore - self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 # type: ignore + self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2 + self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 if self.actor_rollout_ref.rollout.n != 2: self.actor_rollout_ref.rollout.n = 2 # TODO: check other fields diff --git a/trinity/data/controllers/active_iterator.py b/trinity/data/controllers/active_iterator.py index 40da73384b..963a1015a9 100644 --- a/trinity/data/controllers/active_iterator.py +++ b/trinity/data/controllers/active_iterator.py @@ -1,14 +1,14 @@ import os import traceback +from numbers import Number from typing import Any, Dict, List import ray -from trinity.common.config import Config +from trinity.common.config import BufferConfig, DataPipelineConfig from trinity.data.controllers.default_ops import DIMENSION_STATS_KEYS from trinity.data.controllers.task_parser import DataTaskParser from trinity.data.core.dataset import RftDataset -from trinity.data.core.dataset_db import RftDatasetDB from trinity.data.processors.cleaner import DataCleaner from trinity.data.processors.human_annotator import DataHumanAnnotator from trinity.data.processors.synthesizer import DataSynthesizer @@ -21,42 +21,39 @@ class DataActiveIterator: def __init__( self, - config: Config, + config: DataPipelineConfig, + buffer_config: BufferConfig, ): self.config = config - self.data_config = config.data - if ( - self.data_config.agent_model_name is not None - and self.data_config.agent_model_config is not None - ): + self.buffer_config = buffer_config + if self.config.agent_model_name is not None and self.config.agent_model_config is not None: # get the api key api_key = os.environ.get("OPENAI_API_KEY") # initialize the agent import agentscope from agentscope.models import DashScopeChatWrapper - agentscope.init(model_configs=[self.data_config.agent_model_config]) + agentscope.init(model_configs=[self.config.agent_model_config]) self.llm_agent = DashScopeChatWrapper( config_name="_", - model_name=self.data_config.agent_model_name, + model_name=self.config.agent_model_name, api_key=api_key, stream=False, ) else: self.llm_agent = None self.task_parser = DataTaskParser(config, self.llm_agent) - self.dsdb = RftDatasetDB(self.data_config) # Priority weights # larger positive values means larger scores --> higher priority # smaller negative values means lower scores --> higher priority - self.priority_weights = self.data_config.priority_weights or { + self.priority_weights = self.config.priority_weights or { "difficulty": -0.7, "diversity": 0.8, "usage_frequency": -0.5, "quality": 1.0, } - self.min_priority_score = self.data_config.min_priority_score + self.min_priority_score = self.config.min_priority_score # Statistics tracking self.state = {"iterations": 0, "samples_selected": 0, "avg_priority_score": 0.0} @@ -67,17 +64,17 @@ def __init__( # 2. input_keys: [prompt_key, response_key] if they are available # 3. field_names: [prompt_key, response_key] if they are available self.updated_op_args = { - "text_key": self.data_config.format.prompt_key, + "text_key": self.config.format.prompt_key, "input_keys": [ - self.data_config.format.prompt_key, + self.config.format.prompt_key, ], "field_names": [ - self.data_config.format.prompt_key, + self.config.format.prompt_key, ], } - if self.data_config.format.response_key != "": - self.updated_op_args["input_keys"].append(self.data_config.format.response_key) - self.updated_op_args["field_names"].append(self.data_config.format.response_key) + if self.config.format.response_key != "": + self.updated_op_args["input_keys"].append(self.config.format.response_key) + self.updated_op_args["field_names"].append(self.config.format.response_key) # flake8: noqa: C901 def run(self): @@ -94,9 +91,9 @@ def run(self): traceback.print_exc() return 1, "config parsing failed." - # step 2. load dataset + # step 2. load data from the input buffers try: - dataset = RftDataset(self.data_config) + dataset = RftDataset(self.config, self.buffer_config) except Exception: traceback.print_exc() return 2, "RftDataset loading failed." @@ -106,9 +103,9 @@ def run(self): if hit_cleaner: cleaner = DataCleaner( dj_config, - clean_strategy=self.data_config.clean_strategy, - min_size_ratio=self.data_config.min_size_ratio, - data_dist=self.data_config.data_dist, + clean_strategy=self.config.clean_strategy, + min_size_ratio=self.config.min_size_ratio, + data_dist=self.config.data_dist, ) if hit_synthesizer: synthesizer = DataSynthesizer( @@ -122,43 +119,61 @@ def run(self): traceback.print_exc() return 3, "DataCleaner loading failed." - # step 4. apply processors to calculate scores of different dimensions - try: - res_dataset = dataset - if hit_cleaner: - res_dataset = cleaner.process([res_dataset]) - if hit_synthesizer: - res_dataset = synthesizer.process([res_dataset]) - if hit_human_annotator: - res_dataset = human_annotator.process([res_dataset]) - except Exception: - traceback.print_exc() - return 4, "DataProcessors processing failed." - - # step 5. calculate the average and final scores, including priority - try: - if hit_cleaner: - scored_dataset = self._group_scores(res_dataset) - scored_dataset = self._compute_priority_scores(scored_dataset) - else: - scored_dataset = res_dataset - except Exception: - traceback.print_exc() - return 5, "Grouping and computing priority score failed." - - # step 6. track lineage if they are changed - try: - res_dataset = scored_dataset - except Exception: - traceback.print_exc() - return 6, "Tracking lineage failed." - - # step 7. export the result to the database - try: - self.dsdb.add_entries(res_dataset) - except Exception: - traceback.print_exc() - return 7, "Exporting result to database failed." + while True: + # step 4. load data from the input buffers for the next batch + try: + dataset.read_from_buffer() + except StopIteration: + break + except Exception: + traceback.print_exc() + return 4, "RftDataset loading from buffers failed." + + # step 5. apply processors to calculate scores of different dimensions + try: + res_dataset = dataset + if hit_cleaner: + res_dataset = cleaner.process([res_dataset]) + if hit_synthesizer: + res_dataset = synthesizer.process([res_dataset]) + if hit_human_annotator: + res_dataset = human_annotator.process([res_dataset]) + except Exception: + traceback.print_exc() + return 5, "DataProcessors processing failed." + + # step 6. calculate the average and final scores, including priority + try: + if hit_cleaner: + scored_dataset = self._group_scores(res_dataset) + scored_dataset = self._compute_priority_scores(scored_dataset) + else: + scored_dataset = res_dataset + except Exception: + traceback.print_exc() + return 6, "Grouping and computing priority score failed." + + # step 7. track lineage if they are changed + try: + res_dataset = scored_dataset + except Exception: + traceback.print_exc() + return 7, "Tracking lineage failed." + + # step 8 + try: + if "priority" in res_dataset.data.features: + res_dataset.sort_by("priority", reverse=True) + except Exception: + traceback.print_exc() + return 8, "Sorting results by priority failed." + + # step 9. sort and export the result to the output buffer + try: + res_dataset.write_to_buffer() + except Exception: + traceback.print_exc() + return 9, "Exporting result to output buffer failed." return 0, "success" @@ -171,7 +186,8 @@ def _group_scores(self, dataset: RftDataset) -> RftDataset: all_stats = [ sample[Fields.stats][stats] for sample in dataset.data if Fields.stats in sample ] - stats_min_max[stats] = [min(all_stats), max(all_stats)] + if len(all_stats) > 0 and isinstance(all_stats[0], Number): + stats_min_max[stats] = [min(all_stats), max(all_stats)] def _group_single(sample): stats = sample[Fields.stats] @@ -240,7 +256,7 @@ def _compute_combined_score( difficulty = stats.get("difficulty_score", 0.5) score += self.priority_weights["difficulty"] * difficulty - sample["priority"] = [score] + sample["priority"] = [score] if isinstance(sample[Fields.stats], list) else score return sample def _compute_diversity_score(self) -> float: @@ -252,10 +268,6 @@ def _compute_priority_scores(self, dataset: RftDataset) -> RftDataset: dataset.data = dataset.data.map(self._compute_combined_score) return dataset - def _select_top_k(self, dataset: RftDataset, k: int) -> List: - """Select top-k samples based on utility scores""" - return dataset.data.sort("priority", reverse=True).take(k).to_list() - @ray.method(num_returns=1) def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, Any]]: """Select a batch of samples for training""" @@ -267,7 +279,8 @@ def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, A dataset.data = dataset.data.filter(lambda s: s["priority"] >= self.min_priority_score) # Select top-k samples - selected_samples = self._select_top_k(dataset, batch_size) + dataset.sort_by("priority", reverse=True, top_k=batch_size) + selected_samples = dataset.data.to_list() # Update state self._update_state(selected_samples, dataset.data["priority"]) diff --git a/trinity/data/controllers/task_parser.py b/trinity/data/controllers/task_parser.py index 23b169ab2d..2e30dace63 100644 --- a/trinity/data/controllers/task_parser.py +++ b/trinity/data/controllers/task_parser.py @@ -7,7 +7,7 @@ from jsonargparse import Namespace from loguru import logger -from trinity.common.config import Config +from trinity.common.config import DataPipelineConfig from trinity.data.core.dataset import RftDataset from .default_ops import ( @@ -128,7 +128,7 @@ class DataTaskParser: def __init__( self, - rft_config: Config, + data_pipeline_config: DataPipelineConfig, llm_agent: DashScopeChatWrapper = None, dataset: RftDataset = None, validate_config: bool = True, @@ -136,12 +136,12 @@ def __init__( """ Initialization method. - :param rft_config: All configs. + :param data_pipeline_config: All configs of specified data pipeline. :param llm_agent: The LLM agent for natural language parsing. :param dataset: The dataset to be processed. :param validate_config: If execute the config validation check. """ - self.config = rft_config.data + self.config = data_pipeline_config self.llm_agent = llm_agent self.validate_config = validate_config # TODO: refer dataset to support natural language parsing. @@ -164,15 +164,21 @@ def parse_to_dj_config(self, extra_op_args=None): return dj_config, hit_cleaner, hit_synthesizer, hit_human_annotator def _check_types_of_processors(self, dj_config): + if dj_config is None: + return False, False, False hit_cleaner, hit_synthesizer, hit_human_annotator = False, False, False - for op in dj_config.process: + process_list = dj_config.get("process", []) + for op in process_list: op_name = list(op.keys())[0] - if op_name in DEFAULT_CLEANER: - hit_cleaner = True - elif op_name in DEFAULT_SYNTHESIZER: + if op_name in DEFAULT_SYNTHESIZER: hit_synthesizer = True elif op_name in DEFAULT_HUMAN_ANNOTATOR: hit_human_annotator = True + else: + for dimension in DEFAULT_CLEANER: + if op_name in DEFAULT_CLEANER[dimension]: + hit_cleaner = True + break return hit_cleaner, hit_synthesizer, hit_human_annotator def _update_common_op_args(self, dj_config: Namespace, extra_op_args: Dict) -> Namespace: @@ -185,20 +191,10 @@ def _update_common_op_args(self, dj_config: Namespace, extra_op_args: Dict) -> N print(op) return dj_config - def _add_extra_args(self, dj_config: Namespace, op_weights: Dict = {}) -> Namespace: - """Add extra argument for RFT project""" - for op in dj_config.process: - op_name = list(op.keys())[0] - if "op_weight" not in op[op_name]: - op[op_name]["op_weight"] = op_weights[op_name] if op_name in op_weights else 1 - op[op_name]["op_weight"] = max(0, op[op_name]["op_weight"]) - return dj_config - def _direct_mapping(self) -> Namespace: """Direct mapping from RFT config to DJ config""" dj_config = prepare_side_configs(self.config.dj_config_path) dj_config = get_init_configs(dj_config) - dj_config = self._add_extra_args(dj_config) return dj_config def _agent_based_parsing(self, extra_op_args=None, try_num=3) -> Namespace: @@ -251,13 +247,11 @@ def _parse_llm_response(self, response: ModelResponse, extra_op_args=None): other_op_args = DEFAULT_OP_ARGS dj_process = [] - op_weights = {} def json_to_dj_config(parsed_json): for dim in set(parsed_json.keys()) & set(cleaners.keys()): for op_name in set(parsed_json[dim].keys()) & set(cleaners[dim].keys()): dj_process.append({op_name: {}}) - op_weights[op_name] = float(parsed_json[dim][op_name]) json_match = re.search(r"```json\n(.*?)\n```", response.text, re.DOTALL) if json_match: @@ -284,20 +278,5 @@ def json_to_dj_config(parsed_json): op[op_name][key] = val dj_config = Namespace(process=dj_process) dj_config = get_init_configs(dj_config) - dj_config = self._add_extra_args(dj_config, op_weights) - - if self.validate_config and not self._validate_config(dj_config): - return None return dj_config - - def _validate_config(self, config: Namespace) -> bool: - """Validate generated DJ config""" - try: - for op in config.process: - op_name = list(op.keys())[0] - weight = float(op[op_name]["op_weight"]) - assert 0 <= weight and weight <= 1 - except Exception: - return False - return True diff --git a/trinity/data/core/dataset.py b/trinity/data/core/dataset.py index 3e4af0fe12..93be832cc7 100644 --- a/trinity/data/core/dataset.py +++ b/trinity/data/core/dataset.py @@ -3,13 +3,10 @@ from typing import Any, Dict, List, Optional, Union import networkx as nx -from data_juicer.core.data.dj_dataset import Dataset -from datasets import load_dataset +from datasets import Dataset, concatenate_datasets -from trinity.common.config import DataProcessorConfig -from trinity.common.rewards import REWARD_FUNCTIONS -from trinity.common.task import TaskSet -from trinity.common.workflows import WORKFLOWS +from trinity.buffer import get_buffer_reader, get_buffer_writer +from trinity.common.config import BufferConfig, DataPipelineConfig, StorageConfig from trinity.data.core.formatter import BaseDataFormatter @@ -31,25 +28,27 @@ class RftDataset: 4. Basic statistics and metrics computation Args: - config (Dict): Configuration dict including DJ config + data_pipeline_config (DataPipelineConfig): Configuration including DJ config reward_schema (Union[str, Dict]): Schema definition for reward fields track_lineage (bool): Whether to track data lineage """ def __init__( self, - data_config: DataProcessorConfig, + data_pipeline_config: DataPipelineConfig, + buffer_config: BufferConfig = None, reward_schema: Union[str, Dict] = "default", track_lineage: bool = True, ): - self.config = data_config - source_data_path = data_config.source_data_path - if not source_data_path: - raise ValueError("source_data_path is not specified in DJ config") - load_kwargs = data_config.load_kwargs - self.data = load_dataset(source_data_path, trust_remote_code=True, **load_kwargs) - - self.format = data_config.format + self.config = data_pipeline_config + self.buffer_config = buffer_config + input_buffer_configs = self.config.input_buffers + if len(input_buffer_configs) == 0: + raise ValueError("input_buffers is empty in data pipeline config") + self.buffers = [] + for input_buffer_config in input_buffer_configs: + self.buffers.append(get_buffer_reader(input_buffer_config, self.buffer_config)) + self.data = Dataset.from_list([]) self.reward_schema = self._init_reward_schema(reward_schema) self.stats: Dict[str, Any] = {} @@ -65,15 +64,28 @@ def format( for formatter in formatters: self.data = formatter(self.data, num_proc) - def to_taskset(self, **kwargs) -> TaskSet: - default_workflow_cls = WORKFLOWS.get(self.config.default_workflow_type) - default_reward_fn_cls = REWARD_FUNCTIONS.get(self.config.default_reward_fn_type) - return TaskSet( - dataset=self.data, - config=self.config, - workflow=default_workflow_cls, - reward_fn=default_reward_fn_cls, - ) + def sort_by(self, key: str, reverse: bool = False, top_k: int = -1): + if top_k == -1: + top_k = len(self.data) + self.data = self.data.sort(key, reverse=reverse).take(top_k) + + def read_from_buffer(self): + datasets = [] + for buffer in self.buffers: + datasets.append(Dataset.from_list(buffer.read())) + self.data = concatenate_datasets(datasets) + + def write_to_buffer( + self, output_storage_config: StorageConfig = None, buffer_config: BufferConfig = None + ): + if output_storage_config is None: + output_storage_config = self.config.output_buffer + if buffer_config is None: + buffer_config = self.buffer_config + output_buffer = get_buffer_writer(output_storage_config, buffer_config) + output_buffer.write(self.data.to_list()) + output_buffer.finish() + self.data = Dataset.from_list([]) def to_parquet(self, path: str): self.data.to_parquet(path) diff --git a/trinity/data/core/dataset_db.py b/trinity/data/core/dataset_db.py deleted file mode 100644 index f47b138995..0000000000 --- a/trinity/data/core/dataset_db.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import List - -from sqlalchemy import asc, create_engine, desc -from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool - -from trinity.buffer.utils import retry_session -from trinity.common.config import DataProcessorConfig -from trinity.common.schema import Base, RftDatasetModel -from trinity.data.core.dataset import RftDataset -from trinity.utils.log import get_logger - -logger = get_logger(__name__) - - -def rft_dataset_to_model(dataset: RftDataset) -> List[RftDatasetModel]: - # hit keys of schema - hit_schema_keys = [] - hit_dataset_keys = [] - # get hit keys & vals - # - for content keys, we need to map it with content_key_mapping and try to - # find them in the dataset - # - for other keys, we just need to check if they are in the dataset - data = dataset.data - features = data.features - content_key_mapping = dataset.format.__dict__ - schema_keys = {key for key in RftDatasetModel.__dict__.keys() if not key.startswith("_")} - for schema_key in schema_keys: - key = schema_key - if f"{schema_key}_key" in content_key_mapping: - key = content_key_mapping[f"{schema_key}_key"] - if key in features: - hit_schema_keys.append(schema_key) - hit_dataset_keys.append(key) - # construct entries - entries = [] - for sample in data: - valid_data = { - schema_key: sample[key] for schema_key, key in zip(hit_schema_keys, hit_dataset_keys) - } - entries.append(RftDatasetModel(**valid_data)) - return entries - - -class RftDatasetDB: - def __init__(self, config: DataProcessorConfig) -> None: - self.db_url = config.db_url - self.engine = create_engine(self.db_url, poolclass=NullPool) - self.config = config - try: - Base.metadata.create_all(self.engine, checkfirst=True) - except OperationalError: - logger.warning("Failed to create database, assuming it already exists.") - self.session = sessionmaker(bind=self.engine) - - def add_entries(self, dataset: RftDataset): - with retry_session( - self, self.config.max_retry_times, self.config.max_retry_interval - ) as session: - session.add_all(rft_dataset_to_model(dataset)) - - def get_entries(self, num_entries: int, order_by: str = None, ascending: bool = False): - # get num_entries entries from the database - if order_by is not None and hasattr(RftDatasetModel, order_by): - order_by_key = getattr(RftDatasetModel, order_by) - order_by_key = asc(order_by_key) if ascending else desc(order_by_key) - else: - order_by_key = None - with retry_session( - self, self.config.max_retry_times, self.config.max_retry_interval - ) as session: - entries = ( - session.query(RftDatasetModel) - .order_by(order_by_key) - .limit(num_entries) - .with_for_update() - .all() - ) - - for entry in entries: - entry.consumed_cnt += 1 - samples = [entry.to_dict() for entry in entries] - return samples diff --git a/trinity/data/processors/cleaner.py b/trinity/data/processors/cleaner.py index b031e528e1..10979990b1 100644 --- a/trinity/data/processors/cleaner.py +++ b/trinity/data/processors/cleaner.py @@ -36,6 +36,7 @@ def __init__( clean_strategy: str = "iterative", min_size_ratio: PositiveFloat = None, data_dist: str = "gaussian", + op_weights: dict = None, **kwargs, ): """ @@ -54,6 +55,7 @@ def __init__( self.min_size_ratio = min_size_ratio self.data_dist = data_dist self.op_name_to_stats_key = {} + self.op_weights = op_weights def keep_cleaner_op_cfg(self, dj_cfg): """Only consider cleaner op in data-juicer configs.""" @@ -112,7 +114,7 @@ def update_op_threshold( update_record = {} for process in exe_cfg.process: op_name, args = list(process.items())[0] - op_weight = args["op_weight"] + op_weight = self.op_weights.get(op_name, 1) update_record[op_name] = {} temp_args = copy.deepcopy(args) @@ -164,7 +166,7 @@ def process( else: logger.info("Executing Data-Juicer analyzer...") analyzer = Analyzer(self.dj_cfg) - analyzer.run(dataset) + analyzer.run(dataset, skip_export=True) df = analyzer.overall_result mean_series = df[df.index == "mean"] stats_key_to_mean = mean_series.iloc[0, :].to_dict() diff --git a/trinity/data/readme.md b/trinity/data/readme.md index 3294819f43..4b5c828ee6 100644 --- a/trinity/data/readme.md +++ b/trinity/data/readme.md @@ -88,14 +88,14 @@ synth_data = synthesizer.process(clean_data) - Then you need to prepare the `data_processor` section in the config file (e.g. [test_cfg.yaml](tests/test_configs/active_iterator_test_cfg.yaml)) - For the `dj_config_path` argument in it, you can either specify a data-juicer config file path (e.g. [test_dj_cfg.yaml](tests/test_configs/active_iterator_test_dj_cfg.yaml)), or write the demand in `dj_process_desc` argument in natural language and our agent will help you to organize the data-juicer config. - Finally you can send requests to the data server to start an active iterator to process datasets in many ways: - - Request with `curl`: `curl "http://127.0.0.1:5000/data_workflow?configPath=tests%2Ftest_configs%2Factive_iterator_test_cfg.yaml"` + - Request with `curl`: `curl "http://127.0.0.1:5005/data_processor/task_pipeline?configPath=tests%2Ftest_configs%2Factive_iterator_test_cfg.yaml"` - Request using our simple client: ```python from trinity.cli.client import request res = request( - url="http://127.0.0.1:5005/data_workflow", + url="http://127.0.0.1:5005/data_processor/task_pipeline", configPath="tests/test_configs/active_iterator_test_cfg.yaml" ) diff --git a/trinity/data/server.py b/trinity/data/server.py index 08ca5ebfea..e1f57ba81b 100644 --- a/trinity/data/server.py +++ b/trinity/data/server.py @@ -1,20 +1,39 @@ import fire from flask import Flask, jsonify, request +from markupsafe import escape app = Flask(__name__) -APP_NAME = "data_workflow" +APP_NAME = "data_processor" -@app.route(f"/{APP_NAME}", methods=["GET"]) -def data_workflow(): +@app.route(f"/{APP_NAME}/", methods=["GET"]) +def data_processor(pipeline_type): from trinity.common.config import load_config from trinity.data.controllers.active_iterator import DataActiveIterator config_path = request.args.get("configPath") + pipeline_type = escape(pipeline_type) config = load_config(config_path) - iterator = DataActiveIterator(config) + pipeline_config = getattr(config.data_processor, pipeline_type) + if pipeline_config is None: + return jsonify( + { + "return_code": -1, + "message": f"Error: {pipeline_type} is not supported or the corresponding config is empty", + } + ) + + if pipeline_config.dj_config_path is None and pipeline_config.dj_process_desc is None: + return jsonify( + { + "return_code": -1, + "message": "Error: Both dj_config_path and dj_process_desc in the pipeline config are None.", + } + ) + + iterator = DataActiveIterator(pipeline_config, config.buffer) ret, msg = iterator.run() return jsonify({"return_code": ret, "message": msg}) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 5f05973d4d..31ade5f84b 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -1,17 +1,25 @@ # -*- coding: utf-8 -*- """The explorer module""" +from __future__ import annotations + +import asyncio import os import time from collections import defaultdict from typing import List, Optional, Tuple -import ray import torch +from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.buffer import get_buffer_writer from trinity.buffer.buffer import get_buffer_reader from trinity.common.config import Config -from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod +from trinity.common.constants import ( + EXPLORER_NAME, + ROLLOUT_WEIGHT_SYNC_GROUP_NAME, + RunningStatus, + SyncMethod, +) from trinity.common.models import create_inference_models from trinity.common.models.utils import ( get_checkpoint_dir_with_step_num, @@ -23,7 +31,6 @@ from trinity.utils.monitor import MONITOR -@ray.remote(name="explorer", concurrency_groups={"get_weight": 32, "setup_weight_sync_group": 1}) class Explorer: """Responsible for exploring the taskset.""" @@ -31,8 +38,9 @@ def __init__(self, config: Config): self.logger = get_logger(__name__) self.cache = CacheManager(config) explorer_meta = self.cache.load_explorer() - self.step_num = explorer_meta.get("latest_iteration", 0) + self.explore_step_num = explorer_meta.get("latest_iteration", 0) self.config = config + self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) if self.config.mode != "bench": self.experience_buffer = get_buffer_writer( @@ -47,7 +55,7 @@ def __init__(self, config: Config): self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, name=self.config.name, - role="explorer", + role=EXPLORER_NAME, config=config, ) self.batch_size = config.buffer.batch_size @@ -66,10 +74,10 @@ def __init__(self, config: Config): self.state_dict = {} else: # nccl mode self.state_dict_meta = [] + self.status = RunningStatus.RUNNING self.logger.info("Finished initializing Explorer.") - @ray.method(concurrency_group="setup_weight_sync_group") - def setup_weight_sync_group( + async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None ): # In checkpoint mode, we use explorer to store the model weights which has no rank @@ -83,6 +91,7 @@ def setup_weight_sync_group( f"world_size={world_size}, rank_offset={base_offset}" ) self.state_dict_meta = state_dict_meta + # TODO: save state_dict in models refs = [ model.init_process_group.remote( master_address=master_address, @@ -93,10 +102,11 @@ def setup_weight_sync_group( group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, timeout=self.config.synchronizer.sync_timeout, update_with_checkpoint=self.use_checkpoint_weights_update, + state_dict_meta=state_dict_meta, ) for i, model in enumerate(self.models) ] - ray.get(refs) + await asyncio.gather(*refs) def _init_runner_pool(self) -> RunnerPool: if self.config.explorer.rollout_model.engine_type != "vllm_async": @@ -113,16 +123,22 @@ def _init_runner_pool(self) -> RunnerPool: self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners") return RunnerPool(self.config, self.models, self.auxiliary_models) - def _update_model_weight(self, state_dict: dict) -> None: + async def _update_model_weight(self, state_dict: dict) -> None: # TODO: update model weight self.state_dict = state_dict - update_weight_args_list = [] - for name, param in state_dict.items(): - update_weight_args_list.append((name, param.dtype, param.shape)) - ray.get([model.sync_model.remote(update_weight_args_list) for model in self.models]) + if self.state_dict_meta is None: + update_weight_args_list = [] + for name, param in state_dict.items(): + update_weight_args_list.append((name, str(param.dtype), tuple(param.shape))) + self.state_dict_meta = update_weight_args_list + else: + update_weight_args_list = None + await asyncio.gather( + *[model.sync_model.remote(update_weight_args_list) for model in self.models] + ) self.state_dict.clear() - def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: + async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: # TODO: support more checkpoint types try: checkpoint_dir = get_checkpoint_dir_with_step_num( @@ -133,94 +149,71 @@ def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: if checkpoint_dir == self.old_checkpoint: return model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor")) - self._update_model_weight(model_weights) + await self._update_model_weight(model_weights) self.old_checkpoint = checkpoint_dir except Exception as e: - self.logger.error(f"Error when loading state_dict: {e}") + self.logger.warning(f"Fail to load checkpoint: {e}") - def _nccl_weights_update(self): - ray.get([model.sync_model.remote(self.state_dict_meta) for model in self.models]) + async def _nccl_weights_update(self): + assert self.state_dict_meta is not None + await asyncio.gather(*[model.sync_model.remote() for model in self.models]) - def prepare(self) -> None: + async def prepare(self) -> None: """Preparation before running.""" if self.use_checkpoint_weights_update: - master_address, master_port = ray.get(self.models[0].get_available_address.remote()) - self.setup_weight_sync_group(master_address, master_port) + master_address, master_port = await self.models[0].get_available_address.remote() + await self.setup_weight_sync_group(master_address, master_port) - @ray.method(concurrency_group="get_weight") - def get_weight(self, name: str) -> torch.Tensor: + async def get_weight(self, name: str) -> torch.Tensor: """Get the weight of the loaded model (For checkpoint weights update).""" return self.state_dict[name] - def explore(self) -> None: - """Explore the entire dataset.""" + async def explore(self) -> str: while True: - explore_status, explore_iter = self.explore_one_period() - if not explore_status: + try: + explore_contionue = self.explore_step() + if not explore_contionue: + break + if self.need_sync(): + self.wait_for_workflow_done() + await self.sync_weight() + if self.explore_step_num % self.config.explorer.eval_interval == 0: + self.wait_for_workflow_done() + self.eval() + except Exception as e: + self.logger.error(f"Error in Explorer: {e}") break - self.sync_weight() - if explore_iter % self.config.explorer.eval_interval == 0: - self.eval() - self.logger.info("Evaluation finished.") - self.logger.info("Explorer finished.") - - def explore_one_period(self) -> Tuple[bool, int]: - """Explore for one period. - - Different from `explore()` which consumes all tasks in the task set, - `explore_one_period()` only consume `sync_interval * batch_size` - number of tasks. - Returns: - explore_status: whether there are more tasks to explore. - explore_step_num: the number of explore steps - """ - st = time.time() - all_metrics = defaultdict(list) + self.logger.info("--------------------\n> Explorer finished.\n--------------------") + return EXPLORER_NAME - # submit tasks of this step + def explore_step(self) -> bool: + algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1) + # skip warmup + if algo_config.algorithm_type == "sft": + self.explore_step_num += 1 + return True try: - tasks = [] - for _ in range(self.config.synchronizer.sync_interval): - tasks.extend(self.taskset.read()) - self.runner_pool.run_tasks(tasks) # type: ignore + tasks = self.taskset.read() except StopIteration: + self.logger.warning("No more tasks to explore. Stop exploring.") + self.cache.save_explorer( + current_step=self.explore_step_num, + current_task_index=self.explore_step_num * self.config.buffer.batch_size, + ) + self.status = RunningStatus.STOPPED + self.wait_for_workflow_done() self.experience_buffer.finish() - self.logger.warning("No more tasks in the task set. Stop exploring.") - return False, self.step_num - - # wait for all tasks of this step to finish - while self.runner_pool.has_next(): - status_list = self.runner_pool.get_next_unorder() - if not isinstance(status_list, list): - status_list = [status_list] - for status in status_list: - if not status.ok: - self.logger.error(f"Error when running task: {status.message}") - try: - # submit another task to replace the failed task - self.runner_pool.run_tasks(self.taskset.read()) - except StopIteration: - self.logger.warning("No more tasks in the task set. Stop exploring.") - return False, self.step_num - else: - for metric_name, metric_value in status.metric.items(): - all_metrics[metric_name].append(metric_value) - - # calculate metrics - log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore - log_metrics["rollout/step_time"] = time.time() - st - self.step_num += self.config.synchronizer.sync_interval - self.monitor.log(log_metrics, step=self.step_num) - - # save explore checkpoint - self.cache.save_explorer( - current_step=self.step_num, - current_task_index=self.step_num * self.config.buffer.batch_size, - # TODO: remove current_task_index - ) + return False + self.runner_pool.run_tasks(tasks) + self.explore_step_num += 1 + return True - self.logger.info(f"Explore step {self.step_num} finished.") - return True, self.step_num + def need_sync(self) -> bool: + if self.explore_step_num <= self.config.synchronizer.sync_offset: + return False + return ( + self.explore_step_num - self.config.synchronizer.sync_offset + ) % self.config.synchronizer.sync_interval == 0 def eval(self) -> Tuple[bool, int]: """Evaluation on all evaluation data samples.""" @@ -229,7 +222,7 @@ def eval(self) -> Tuple[bool, int]: eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer)) if len(eval_tasksets) == 0: self.logger.warning("No evaluation data samples. Skip evaluation.") - return True, self.step_num + return True, self.explore_step_num self.logger.info("Evaluation started.") all_st = time.time() log_metrics = {} @@ -261,14 +254,15 @@ def wait(): log_metrics.update(metrics) log_metrics[f"eval/{eval_taskset.name}/time"] = time.time() - st log_metrics["eval/total_time"] = time.time() - all_st - self.monitor.log(log_metrics, step=self.step_num) # type: ignore - return True, self.step_num + self.monitor.log(log_metrics, step=self.explore_step_num) # type: ignore + self.logger.info("Evaluation finished.") + return True, self.explore_step_num - def benchmark(self) -> bool: + async def benchmark(self) -> bool: """Benchmark the model checkpoints.""" # benchmark on the latest checkpoint if self.config.explorer.eval_on_latest_checkpoint: - self._checkpoint_weights_update() + await self._checkpoint_weights_update() self.eval() return True @@ -282,18 +276,56 @@ def benchmark(self) -> bool: ] ) for step_num in all_ckp_steps: - self.step_num = step_num - self._checkpoint_weights_update(step_num=step_num) + self.explore_step_num = step_num + await self._checkpoint_weights_update(step_num=step_num) self.eval() return True - def sync_weight(self) -> None: + def wait_for_workflow_done(self) -> None: + """Wait for workflow to finish.""" + all_metrics = defaultdict(list) + # wait for all tasks of this step to finish + while self.runner_pool.has_next(): + status_list = self.runner_pool.get_next_unorder() + if not isinstance(status_list, list): + status_list = [status_list] + for status in status_list: + if not status.ok: + self.logger.error(f"Error when running task: {status.message}") + # submit another task to replace the failed task + try: + tasks = self.taskset.read(batch_size=1) + except StopIteration: + self.logger.warning("No more tasks in taskset. Stop retrying.") + return + self.runner_pool.run_tasks(tasks) + else: + for metric_name, metric_value in status.metric.items(): + all_metrics[metric_name].append(metric_value) + # calculate metrics + log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore + self.monitor.log(log_metrics, step=self.explore_step_num) + self.logger.info(f"Explore step {self.explore_step_num} finished.") + + async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights + self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.") + self.status = RunningStatus.WAITING_SYNC if self.use_checkpoint_weights_update: - self._checkpoint_weights_update() + await self._checkpoint_weights_update() else: # nccl weights update - self._nccl_weights_update() + await self._nccl_weights_update() + # save explore checkpoint + self.cache.save_explorer( + current_step=self.explore_step_num, + current_task_index=self.explore_step_num * self.config.buffer.batch_size, + ) + self.status = RunningStatus.RUNNING + self.logger.info(f"Explorer sync at step {self.explore_step_num} finished") + + async def running_status(self) -> RunningStatus: + return self.status def flush_log(self, step: int) -> None: """Flush the log of the current step.""" diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 80b8992b3b..de4305a9cc 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -7,10 +7,25 @@ import streamlit as st import yaml -from trinity.common.constants import AlgorithmType, StorageType +from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN +from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ENTROPY_LOSS_FN +from trinity.algorithm.kl_fn.kl_fn import KL_FN +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN +from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY +from trinity.common.constants import StorageType from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.trainer_config_manager import use_critic +register_map = { + "sample_strategy": SAMPLE_STRATEGY, + "policy_loss_fn": POLICY_LOSS_FN, + "advantage_fn": ADVANTAGE_FN, + "kl_loss_fn": KL_FN, + "kl_penalty_fn": KL_FN, + "entropy_loss_fn": ENTROPY_LOSS_FN, +} + class ConfigManager: def __init__(self): @@ -47,55 +62,48 @@ def maintain_session_state(self): for key in CONFIG_GENERATORS.default_config: st.session_state[key] = st.session_state[key] - eval_dataset_keys = [ + def maintain_list_state(prefix, key_list): + last_idx, del_num = 0, 0 + for idx in range(st.session_state[f"_{prefix}_num"]): + if st.session_state.get(f"{prefix}_{idx}_del_flag", False): + del_num += 1 + continue + for key in key_list: + full_key = f"{prefix}_{idx}_{key}" + last_full_key = f"{prefix}_{last_idx}_{key}" + st.session_state[last_full_key] = st.session_state[full_key] + last_idx += 1 + st.session_state[f"_{prefix}_num"] -= del_num + + self.eval_dataset_keys = [ "name", "path", - "subset_name", "split", + "subset_name", "prompt_key", "response_key", "temperature", "logprobs", "n", ] - last_idx, del_num = 0, 0 - for idx in range(st.session_state["_eval_tasksets_num"]): - if st.session_state.get(f"eval_taskset_{idx}_del_flag", False): - del_num += 1 - continue - for key in eval_dataset_keys: - full_key = f"eval_taskset_{idx}_{key}" - last_full_key = f"eval_taskset_{last_idx}_{key}" - st.session_state[last_full_key] = st.session_state[full_key] - last_idx += 1 - st.session_state["_eval_tasksets_num"] -= del_num - - auxiliary_model_keys = [ + maintain_list_state("eval_tasksets", self.eval_dataset_keys) + + self.inference_model_keys = [ "model_path", "engine_type", "engine_num", "tensor_parallel_size", - "gpu_memory_utilization", - "dtype", - "seed", "use_v1", "enforce_eager", "enable_prefix_caching", "enable_chunked_prefill", + "gpu_memory_utilization", + "dtype", + "seed", "enable_thinking", "enable_openai_api", ] - last_idx, del_num = 0, 0 - for idx in range(st.session_state["_auxiliary_models_num"]): - if st.session_state.get(f"auxiliary_model_{idx}_del_flag", False): - del_num += 1 - continue - for key in auxiliary_model_keys: - full_key = f"auxiliary_model_{idx}_{key}" - last_full_key = f"auxiliary_model_{last_idx}_{key}" - st.session_state[last_full_key] = st.session_state[full_key] - last_idx += 1 - st.session_state["_auxiliary_models_num"] -= del_num + maintain_list_state("auxiliary_models", self.inference_model_keys) def get_configs(self, *config_names: str, columns_spec: List[int] = None): CONFIG_GENERATORS.get_configs(*config_names, columns_spec=columns_spec) @@ -108,7 +116,7 @@ def beginner_mode(self): self.get_configs("checkpoint_root_dir") - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": self.get_configs("taskset_path") else: self.get_configs("experience_buffer_path") @@ -126,7 +134,7 @@ def beginner_mode(self): self.get_configs("sync_interval", "eval_interval", "save_interval") - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": self.get_configs("taskset_args") else: self.get_configs("dpo_dataset_kwargs") @@ -136,9 +144,6 @@ def beginner_mode(self): self.get_configs("default_workflow_type", "default_reward_fn_type") - self.get_configs("actor_use_kl_loss") - self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type") - self.get_configs( "actor_ppo_micro_batch_size_per_gpu", "actor_lr", @@ -165,7 +170,7 @@ def _expert_buffer_part(self): self.get_configs("system_prompt") self.get_configs("reply_prefix") - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": with st.expander("Taskset Configs", expanded=True): self.get_configs("taskset_path") self.get_configs("taskset_args") @@ -182,7 +187,7 @@ def _expert_buffer_part(self): self.get_configs("sft_warmup_dataset_path") self.get_configs("sft_warmup_dataset_args") - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": with st.expander("Experiences Buffer Configs", expanded=True): self.get_configs("storage_type") self.get_configs("experience_buffer_path") @@ -213,8 +218,30 @@ def _expert_explorer_part(self): self.get_configs("auxiliary_models") def _expert_trainer_part(self): - self.get_configs("algorithm_type", "gamma", "lam") - self.get_configs("repeat_times", "save_interval") + self.get_configs("algorithm_type", "repeat_times", "save_interval") + self.get_configs("sample_strategy", "advantage_fn", "entropy_loss_fn") + self.get_configs("policy_loss_fn", "kl_penalty_fn", "kl_loss_fn") + + with st.expander("Advanced Algorithm Config"): + algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"]) + default_config = algorithm.default_config() + config_key_list = [] + for key in default_config.keys(): + value = st.session_state[key] + if key == "repeat_times": + continue + default_args = register_map[key].get(value).default_args() + for sub_key in default_args.keys(): + full_key = sub_key + "_in_" + key + config_key_list.append(full_key) + + idx = 0 + while idx < len(config_key_list): + delta = 3 if len(config_key_list) - idx != 4 else 2 + key_list = config_key_list[idx : idx + delta] + idx += delta + self.get_configs(*key_list) + self.get_configs("enable_preview") if st.session_state["trainer_type"] == "verl": @@ -238,12 +265,6 @@ def _expert_verl_training_part(self): self.get_configs("max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep") - def _expert_verl_algorithm_part(self): - st.subheader("RL Algorithm Config") - self.get_configs("norm_adv_by_std_in_grpo", "use_kl_in_reward") - self.get_configs("kl_penalty", "kl_ctrl_type", "kl_ctrl_coef") - self.get_configs("horizon", "target_kl") - def _expert_verl_actor_part(self): st.subheader("Actor Model Config") self.get_configs( @@ -254,12 +275,7 @@ def _expert_verl_actor_part(self): self.get_configs("actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio") - self.get_configs("actor_grad_clip", "actor_clip_ratio", "actor_entropy_coef") - - self.get_configs("actor_use_kl_loss", "actor_use_uid") - self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type") - - self.get_configs("actor_tau", "actor_opmd_baseline") + self.get_configs("actor_grad_clip") self.get_configs("actor_checkpoint") @@ -277,7 +293,6 @@ def _expert_verl_critic_part(self): def _expert_verl_trainer_part(self): name2func = { "RL Training Config": self._expert_verl_training_part, - "RL Algorithm Config": self._expert_verl_algorithm_part, "Actor and Ref Config": self._expert_verl_actor_part, } if use_critic(): @@ -359,9 +374,6 @@ def _generate_verl_config(self): ), }, "fsdp_config": copy.deepcopy(fsdp_config), - "tau": st.session_state["actor_tau"], - "opmd_baseline": st.session_state["actor_opmd_baseline"], - "use_uid": st.session_state["actor_use_uid"], }, "ref": { "fsdp_config": copy.deepcopy(fsdp_config), @@ -375,14 +387,7 @@ def _generate_verl_config(self): ], }, }, - "custom_reward_function": {"path": None, "name": "compute_score"}, - "algorithm": { - "kl_penalty": st.session_state["kl_penalty"], - "kl_ctrl": { - "type": st.session_state["kl_ctrl_type"], - "kl_coef": st.session_state["kl_ctrl_coef"], - }, - }, + "critic": {}, "trainer": { "balance_batch": balance_batch, "resume_mode": st.session_state["resume_mode"], @@ -436,11 +441,35 @@ def _generate_verl_config(self): "cliprange_value": st.session_state["critic_cliprange_value"], "checkpoint": {"contents": st.session_state["critic_checkpoint"]}, } + else: + del trainer_config["critic"] return trainer_config + def _gen_algorithm_config(self): + algorithm_config = { + "algorithm_type": st.session_state["algorithm_type"], + } + algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"]) + default_config = algorithm.default_config() + current_config = {} + for key in default_config.keys(): + current_config[key] = value = st.session_state[key] + if key == "repeat_times": + continue + default_args = register_map[key].get(value).default_args() + args = {} + for sub_key in default_args.keys(): + full_key = sub_key + "_in_" + key + args[sub_key] = st.session_state.get(full_key, default_args[sub_key]) + if default_args != args: + current_config[key + "_args"] = args + if default_config != current_config: + algorithm_config.update(current_config) + return algorithm_config + def _gen_buffer_config(self): experience_buffer_path = st.session_state["experience_buffer_path"].strip() - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": if ( not experience_buffer_path and st.session_state["storage_type"] == StorageType.SQL.value @@ -456,6 +485,7 @@ def _gen_buffer_config(self): buffer_config = { "batch_size": st.session_state["train_batch_size"], "total_epochs": st.session_state["total_epochs"], + "explorer_input": {}, "trainer_input": { "experience_buffer": { "name": "experience_buffer", @@ -497,13 +527,25 @@ def _gen_buffer_config(self): { "name": st.session_state[f"eval_taskset_{idx}_name"], "path": st.session_state[f"eval_taskset_{idx}_path"], - "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"], "split": st.session_state[f"eval_taskset_{idx}_split"], - "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"], - "response_key": st.session_state[f"eval_taskset_{idx}_response_key"], + "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"], + "format": { + "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"], + "response_key": st.session_state[ + f"eval_taskset_{idx}_response_key" + ], + }, + "rollout_args": { + "temperature": st.session_state[f"eval_taskset_{idx}_temperature"], + "logprobs": st.session_state[f"eval_taskset_{idx}_logprobs"], + "n": st.session_state[f"eval_taskset_{idx}_n"], + }, } ) - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + else: + del buffer_config["explorer_input"] + + if st.session_state["algorithm_type"] == "dpo": experience_buffer = buffer_config["trainer_input"]["experience_buffer"] experience_buffer["split"] = st.session_state["dpo_dataset_train_split"] experience_buffer["format"] = { @@ -534,26 +576,23 @@ def _gen_explorer_config(self): "max_timeout": st.session_state["max_timeout"], "max_retry_times": st.session_state["explorer_max_retry_times"], "rollout_model": { - "engine_type": st.session_state["engine_type"], - "engine_num": st.session_state["engine_num"], - "tensor_parallel_size": st.session_state["tensor_parallel_size"], - "use_v1": st.session_state["use_v1"], - "enforce_eager": st.session_state["enforce_eager"], - "enable_prefix_caching": st.session_state["enable_prefix_caching"], - "enable_chunked_prefill": st.session_state["enable_chunked_prefill"], - "gpu_memory_utilization": st.session_state["gpu_memory_utilization"], - "dtype": st.session_state["dtype"], - "seed": st.session_state["seed"], + key: st.session_state[key] + for key in self.inference_model_keys + if key != "model_path" # "max_prompt_tokens": None, # TODO # "max_response_tokens": None, # TODO # "chat_template": None, # TODO: add chat template - "enable_thinking": st.session_state["enable_thinking"], - "enable_openai_api": st.session_state["enable_openai_api"], }, "auxiliary_models": [], "eval_interval": st.session_state["eval_interval"], "eval_on_latest_checkpoint": st.session_state["eval_on_latest_checkpoint"], } + for i in range(st.session_state["_auxiliary_models_num"]): + auxiliary_model_config = { + key: st.session_state[f"auxiliary_model_{i}_{key}"] + for key in self.inference_model_keys + } + explorer_config["auxiliary_models"].append(auxiliary_model_config) return explorer_config def generate_config(self): @@ -585,12 +624,7 @@ def generate_config(self): "project": st.session_state["project"], "name": st.session_state["exp_name"], "checkpoint_root_dir": st.session_state["checkpoint_root_dir"], - "algorithm": { - "algorithm_type": st.session_state["algorithm_type"], - "repeat_times": st.session_state["repeat_times"], - "gamma": st.session_state["gamma"], - "lam": st.session_state["lam"], - }, + "algorithm": self._gen_algorithm_config(), "data_processor": {}, # TODO: Add data processor config "model": { "model_path": st.session_state["model_path"], @@ -607,11 +641,7 @@ def generate_config(self): "trainer_type": st.session_state["trainer_type"], "save_interval": st.session_state["save_interval"], "enable_preview": st.session_state["enable_preview"], - "actor_use_kl_loss": st.session_state["actor_use_kl_loss"], - "actor_kl_loss_coef": st.session_state["actor_kl_loss_coef"], - "actor_entropy_coef": st.session_state["actor_entropy_coef"], "actor_grad_clip": st.session_state["actor_grad_clip"], - "actor_clip_ratio": st.session_state["actor_clip_ratio"], "trainer_config": trainer_config, }, "monitor": { diff --git a/trinity/manager/config_registry/__init__.py b/trinity/manager/config_registry/__init__.py index e62c565fb4..3896582755 100644 --- a/trinity/manager/config_registry/__init__.py +++ b/trinity/manager/config_registry/__init__.py @@ -1,3 +1,4 @@ +import trinity.manager.config_registry.algorithm_config_manager as algorithm_config_manager import trinity.manager.config_registry.buffer_config_manager as buffer_config_manager import trinity.manager.config_registry.explorer_config_manager as explorer_config_manager import trinity.manager.config_registry.model_config_manager as model_config_manager @@ -6,6 +7,7 @@ __all__ = [ "CONFIG_GENERATORS", + "algorithm_config_manager", "buffer_config_manager", "explorer_config_manager", "model_config_manager", diff --git a/trinity/manager/config_registry/algorithm_config_manager.py b/trinity/manager/config_registry/algorithm_config_manager.py new file mode 100644 index 0000000000..c9694dec25 --- /dev/null +++ b/trinity/manager/config_registry/algorithm_config_manager.py @@ -0,0 +1,371 @@ +import streamlit as st + +from trinity.algorithm.advantage_fn import ( + ADVANTAGE_FN, + GRPOAdvantageFn, + OPMDAdvantageFn, + PPOAdvantageFn, +) +from trinity.algorithm.algorithm import ALGORITHM_TYPE, PPOAlgorithm +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ( + ENTROPY_LOSS_FN, + EntropyLossFn, +) +from trinity.algorithm.kl_fn.kl_fn import KL_FN, KLFn +from trinity.algorithm.policy_loss_fn import ( + POLICY_LOSS_FN, + DPOLossFn, + MIXPolicyLossFn, + OPMDPolicyLossFn, + PPOPolicyLossFn, + SFTLossFn, +) +from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, MixSampleStrategy +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num + + +@CONFIG_GENERATORS.register_config( + default_value="ppo", + other_configs={"mode": "both", "_current_default_config": PPOAlgorithm.default_config()}, +) +def set_algorithm_type(**kwargs): + def on_change(): + if st.session_state["algorithm_type"] == "dpo": + st.session_state["mode"] = "train" + else: + st.session_state["mode"] = "both" + algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"]) + default_config = algorithm.default_config() + st.session_state["_current_default_config"] = default_config + for key, value in default_config.items(): + st.session_state[key] = value + set_trainer_gpu_num() + + candidates = list(ALGORITHM_TYPE.modules.keys()) + st.selectbox( + "Algorithm Type", + candidates, + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["repeat_times"], + visible=lambda: "repeat_times" in st.session_state["_current_default_config"], + other_configs={ + "_grouped_adv_repeat_times": 2, + "_not_grouped_adv_repeat_times": 1, + }, +) +def set_repeat_times(**kwargs): # TODO + key = kwargs.get("key") + grouped_adv_algorithms = [ + "grpo", + "opmd", # TODO: may add rloo + ] + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + min_repeat_times = 2 + st.session_state[key] = st.session_state["_grouped_adv_repeat_times"] + else: + min_repeat_times = 1 + st.session_state[key] = st.session_state["_not_grouped_adv_repeat_times"] + + def on_change(): + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + st.session_state["_grouped_adv_repeat_times"] = st.session_state[key] + else: + st.session_state["_not_grouped_adv_repeat_times"] = st.session_state[key] + + st.number_input( + "Repeat Times", + min_value=min_repeat_times, + help="`repeat_times` is used to set how many experiences each task can generate, " + "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", + on_change=on_change, + **kwargs, + ) + + +# Sample_strategy Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["sample_strategy"], + visible=lambda: "sample_strategy" in st.session_state["_current_default_config"], +) +def set_sample_strategy(**kwargs): + candidates = list(SAMPLE_STRATEGY.modules.keys()) + st.selectbox( + "Sample Strategy", + candidates, + help="The sample strategy used to obtain experiences.", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=MixSampleStrategy.default_args()["expert_data_ratio"], + visible=lambda: st.session_state["sample_strategy"] == "mix", +) +def set_expert_data_ratio_in_sample_strategy(**kwargs): + st.number_input( + "Expert Data Ratio", + min_value=0.0, + max_value=1.0, + value=0.5, + help="The ratio of expert data to be used in the training.", + **kwargs, + ) + + +# Advantage Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["advantage_fn"], + visible=lambda: "advantage_fn" in st.session_state["_current_default_config"], +) +def set_advantage_fn(**kwargs): + candidates = list(ADVANTAGE_FN.modules.keys()) + st.selectbox( + "Advantage Function", + candidates, + help="The advantage function used to compute advantages.", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAdvantageFn.default_args()["gamma"], + visible=lambda: st.session_state["advantage_fn"] in {"ppo", "reinforceplusplus"}, +) +def set_gamma_in_advantage_fn(**kwargs): + st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAdvantageFn.default_args()["lam"], + visible=lambda: st.session_state["advantage_fn"] == "ppo", +) +def set_lam_in_advantage_fn(**kwargs): + st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=GRPOAdvantageFn.default_args()["epsilon"], + visible=lambda: st.session_state["advantage_fn"] == "grpo", +) +def set_epsilon_in_advantage_fn(**kwargs): # TODO: update help message + st.number_input( + r"GRPO Epsilon", + help=r""" +```python +scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) +``` +""", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=OPMDAdvantageFn.default_args()["opmd_baseline"], + visible=lambda: st.session_state["advantage_fn"] == "opmd", +) +def set_opmd_baseline_in_advantage_fn(**kwargs): + st.selectbox( + "OPMD Baseline", + ["mean", "logavgexp"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=OPMDAdvantageFn.default_args()["tau"], + visible=lambda: st.session_state["advantage_fn"] == "opmd" + and st.session_state["opmd_baseline_in_advantage_fn"] == "logavgexp", +) +def set_tau_in_advantage_fn(**kwargs): + st.number_input("Tau for OPMD Adv.", min_value=0.0, format="%.1e", **kwargs) + + +# KL Loss Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["kl_loss_fn"], + visible=lambda: "kl_loss_fn" in st.session_state["_current_default_config"], +) +def set_kl_loss_fn(**kwargs): + candidates = list(KL_FN.modules.keys()) + st.selectbox( + "KL Loss Type", + candidates, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=KLFn.default_args()["kl_coef"], + visible=lambda: st.session_state["kl_loss_fn"] != "none", +) +def set_kl_coef_in_kl_loss_fn(**kwargs): + st.number_input( + r"KL Loss Coef :blue-badge[$\beta$]", + min_value=0.0, + max_value=1.0, + format="%.1e", + **kwargs, + ) + + +# KL Penalty Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["kl_penalty_fn"], + visible=lambda: "kl_penalty_fn" in st.session_state["_current_default_config"], +) +def set_kl_penalty_fn(**kwargs): + candidates = list(KL_FN.modules.keys()) + st.selectbox( + "KL Penalty Type", + candidates, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=KLFn.default_args()["adaptive"], + visible=lambda: st.session_state["kl_penalty_fn"] != "none", +) +def set_adaptive_in_kl_penalty_fn(**kwargs): + st.checkbox( + "Adaptive KL Penalty", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=KLFn.default_args()["kl_coef"], + visible=lambda: st.session_state["kl_penalty_fn"] != "none", +) +def set_kl_coef_in_kl_penalty_fn(**kwargs): + st.number_input( + r"KL Penalty Coef", + min_value=0.0, + max_value=1.0, + format="%.1e", + **kwargs, + ) + + +# TODO: target_kl and horizon + +# Policy Loss Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["policy_loss_fn"], + visible=lambda: "policy_loss_fn" in st.session_state["_current_default_config"], +) +def set_policy_loss_fn(**kwargs): + candidates = list(POLICY_LOSS_FN.modules.keys()) + st.selectbox( + "Policy Loss Fn", + candidates, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=PPOPolicyLossFn.default_args()["clip_range"], + visible=lambda: st.session_state["policy_loss_fn"] in {"ppo", "mix"}, +) +def set_clip_range_in_policy_loss_fn(**kwargs): + st.number_input( + "Clip Range", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=SFTLossFn.default_args()["use_token_level_loss"], + visible=lambda: st.session_state["policy_loss_fn"] == "sft", +) +def set_use_token_level_loss_in_policy_loss_fn(**kwargs): + st.checkbox( + "Use Token Level Loss", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=DPOLossFn.default_args()["beta"], + visible=lambda: st.session_state["policy_loss_fn"] == "dpo", +) +def set_beta_in_policy_loss_fn(**kwargs): + st.number_input( + "Beta for DPO", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=DPOLossFn.default_args()["label_smoothing"], + visible=lambda: st.session_state["policy_loss_fn"] == "dpo", +) +def set_label_smoothing_in_policy_loss_fn(**kwargs): + st.number_input( + "Label Smoothing", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=OPMDPolicyLossFn.default_args()["tau"], + visible=lambda: st.session_state["policy_loss_fn"] == "opmd", +) +def set_tau_in_policy_loss_fn(**kwargs): + st.number_input("Tau for OPMD Loss", min_value=0.0, format="%.1e", **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=MIXPolicyLossFn.default_args()["mu"], + visible=lambda: st.session_state["policy_loss_fn"] == "mix", +) +def set_mu_in_policy_loss_fn(**kwargs): + st.number_input("Mu for Mix Policy Loss", min_value=0.0, **kwargs) + + +# Entropy Loss Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["entropy_loss_fn"], + visible=lambda: "entropy_loss_fn" in st.session_state["_current_default_config"], +) +def set_entropy_loss_fn(**kwargs): + candidates = list(ENTROPY_LOSS_FN.modules.keys()) + st.selectbox("Entropy Loss Function", candidates, **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=EntropyLossFn.default_args()["entropy_coef"], + visible=lambda: st.session_state["entropy_loss_fn"] != "none", +) +def set_entropy_coef_in_entropy_loss_fn(**kwargs): + st.number_input( + "Entropy Coeff", + min_value=0.0, + max_value=1.0, + format="%.1e", + **kwargs, + ) diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py index 044f982e94..f704d0ecd2 100644 --- a/trinity/manager/config_registry/buffer_config_manager.py +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -1,6 +1,6 @@ import streamlit as st -from trinity.common.constants import AlgorithmType, PromptType, StorageType +from trinity.common.constants import PromptType, StorageType from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS from trinity.common.workflows.workflow import WORKFLOWS from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS @@ -264,7 +264,7 @@ def set_reply_prefix(**kwargs): ) def set_storage_type(**kwargs): key = kwargs.get("key") - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": st.session_state[key] = st.session_state["_dpo_storage_type"] storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] else: @@ -272,7 +272,7 @@ def set_storage_type(**kwargs): storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value] def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": st.session_state["_dpo_storage_type"] = st.session_state[key] else: st.session_state["_not_dpo_storage_type"] = st.session_state[key] @@ -294,7 +294,7 @@ def on_change(): ) def set_experience_buffer_path(**kwargs): # TODO key = kwargs.get("key") - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": if st.session_state["taskset_path"] and not st.session_state["_dpo_experience_buffer_path"]: st.session_state["_dpo_experience_buffer_path"] = st.session_state["taskset_path"] st.session_state[key] = st.session_state["_dpo_experience_buffer_path"] @@ -314,7 +314,7 @@ def set_experience_buffer_path(**kwargs): # TODO if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_root_dir, '.cache', project_name, experiment_name)}/data.db`.""" def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": st.session_state["_dpo_experience_buffer_path"] = st.session_state[key] else: st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[key] @@ -324,7 +324,7 @@ def on_change(): @CONFIG_GENERATORS.register_check() def check_experience_buffer_path(unfinished_fields: set, key: str): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": if not st.session_state[key].strip(): unfinished_fields.add(key) st.warning("Please input DPO dataset path.") diff --git a/trinity/manager/config_registry/explorer_config_manager.py b/trinity/manager/config_registry/explorer_config_manager.py index 9393187f60..12e8034a30 100644 --- a/trinity/manager/config_registry/explorer_config_manager.py +++ b/trinity/manager/config_registry/explorer_config_manager.py @@ -1,6 +1,6 @@ import streamlit as st -from trinity.common.constants import AlgorithmType, SyncMethod +from trinity.common.constants import SyncMethod from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num @@ -255,7 +255,7 @@ def check_auxiliary_models(unfinished_fields: set, key: str): ) def set_sync_method(**kwargs): key = kwargs.get("key") - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": st.session_state[key] = SyncMethod.CHECKPOINT.value disabled = True else: @@ -263,7 +263,7 @@ def set_sync_method(**kwargs): disabled = False def on_change(): - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": st.session_state["_not_dpo_sync_method"] = st.session_state[key] st.selectbox( diff --git a/trinity/manager/config_registry/model_config_manager.py b/trinity/manager/config_registry/model_config_manager.py index 837bf27679..f9014e58a1 100644 --- a/trinity/manager/config_registry/model_config_manager.py +++ b/trinity/manager/config_registry/model_config_manager.py @@ -2,10 +2,9 @@ import streamlit as st -from trinity.common.constants import AlgorithmType, MonitorType +from trinity.common.constants import MonitorType from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.trainer_config_manager import use_critic -from trinity.trainer.verl.ray_trainer import AdvantageEstimator def set_total_gpu_num(): @@ -64,91 +63,6 @@ def set_monitor_type(**kwargs): ) -# Algorithm Configs - - -@CONFIG_GENERATORS.register_config( - default_value=AlgorithmType.PPO.value, - other_configs={"mode": "both", "adv_estimator": AdvantageEstimator.GAE.value}, -) -def set_algorithm_type(**kwargs): - def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.PPO.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value - elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["mode"] = "train" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - else: # TODO: add more algorithms - pass - set_trainer_gpu_num() - - st.selectbox( - "Algorithm Type", - [ - AlgorithmType.PPO.value, - AlgorithmType.GRPO.value, - AlgorithmType.DPO.value, - AlgorithmType.OPMD.value, - ], - key="algorithm_type", - on_change=on_change, - ) - - -@CONFIG_GENERATORS.register_config( - default_value=1, - visible=lambda: st.session_state["mode"] == "both", - other_configs={ - "_grouped_adv_repeat_times": 2, - "_not_grouped_adv_repeat_times": 1, - }, -) -def set_repeat_times(**kwargs): # TODO - key = kwargs.get("key") - grouped_adv_algorithms = [ - AlgorithmType.GRPO.value, - AlgorithmType.OPMD.value, # TODO: may add rloo - ] - if st.session_state["algorithm_type"] in grouped_adv_algorithms: - min_repeat_times = 2 - st.session_state[key] = st.session_state["_grouped_adv_repeat_times"] - else: - min_repeat_times = 1 - st.session_state[key] = st.session_state["_not_grouped_adv_repeat_times"] - - def on_change(): - if st.session_state["algorithm_type"] in grouped_adv_algorithms: - st.session_state["_grouped_adv_repeat_times"] = st.session_state[key] - else: - st.session_state["_not_grouped_adv_repeat_times"] = st.session_state[key] - - st.number_input( - "Repeat Times", - min_value=min_repeat_times, - help="`repeat_times` is used to set how many experiences each task can generate, " - "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", - on_change=on_change, - **kwargs, - ) - - -@CONFIG_GENERATORS.register_config(default_value=1.0) -def set_gamma(**kwargs): - st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs) - - -@CONFIG_GENERATORS.register_config(default_value=1.0) -def set_lam(**kwargs): - st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs) - - # Model Configs diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py index d0f5d26897..9b3e5f3ea9 100644 --- a/trinity/manager/config_registry/trainer_config_manager.py +++ b/trinity/manager/config_registry/trainer_config_manager.py @@ -1,12 +1,13 @@ import streamlit as st -from trinity.common.constants import AlgorithmType, SyncMethod +from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.common.constants import SyncMethod from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS -from trinity.trainer.verl.ray_trainer import AdvantageEstimator def use_critic(): - return st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value + algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"]) + return algorithm.use_critic @CONFIG_GENERATORS.register_config(default_value="verl") @@ -18,7 +19,7 @@ def set_trainer_type(**kwargs): def set_save_interval(**kwargs): key = kwargs.get("key") if ( - st.session_state["algorithm_type"] == AlgorithmType.DPO.value + st.session_state["algorithm_type"] == "dpo" or st.session_state["sync_method"] == SyncMethod.NCCL.value ): st.session_state[key] = st.session_state["_nccl_save_interval"] @@ -29,7 +30,7 @@ def set_save_interval(**kwargs): def on_change(): if ( - st.session_state["algorithm_type"] == AlgorithmType.DPO.value + st.session_state["algorithm_type"] == "dpo" or st.session_state["sync_method"] == SyncMethod.NCCL.value ): st.session_state["_nccl_save_interval"] = st.session_state[key] @@ -49,54 +50,6 @@ def set_enable_preview(**kwargs): st.checkbox("Enable Preview", **kwargs) -def _actor_use_kl_loss_visible(): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["actor_use_kl_loss"] = True - return False - return True - - -@CONFIG_GENERATORS.register_config( - default_value=True, - visible=_actor_use_kl_loss_visible, - other_configs={"_not_dpo_actor_use_kl_loss": True}, -) -def set_actor_use_kl_loss(**kwargs): - key = kwargs.get("key") - st.session_state[key] = st.session_state["_not_dpo_actor_use_kl_loss"] - - def on_change(): - st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[key] - - st.checkbox("Use KL Loss", on_change=on_change, **kwargs) - - -@CONFIG_GENERATORS.register_config( - default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"] -) -def set_actor_kl_loss_coef(**kwargs): - st.number_input( - r"KL Loss Coef :blue-badge[$\beta$]", - min_value=0.0, - max_value=1.0, - format="%.1e", - **kwargs, - ) - - -@CONFIG_GENERATORS.register_config( - default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"] -) -def set_actor_entropy_coef(**kwargs): - st.number_input( - "Entropy Coeff", - min_value=0.0, - max_value=1.0, - format="%.1e", - **kwargs, - ) - - @CONFIG_GENERATORS.register_config(default_value=1.0) def set_actor_grad_clip(**kwargs): st.number_input( @@ -108,16 +61,6 @@ def set_actor_grad_clip(**kwargs): ) -@CONFIG_GENERATORS.register_config(default_value=0.2) -def set_actor_clip_ratio(**kwargs): - st.number_input( - r"Clip Ratio :blue-badge[$\epsilon$]", - min_value=0.0, - max_value=1.0, - **kwargs, - ) - - # veRL Trainer Configs @@ -322,31 +265,6 @@ def set_actor_lr_warmup_steps_ratio(**kwargs): ) -@CONFIG_GENERATORS.register_config( - default_value=0.0, visible=lambda: st.session_state["algorithm_type"] == "opmd" -) -def set_actor_tau(**kwargs): - st.number_input("Tau for OPMD", min_value=0.0, format="%.1e", **kwargs) - - -@CONFIG_GENERATORS.register_config( - default_value="mean", visible=lambda: st.session_state["algorithm_type"] == "opmd" -) -def set_actor_opmd_baseline(**kwargs): - st.selectbox( - "OPMD Baseline", - ["mean", "logavgexp"], - **kwargs, - ) - - -@CONFIG_GENERATORS.register_config( - default_value=False, visible=lambda: st.session_state["algorithm_type"] == "opmd" -) -def set_actor_use_uid(**kwargs): - st.checkbox("Use UID for OPMD", **kwargs) - - @CONFIG_GENERATORS.register_config(default_value="low_var_kl") def set_actor_kl_loss_type(**kwargs): st.selectbox( diff --git a/trinity/manager/manager.py b/trinity/manager/manager.py index 3c148cbe12..baaf1242c3 100644 --- a/trinity/manager/manager.py +++ b/trinity/manager/manager.py @@ -47,7 +47,13 @@ def load_explorer(self) -> dict: try: with open(self.explorer_meta_path, "r", encoding="utf-8") as f: explorer_meta = json.load(f) - logger.info(f"Find existing explorer meta: {explorer_meta}") + logger.info( + "----------------------------------\n" + "Found existing explorer checkpoint:\n" + f" > {explorer_meta}\n" + "Continue exploring from this point.\n" + "----------------------------------" + ) return explorer_meta except Exception as e: logger.error(f"Failed to load explore meta file: {e}") @@ -62,7 +68,13 @@ def load_trainer(self) -> dict: try: with open(self.trainer_meta_path, "r", encoding="utf-8") as f: trainer_meta = json.load(f) - logger.info(f"Find existing trainer meta: {trainer_meta}") + logger.info( + "----------------------------------\n" + "Found existing trainer checkpoint:\n" + f" > {trainer_meta}\n" + "Continue training from this point.\n" + "----------------------------------" + ) return trainer_meta except Exception as e: logger.warning(f"Failed to load trainer meta file: {e}") diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 36d23e7628..216c916c69 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -1,119 +1,74 @@ # -*- coding: utf-8 -*- """ Trainer Class -This file is modified from verl.trainer.main_ppo.py -And is a reproduction code of Jiayi-Pan/TinyZero. - -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ +from __future__ import annotations + import os from abc import ABC, abstractmethod -from typing import Tuple import ray -from trinity.buffer import get_buffer_reader from trinity.common.config import Config -from trinity.common.constants import AlgorithmType, SyncMethod -from trinity.common.experience import Experiences +from trinity.common.constants import ( + EXPLORER_NAME, + TRAINER_NAME, + RunningStatus, + SyncMethod, +) from trinity.utils.log import get_logger -@ray.remote(name="trainer") class Trainer: """Consume the experience and train the model.""" def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) - self.train_buffer = get_buffer_reader( - self.config.buffer.trainer_input.experience_buffer, # type: ignore - self.config.buffer, - ) - self.sft_warmup_buffer = ( - get_buffer_reader( - self.config.buffer.trainer_input.sft_warmup_dataset, # type: ignore - self.config.buffer, - ) - if self.config.buffer.trainer_input.sft_warmup_steps > 0 - else None - ) self.engine = get_trainer_wrapper(config) + self.explorer_ref = None def prepare(self) -> None: """Prepare the trainer.""" self.engine.prepare() - def train(self, algo_type: AlgorithmType = AlgorithmType.PPO): + def train(self) -> str: """Train the model.""" while True: - train_status, _ = self.train_step(algo_type) - if not train_status: + try: + train_continue = self.train_step() + if not train_continue: + break + if self.need_sync(): + self.sync_weight() + except Exception as e: + self.logger.error(f"Error in Trainer: {e}") break + self.logger.info("--------------------\n> Trainer finished.\n--------------------") + return TRAINER_NAME - def train_one_period(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: - """Train for one period. Each period contains `sync_interval` steps. - Returns: - train_status: Whether to continue training. - train_step_num: The number of training steps""" - for _ in range(self.config.synchronizer.sync_interval): - train_status, train_step_num = self.train_step(algo_type) - if not train_status: - return False, train_step_num - self.logger.info(f"Train step {train_step_num} finished.") - return True, train_step_num - - def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: + def train_step(self) -> bool: """Train one step. - Args: - algo_type (AlgorithmType): The type of data to be used for training. - Defaults to AlgorithmType.PPO. - Returns: bool: Whether to continue training. """ - self.engine.set_mode(algo_type) - if algo_type.is_rft() and self.config.buffer.trainer_input.read_experience_strategy: - strategy = self.config.buffer.trainer_input.read_experience_strategy - else: - strategy = None - try: - if algo_type.is_sft(): - exps = self.sft_warmup_buffer.read() - else: - exps = self.train_buffer.read(strategy=strategy) - except StopIteration: - self.logger.warning("No more data to train. Stop training.") - return False, 0 # TODO: get the actual step number - - if algo_type.is_sft(): - return self.engine.train_sft_step( - Experiences.gather_experiences( - exps, - pad_token_id=self.config.buffer.pad_token_id, # type: ignore - ) - ) - elif algo_type.is_rft(): - return self.engine.train_rft_step( - Experiences.gather_experiences( - exps, - pad_token_id=self.config.buffer.pad_token_id, # type: ignore - ) - ) - elif algo_type.is_dpo(): - return self.engine.train_dpo_step( - Experiences.gather_dpo_experiences( - exps, - pad_token_id=self.config.buffer.pad_token_id, # type: ignore - ) - ) - else: - raise ValueError(f"Unsupported algorithm type: {algo_type}") + return self.engine.train_step() + + def need_sync(self) -> bool: + """Whether to sync the model weight.""" + return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0 def sync_weight(self) -> None: """Sync the model weight.""" if self.config.synchronizer.sync_method == SyncMethod.NCCL: + if self.explorer_ref is None: + self.explorer_ref = ray.get_actor(EXPLORER_NAME) + explorer_status = ray.get(self.explorer_ref.running_status.remote()) + if explorer_status == RunningStatus.STOPPED: + self.logger.warning("Explorer has already stopped. Skipping sync weight.") + return + self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.") self.engine.sync_weight() def flush_log(self, step: int) -> None: @@ -122,7 +77,7 @@ def flush_log(self, step: int) -> None: def shutdown(self) -> None: # if checkpoint not saved, save the last checkpoint - step_num = self.engine.global_steps - 1 + step_num = self.engine.train_step_num path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}") if not os.path.isdir(path) or len(os.listdir(path)) == 0: self.engine.save_checkpoint() @@ -136,17 +91,14 @@ class TrainEngineWrapper(ABC): def prepare(self) -> None: """Do some preparation before training started.""" + @property @abstractmethod - def train_rft_step(self, experiences) -> Tuple[bool, int]: - """Train on the RFT data.""" + def train_step_num(self) -> int: + """Get the current training step number.""" @abstractmethod - def train_sft_step(self, experiences) -> Tuple[bool, int]: - """Train on the SFT data.""" - - @abstractmethod - def train_dpo_step(self, experiences) -> Tuple[bool, int]: - """Train on the DPO data.""" + def train_step(self) -> bool: + """Training.""" @abstractmethod def save_checkpoint(self) -> None: @@ -156,10 +108,6 @@ def save_checkpoint(self) -> None: def sync_weight(self) -> None: """Sync the model weight.""" - @abstractmethod - def set_mode(self, algo_type: AlgorithmType) -> None: - """Set training mode.""" - @abstractmethod def shutdown(self) -> None: """Shutdown the engine.""" diff --git a/trinity/trainer/verl/core_algos.py b/trinity/trainer/verl/core_algos.py deleted file mode 100644 index 20cffc9962..0000000000 --- a/trinity/trainer/verl/core_algos.py +++ /dev/null @@ -1,717 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Modified from core_algos.py -""" - -from abc import ABC, abstractmethod -from collections import defaultdict - -import numpy as np -import torch -import torch.nn.functional as F -import verl.utils.torch_functional as verl_F - -from trinity.common.constants import AlgorithmType - - -class KLController(ABC): - @abstractmethod - def update(self, current_kl, n_steps): - """update value""" - - -class AdaptiveKLController(KLController): - """ - Adaptive KL controller described in the paper: - https://arxiv.org/pdf/1909.08593.pdf - """ - - def __init__(self, init_kl_coef, target_kl, horizon): - self.value = init_kl_coef - self.target = target_kl - self.horizon = horizon - - def update(self, current_kl, n_steps): - target = self.target - proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.horizon - self.value *= mult - - -class FixedKLController(KLController): - """Fixed KL controller.""" - - def __init__(self, kl_coef): - self.value = kl_coef - - def update(self, current_kl, n_steps): - pass - - -def get_kl_controller(kl_config): - if kl_config.type == "fixed": - return FixedKLController(kl_coef=kl_config.kl_coef) - elif kl_config.type == "adaptive": - assert kl_config.horizon > 0, f"horizon must be larger than 0. Got {kl_config.horizon}" - return AdaptiveKLController( - init_kl_coef=kl_config.kl_coef, - target_kl=kl_config.target_kl, - horizon=kl_config.horizon, - ) - else: - raise ValueError("Unknown kl_ctrl type") - - -def compute_opmd_outcome_advantage( - token_level_rewards: torch.Tensor, - eos_mask: torch.Tensor, - index: torch.Tensor, - opmd_baseline: str = "mean", - tau: float = 1.0, -): - """Modified from compute_grpo_outcome_advantage - - Compute advantage for OPMD, operating only on Outcome reward - (with only one scalar reward for each response). - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - response_length = token_level_rewards.shape[-1] - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2baseline = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2baseline[idx] = torch.tensor(0.0) - # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?) - elif len(id2score[idx]) > 1: - if opmd_baseline == "mean": - id2baseline[idx] = torch.mean(torch.tensor(id2score[idx])) - elif opmd_baseline == "logavgexp": - rewards_tensor = torch.tensor(id2score[idx]) - # NOTE: we use the fact that logavgexp(x) = logsumexp(x) - log(len(x)). - # Hopefully the logsumexp calculation is numerically stable (as claimed by PyTorch's doc) - # in cases where tau is small... - id2baseline[idx] = tau * ( - torch.logsumexp(rewards_tensor / tau, dim=-1) - - torch.log(torch.tensor(len(id2score[idx]))) - ) - else: - raise NotImplementedError - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - scores[i] = scores[i] - id2baseline[index[i]] - scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask - - return scores, scores - - -def compute_gae_advantage_return( - token_level_rewards: torch.Tensor, - values: torch.Tensor, - eos_mask: torch.Tensor, - gamma: torch.Tensor, - lam: torch.Tensor, -): - """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py - - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - values: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. - gamma: `(float)` - discounted factor used in RL - lam: `(float)` - lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - - """ - with torch.no_grad(): - lastgaelam = 0 - advantages_reversed = [] - gen_len = token_level_rewards.shape[-1] - - # values = values * eos_mask TODO: may use in multi-turn - for t in reversed(range(gen_len)): - nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 - delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] - - lastgaelam = delta + gamma * lam * lastgaelam - # lastgaelam = torch.where( # TODO: may use in multi-turn - # eos_mask[:, t] == 1, delta + gamma * lam * lastgaelam, lastgaelam - # ) - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], dim=1) - - returns = advantages + values - advantages = verl_F.masked_whiten(advantages, eos_mask) - return advantages, returns - - -# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. -def compute_grpo_outcome_advantage( - token_level_rewards: torch.Tensor, - eos_mask: torch.Tensor, - index: torch.Tensor, - epsilon: float = 1e-6, -): - """ - Compute advantage for GRPO, operating only on Outcome reward - (with only one scalar reward for each response). - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - response_length = token_level_rewards.shape[-1] - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2mean = {} - id2std = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - id2std[idx] = torch.tensor(1.0) - elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - id2std[idx] = torch.std(torch.tensor([id2score[idx]])) - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) - scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask - - return scores, scores - - -def compute_rloo_outcome_advantage( - token_level_rewards: torch.Tensor, - eos_mask: torch.Tensor, - index: torch.Tensor, - epsilon: float = 1e-6, -): - """ - Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - response_length = token_level_rewards.shape[-1] - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2mean = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - response_num = len(id2score[index[i]]) - if response_num > 1: - scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[ - index[i] - ] * response_num / (response_num - 1) - scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask - - return scores, scores - - -def compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor -): - """ - Compute advantage for REINFORCE++. - This implementation is based on the paper: https://arxiv.org/abs/2501.03262 - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - - with torch.no_grad(): - returns = torch.zeros_like(token_level_rewards) - running_return = 0 - - for t in reversed(range(token_level_rewards.shape[1])): - running_return = token_level_rewards[:, t] + gamma * running_return - returns[:, t] = running_return - - advantages = verl_F.masked_whiten(returns, eos_mask) - advantages = advantages * eos_mask - - return advantages, returns - - -def compute_remax_outcome_advantage( - token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor -): - """ - Compute advantage for ReMax, operating only on Outcome reward - This implementation is based on the paper: https://arxiv.org/abs/2310.10505 - - (with only one scalar reward for each response). - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - reward_baselines: `(torch.Tensor)` - shape: (bs,) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - response_length = token_level_rewards.shape[-1] - token_level_rewards.sum(dim=-1) - - with torch.no_grad(): - returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) - advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask - - return advantages, returns - - -def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): - kl = old_log_prob - ref_log_prob - return token_level_scores - kl * kl_ratio - - -def compute_policy_loss(old_log_prob, log_prob, eos_mask, **kwargs): - """Compute policy loss for PPO / OPMD / pairwise OPMD""" - - algorithm_type: AlgorithmType = kwargs.get("algorithm_type", AlgorithmType.PPO) - - if algorithm_type == AlgorithmType.OPMD: - advantages = kwargs.get("advantages") - tau = kwargs.get("tau") - return compute_policy_loss_opmd( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - eos_mask=eos_mask, - tau=tau, - ) - - elif algorithm_type == AlgorithmType.PAIRWISE_OPMD: - token_level_scores = kwargs.get("token_level_scores") - index = kwargs.get("index") - tau = kwargs.get("tau") - return compute_policy_loss_pairwise_opmd( - old_log_prob=old_log_prob, - log_prob=log_prob, - token_level_scores=token_level_scores, - eos_mask=eos_mask, - index=index, - tau=tau, - ) - - elif algorithm_type.is_rft(): - advantages = kwargs.get("advantages") - cliprange = kwargs.get("cliprange") - return compute_policy_loss_ppo( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - eos_mask=eos_mask, - cliprange=cliprange, - ) - - else: - raise NotImplementedError(f"Get invalid algorithm_type '{algorithm_type}'.") - - -def compute_policy_loss_dpo( - log_prob, ref_log_prob, eos_mask, loss_type="sigmoid", beta=0.1, label_smoothing=0.0 -): - """Compute policy loss for DPO (Direct Preference Optimization) - - Ref: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L918 - - Args: - log_prob: `(torch.Tensor)` - The log probabilities of the chosen responses from the policy model. - ref_log_prob: `(torch.Tensor)` - The log probabilities of the chosen responses from the reference model. - loss_type: `(str)` - Default: "sigmoid" - The type of loss function to use. - beta: `(float)` - Default: 0.1 - A temperature parameter that controls the sharpness of the preference signal. - Higher values make the loss more sensitive to small differences in log probabilities. - label_smoothing: `(float)` - Default: 0.0 - A parameter to encode uncertainty about the labels. Adds a small amount of smoothing to the loss - to avoid overconfident predictions. - - Returns: - dpo_loss: `a scalar torch.Tensor` - chosen_diff: `(torch.Tensor)` - rejected_diff: `(torch.Tensor)` - """ - # log_prob: chosen, rejected, chosen, rejected, ... - chosen_log_prob, rejected_log_prob = log_prob[::2], log_prob[1::2] - chosen_mask, rejected_mask = eos_mask[::2], eos_mask[1::2] - chosen_log_prob_sum = (chosen_log_prob * chosen_mask).sum(-1) - rejected_log_prob_sum = (rejected_log_prob * rejected_mask).sum(-1) - - if ref_log_prob is None: - raise NotImplementedError("DPO requires valid ref_log_prob") - chosen_ref_log_prob, rejected_ref_log_prob = ref_log_prob[::2], ref_log_prob[1::2] - chosen_ref_log_prob_sum = (chosen_ref_log_prob * chosen_mask).sum(-1) - rejected_ref_log_prob_sum = (rejected_ref_log_prob * rejected_mask).sum(-1) - - # compute logits - chosen_ratios = chosen_log_prob_sum - chosen_ref_log_prob_sum - rejected_ratios = rejected_log_prob_sum - rejected_ref_log_prob_sum - logits = chosen_ratios - rejected_ratios - - if loss_type == "sigmoid": - losses = ( - -F.logsigmoid(beta * logits) * (1 - label_smoothing) - - F.logsigmoid(-beta * logits) * label_smoothing - ) - loss = losses.mean() - - else: - raise NotImplementedError(f"loss_type {loss_type} is not supported in DPO") - - chosen_reward = beta * chosen_ratios.detach() - rejected_reward = beta * rejected_ratios.detach() - return loss, chosen_reward, rejected_reward - - -def compute_policy_loss_pairwise_opmd( - old_log_prob, log_prob, token_level_scores, eos_mask, index, tau -): - """Compute policy loss for pairwise_opmd - - NOTE: NOT TESTED YET - - TODO: allow using old_log_prob; for now we just discard it. - - NOTE: use token_level_scores rather than token_level_rewards, because we're not sure yet - whether this algorithm is compatible with kl penalty as negative reward - - Args: - old_log_prob: `(torch.Tensor)` - shape: (bs, response_length) - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - token_level_scores: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - index: `(torch.Tensor)` or None (when use_uid is False) - tau: `float` - - Returns: - opmd_loss: `a scalar torch.Tensor` - pairwise_opmd loss - pg_clipfrac: (float) - a float number indicating the fraction of policy gradient loss being clipped - ppo_kl: (float) ... (TODO, confirm that this is only used for logging stats) - - """ - - # dummy computation - log_prob_diff = log_prob - log_prob - pg_clipfrac = verl_F.masked_mean(torch.gt(log_prob_diff, log_prob_diff).float(), eos_mask) - ppo_kl = verl_F.masked_mean(-log_prob_diff, eos_mask) - - # loss for pairwise_opmd - scores = token_level_scores.sum(dim=-1) - action_level_log_prob = (log_prob * eos_mask).sum(dim=-1) - diffs = scores - tau * (action_level_log_prob - action_level_log_prob.detach()) - - if index is None: - normalizer = eos_mask.sum() * max(1.0, tau) - opmd_loss = (diffs - diffs.mean()).square().sum() / normalizer - else: - opmd_loss = None - unique_index = list(set(index.tolist())) - for idx in unique_index: - subdiff = diffs[index == idx] - if subdiff.shape[0] == 1: - continue - # subloss = len(subdiff) * subdiff.square().sum() - subdiff.sum().square() - subloss = (subdiff - subdiff.mean()).square().sum() - if opmd_loss is None: - opmd_loss = subloss - else: - opmd_loss = opmd_loss + subloss - normalizer = eos_mask.sum() * max(1.0, tau) - opmd_loss = opmd_loss / normalizer - - # NOTE: return pg_clipfrac and ppo_kl merely for compatibility with original compute_policy_loss - return opmd_loss, pg_clipfrac, ppo_kl - - -def compute_policy_loss_opmd(old_log_prob, log_prob, advantages, eos_mask, tau): - """The OPMD counterpart of verl's original compute_policy_loss (now renamed as compute_policy_loss_ppo) - - Args: - old_log_prob: `(torch.Tensor)` - shape: (bs, response_length) - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - advantages: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - tau: `float` - - Returns: - opmd_loss: `a scalar torch.Tensor` - opmd loss - pg_clipfrac: (float) - a float number indicating the fraction of policy gradient loss being clipped - ppo_kl: (float) ... (TODO, confirm that this is only used for logging stats) - - """ - log_prob_diff = log_prob - old_log_prob - pg_clipfrac = verl_F.masked_mean( - torch.gt(log_prob_diff, log_prob_diff).float(), eos_mask - ) # meaningless - ppo_kl = verl_F.masked_mean(-log_prob_diff, eos_mask) - - # --- version 0: kimi-opmd --- - - # # the original quadratic loss in OPMD can be reformulated as follows - # pg_losses = -advantages * log_prob - # pg_loss = verl_F.masked_sum(pg_losses, eos_mask) - - # reg_losses = (log_prob_diff * eos_mask).sum(dim=-1).square() - # reg_loss = reg_losses.sum() - - # opmd_loss = (pg_loss + 0.5 * tau * reg_loss) / eos_mask.sum() - # # NOTE: this implementation uses batch-wise normalization; - # # would it be beneficial to use trajectory-wise or group-wise normalization? - - # opmd_loss = opmd_loss / max(1.0, tau) # for stability when tau is large - - # --- version 1: min-opmd (minimalistic, but theoretically grounded) --- - - pg_losses = -advantages * log_prob - opmd_loss = verl_F.masked_mean(pg_losses, eos_mask) - opmd_loss = opmd_loss / (1.0 + tau) # for regularization (w.r.t. current pi_theta) - - # NOTE: return pg_clipfrac and ppo_kl merely for compatibility with original compute_policy_loss - return opmd_loss, pg_clipfrac, ppo_kl - - -def compute_policy_loss_ppo(old_log_prob, log_prob, advantages, eos_mask, cliprange): - """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 - - Args: - old_log_prob: `(torch.Tensor)` - shape: (bs, response_length) - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - advantages: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - cliprange: (float) - The clip range used in PPO. See https://arxiv.org/abs/1707.06347 - - Returns: - pg_loss: `a scalar torch.Tensor` - policy gradient loss computed via PPO - pg_clipfrac: (float) - a float number indicating the fraction of policy gradient loss being clipped - - """ - negative_approx_kl = log_prob - old_log_prob - ratio = torch.exp(negative_approx_kl) - ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask) - - pg_losses = -advantages * ratio - pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) - - pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask) - pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask) - return pg_loss, pg_clipfrac, ppo_kl - - -def compute_policy_loss_sft(log_prob, eos_mask): - """Simple way to compute SFT loss, unified with PG loss - - Args: - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - sft_loss: `a scalar torch.Tensor` - pg_clipfrac: dummy value, merely for compatibility - ppo_kl: dummy value, merely for compatibility - - """ - log_prob_diff = log_prob - log_prob.detach() - pg_clipfrac = verl_F.masked_mean(torch.gt(log_prob_diff, log_prob_diff).float(), eos_mask) - ppo_kl = verl_F.masked_mean(-log_prob_diff, eos_mask) - - sft_loss = verl_F.masked_mean(-log_prob, eos_mask) - - # Return pg_clipfrac and ppo_kl merely for compatibility with original compute_policy_loss - return sft_loss, pg_clipfrac, ppo_kl - - -def compute_entropy_loss(logits, eos_mask): - """Compute Categorical entropy loss - - Args: - logits: `(torch.Tensor)` - shape: (bs, response_length, vocab_size) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - entropy: a scalar torch.Tensor - - """ - # compute entropy - entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) - entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask) - return entropy_loss - - -def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value): - """Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 - - Args: - vpreds (`torch.FloatTensor`): - Predicted values of the value head, shape (`batch_size`, `response_length`) - values (`torch.FloatTensor`): - Old values of value head, shape (`batch_size`, `response_length`) - returns: (`torch.FloatTensor`): - Ground truth returns, shape (`batch_size`, `response_length`) - - Returns: - vf_loss: a scalar (`torch.FloatTensor`): - value function loss - vf_clipfrac: a float - The ratio of vf being clipped - - """ - vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) - vf_losses1 = (vpreds - returns) ** 2 - vf_losses2 = (vpredclipped - returns) ** 2 - vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask) - vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask) - return vf_loss, vf_clipfrac - - -def kl_penalty( - logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty -) -> torch.FloatTensor: - """Compute KL divergence given logprob and ref_logprob. - Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 - - Args: - logprob: - ref_logprob: - - Returns: - - """ - if kl_penalty == "kl": - return logprob - ref_logprob - - if kl_penalty == "abs": - return (logprob - ref_logprob).abs() - - if kl_penalty == "mse": - return 0.5 * (logprob - ref_logprob).square() - - # J. Schulman. Approximating kl divergence, 2020. - # # URL http://joschu.net/blog/kl-approx.html. - if kl_penalty == "low_var_kl": - kl = ref_logprob - logprob - ratio = torch.exp(kl) - kld = (ratio - kl - 1).contiguous() - return torch.clamp(kld, min=-10, max=10) - - if kl_penalty == "full": - # so, here logprob and ref_logprob should contain the logits for every token in vocabulary - raise NotImplementedError - - raise NotImplementedError diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 246cd1f21c..e7eb34ea17 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -1,4 +1,6 @@ # Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,232 +14,57 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Modified from dp_actor.py +Single Process Actor. +Modified from https://github.com/volcengine/verl/blob/0758489422e8d41a89e6c36d4c477714520f0dcc/verl/workers/actor/dp_actor.py """ import itertools -from typing import Tuple +import logging +import os import torch -import verl.utils.torch_functional as verl_F -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from torch import nn -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from verl import DataProto +from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_torch_device from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches -from verl.utils.torch_functional import logprobs_from_logits, masked_mean -from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs -from verl.workers.actor import BasePPOActor +from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor -from trinity.common.constants import AlgorithmType -from trinity.trainer.verl import core_algos +from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn +from trinity.algorithm.kl_fn.kl_fn import DummyKLFn +from trinity.algorithm.utils import prefix_metrics +from trinity.common.config import AlgorithmConfig __all__ = ["DataParallelPPOActor"] +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -class DataParallelPPOActor(BasePPOActor): + +class DataParallelPPOActor(DPActor): def __init__( - self, - config, - actor_module: nn.Module, - actor_optimizer: torch.optim.Optimizer = None, + self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None ): """When optimizer is None, it is Reference Policy""" - super().__init__(config) - self.actor_module = actor_module - self.actor_optimizer = actor_optimizer - self.use_remove_padding = self.config.get("use_remove_padding", False) - print(f"Actor use_remove_padding={self.use_remove_padding}") - self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size - self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 - - self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) - self.algorithm_type = AlgorithmType.PPO - - def set_mode(self, algorithm_type: AlgorithmType = AlgorithmType.PPO): - self.algorithm_type = algorithm_type - - def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Returns: - entropy: # (bs, response_len) - log_probs: # (bs, response_len) - """ - response_length = micro_batch["responses"].size(-1) - multi_modal_inputs = {} - if "multi_modal_inputs" in micro_batch: - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 - ) - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) - - if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - if position_ids.dim() == 3: - position_ids_rmpad = ( - index_first_axis( - rearrange(position_ids, "c b s ... -> (b s) c ..."), indices - ) - .transpose(0, 1) - .unsqueeze(1) - ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) - else: - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # for compute the log_prob - input_ids_rmpad_rolled = torch.roll( - input_ids_rmpad, shifts=-1, dims=1 - ) # (1, total_nnz) - - # pad and slice the inputs if sp > 1 - if self.use_ulysses_sp: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, - position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size - ) - - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( - 0 - ) # ((total_nnz / sp) + pad) - - # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.actor_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - **multi_modal_inputs, - use_cache=False, - ) # prevent model thinks we are generating - logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - if self.algorithm_type.is_sft(): # SFT - loss_fct = nn.CrossEntropyLoss(reduction="none") - loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) - if self.use_ulysses_sp: - loss = gather_outpus_and_unpad( - loss, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - response_mask = attention_mask[:, -response_length:].bool() - # pad back to (bsz, seqlen) - full_loss = pad_input( - hidden_states=loss.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ).squeeze(-1) - full_loss = torch.where( - response_mask, full_loss[:, -response_length - 1 : -1], 0.0 - ) - full_loss = full_loss.sum(-1) / response_mask.sum(-1) - full_loss = full_loss.mean() - return full_loss - - logits_rmpad.div_(temperature) - - # compute entropy - entropy_rmpad = self.compute_entropy_from_logits( - logits_rmpad - ) # ((total_nnz / sp) + pad) - - # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) - - # gather log_prob if sp > 1 - if self.use_ulysses_sp: - # gather and unpad for the ulysses sp - log_probs = gather_outpus_and_unpad( - log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - entropy_rmpad = gather_outpus_and_unpad( - entropy_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - # pad back to (bsz, seqlen) - full_entropy = pad_input( - hidden_states=entropy_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - full_log_probs = pad_input( - hidden_states=log_probs.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - - # only return response part: - entropy = full_entropy.squeeze(-1)[ - :, -response_length - 1 : -1 - ] # (bsz, response_length) - log_probs = full_log_probs.squeeze(-1)[ - :, -response_length - 1 : -1 - ] # (bsz, response_length) - - else: # not using rmpad and no ulysses sp - output = self.actor_module( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **multi_modal_inputs, - use_cache=False, - ) # prevent model thinks we are generating - logits = output.logits - if self.algorithm_type.is_sft(): - loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-100) - response_mask = attention_mask[:, -response_length:].bool() - response_labels = torch.where( - response_mask, input_ids[:, -response_length:], -100 - ) - response_logits = logits[:, -response_length - 1 : -1, :] - loss = loss_fct( - response_logits.reshape(-1, response_logits.shape[-1]), - response_labels.reshape(-1), - ) - loss = loss.view(response_labels.shape) - loss = loss.sum(-1) / response_mask.sum(-1) - loss = loss.mean() - return loss - logits.div_(temperature) - logits = logits[ - :, -response_length - 1 : -1, : - ] # (bsz, response_length, vocab_size) - log_probs = logprobs_from_logits(logits, micro_batch["responses"]) - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) - - return entropy, log_probs - - def _optimizer_step(self): - assert self.config.grad_clip is not None - - if isinstance(self.actor_module, FSDP): - grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) - else: - grad_norm = torch.nn.utils.clip_grad_norm_( - self.actor_module.parameters(), max_norm=self.config.grad_clip - ) - self.actor_optimizer.step() - return grad_norm - - def compute_log_prob(self, data: DataProto) -> torch.Tensor: + super().__init__(config, actor_module, actor_optimizer) + + self.policy_loss_fn = None + self.kl_loss_fn = None + self.entropy_loss_fn = None + + def set_algorithm(self, algorithm_config: AlgorithmConfig): + self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( + backend="verl", **algorithm_config.policy_loss_fn_args + ) + self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args) + self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( + **algorithm_config.entropy_loss_fn_args + ) + + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor: """Compute the log probability of the responses given input_ids, attention_mask and position_ids Args: @@ -261,7 +88,7 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: micro_batch_size = data.meta_info["micro_batch_size"] temperature = data.meta_info[ "temperature" - ] # temperature must be in the data.meta_info to avoid slient error + ] # temperature must be in the data.meta_info to avoid silent error use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] @@ -284,98 +111,69 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: micro_batches = batch.split(micro_batch_size) log_probs_lst = [] + entropy_lst = [] for micro_batch in micro_batches: if isinstance(micro_batch, DataProto): micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} - with torch.no_grad(): - _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature) + entropy, log_probs = self._forward_micro_batch( + micro_batch, temperature=temperature, calculate_entropy=calculate_entropy + ) log_probs_lst.append(log_probs) - log_probs = torch.concat(log_probs_lst, dim=0) + if calculate_entropy: + entropy_lst.append(entropy) + log_probs = torch.concat(log_probs_lst, dim=0) + entropys = None + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) log_probs = log_probs[revert_indices] + if calculate_entropy: + entropys = entropys[revert_indices] # type: ignore - return log_probs + return log_probs, entropys - def update_policy(self, data: DataProto): # noqa: C901 + @GPUMemoryLogger(role="dp actor", logger=logger) + def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() temperature = data.meta_info[ "temperature" - ] # temperature must be in the data.meta_info to avoid slient error - - algorithm_type: AlgorithmType = self.config.get("algorithm_type", AlgorithmType.PPO) - if self.algorithm_type.is_rft(): - select_keys = [ - "responses", - "input_ids", - "attention_mask", - "position_ids", - "old_log_probs", - "advantages", - "response_mask", - ] - if self.config.use_kl_loss: - select_keys.append("ref_log_prob") - - if algorithm_type == AlgorithmType.PAIRWISE_OPMD: - select_keys.append("token_level_scores") - elif self.algorithm_type.is_dpo(): - select_keys = [ - "attention_mask", - "input_ids", - "position_ids", - "response_mask", - "responses", - "ref_log_prob", - ] - else: # sft - select_keys = [ - "attention_mask", - "input_ids", - "position_ids", - "response_mask", - "responses", - ] - use_uid = self.config.get("use_uid", False) - + ] # temperature must be in the data.meta_info to avoid silent error + select_keys = [ + "input_ids", + "position_ids", + "attention_mask", + "responses", + "response_mask", + ] + select_keys.extend(self.policy_loss_fn.select_keys) + if not isinstance(self.kl_loss_fn, DummyKLFn): + select_keys.append("ref_log_prob") + select_keys = list(set(select_keys)) batch = data.select(batch_keys=select_keys).batch has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs or ((algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid): - # TODO: for now, we treat algorithm_type == AlgorithmType.PAIRWISE_OPMD in the same way that - # has_multi_modal_inputs was treated originally (to handle non_tensor_select_keys); - # need to double check if this is the best approach. + if has_multi_modal_inputs: num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = [] - if has_multi_modal_inputs: - non_tensor_select_keys.append("multi_modal_inputs") - if (algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid: - non_tensor_select_keys.append("uid") + non_tensor_select_keys = ["multi_modal_inputs"] dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) else: dataloader = batch.split(self.config.ppo_mini_batch_size) - # TODO: for pairwise_opmd and use_uid, is it necessary to somehow sort samples within batch by uid, - # to ensure that there are samples with the same uid within each micro-batch - # (at which level pairwise loss is computed)? - # (In comparison, advantage is computed at the level of batch, same for opmd, grpo, etc.) - metrics = {} for epoch in range(self.config.ppo_epochs): for batch_idx, data in enumerate(dataloader): # split batch into micro_batches mini_batch = data - if has_multi_modal_inputs or ( - (algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid - ): + if has_multi_modal_inputs: self.gradient_accumulation = ( self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu ) @@ -402,103 +200,64 @@ def update_policy(self, data: DataProto): # noqa: C901 self.actor_optimizer.zero_grad() for data in micro_batches: + micro_batch_metrics = {} + # Support all hardwares if isinstance(data, DataProto): data = { - **data.batch.to(torch.cuda.current_device()), + **data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch, } else: data = data.to( - torch.cuda.current_device() + get_torch_device().current_device() ) # actor device is cpu when using offload + responses = data["responses"] + response_length = responses.size(1) + attention_mask = data["attention_mask"] + response_mask = data["response_mask"] + assert response_mask.shape == attention_mask[:, -response_length:].shape + + # all return: (bsz, response_length) + calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn + entropy, log_prob = self._forward_micro_batch( + micro_batch=data, + temperature=temperature, + calculate_entropy=calculate_entropy, + ) - # TODO: it is better to unify the returns of several modes (sft, dpo) - if self.algorithm_type.is_sft(): - policy_loss = self._forward_micro_batch( - micro_batch=data, temperature=temperature - ) - - elif self.algorithm_type.is_dpo(): - response_mask = data["response_mask"] - - _, log_prob = self._forward_micro_batch( - micro_batch=data, temperature=temperature - ) - if self.config.use_kl_loss: - ref_log_prob = data["ref_log_prob"] - else: - ref_log_prob = None - - ( - policy_loss, - chosen_reward, - rejected_reward, - ) = core_algos.compute_policy_loss_dpo( - log_prob=log_prob, - ref_log_prob=ref_log_prob, - eos_mask=response_mask, - beta=self.config.kl_loss_coef, - # label_smoothing=self.config.label_smoothing # TODO: add configs for dpo - ) - - else: # rft - responses = data["responses"] - response_length = responses.size(1) - attention_mask = data["attention_mask"] - # response_mask = attention_mask[:, -response_length:] - response_mask = data["response_mask"] - assert response_mask.shape == attention_mask[:, -response_length:].shape - old_log_prob = data["old_log_probs"] - advantages = data["advantages"] - - clip_ratio = self.config.clip_ratio - entropy_coeff = self.config.entropy_coeff - - tau = self.config.get("tau", 1.0) - token_level_scores = None - index = None - if algorithm_type == AlgorithmType.PAIRWISE_OPMD: - token_level_scores = data["token_level_scores"] - if use_uid: - index = data["uid"] - - # all return: (bsz, response_length) - entropy, log_prob = self._forward_micro_batch( - micro_batch=data, temperature=temperature - ) - - pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - eos_mask=response_mask, - algorithm_type=algorithm_type, - advantages=advantages, - cliprange=clip_ratio, - # for opmd / pairwise_opmd - tau=tau, - token_level_scores=token_level_scores, - index=index, - ) - # compute entropy loss from entropy - entropy_loss = verl_F.masked_mean(entropy, response_mask) - - # compute policy loss - policy_loss = pg_loss - entropy_loss * entropy_coeff - - if self.config.use_kl_loss: - ref_log_prob = data["ref_log_prob"] - # compute kl loss - kld = core_algos.kl_penalty( - logprob=log_prob, - ref_logprob=ref_log_prob, - kl_penalty=self.config.kl_loss_type, - ) - kl_loss = masked_mean(kld, response_mask) - - policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - metrics["actor/kl_loss"] = kl_loss.detach().item() - metrics["actor/kl_coef"] = self.config.kl_loss_coef + pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore + logprob=log_prob, **data + ) + prefix_metrics( + src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics + ) + + # compute entropy loss from entropy + entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore + entropy=entropy, + action_mask=response_mask, + ) + prefix_metrics( + src_metrics=entropy_loss_metrics, + prefix="actor", + dst_metrics=micro_batch_metrics, + ) + + # compute policy loss + policy_loss = pg_loss - entropy_loss + + kl_loss, kl_loss_metrics = self.kl_loss_fn.calculate_kl_loss( + logprob=log_prob, + ref_logprob=data.get("ref_log_prob", None), + response_mask=response_mask, + ) + prefix_metrics( + src_metrics=kl_loss_metrics, + prefix="actor", + dst_metrics=micro_batch_metrics, + ) + policy_loss = policy_loss + kl_loss if self.config.use_dynamic_bsz: # relative to the dynamic bsz @@ -507,32 +266,10 @@ def update_policy(self, data: DataProto): # noqa: C901 loss = policy_loss / self.gradient_accumulation loss.backward() - if self.algorithm_type.is_rft(): - data = { - "actor/entropy_loss": entropy_loss.detach().item(), - "actor/pg_loss": pg_loss.detach().item(), - "actor/pg_clipfrac": pg_clipfrac.detach().item(), - "actor/ppo_kl": ppo_kl.detach().item(), - } - elif self.algorithm_type.is_dpo(): - data = { - "dpo/loss": policy_loss.detach().item(), - "dpo/loss_mean": loss.detach().item(), - "dpo/chosen_reward": chosen_reward.detach().mean().item(), - "dpo/rejected_reward": rejected_reward.detach().mean().item(), - "dpo/accuracy_mean": (chosen_reward > rejected_reward) - .float() - .mean() - .item(), - } - else: - data = { - "sft/loss": loss.detach().item(), - } - append_to_dict(metrics, data) + append_to_dict(metrics, micro_batch_metrics) grad_norm = self._optimizer_step() data = {"actor/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + append_to_dict(metrics, data) self.actor_optimizer.zero_grad() return metrics diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 26b640e871..cbc88902a0 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -12,73 +12,74 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -The main entry point to run the PPO algorithm +The main entry point to run the PPO algorithm. +Modified from https://github.com/volcengine/verl/blob/0758489422e8d41a89e6c36d4c477714520f0dcc/verl/workers/fsdp_workers.py """ +import json import logging import os import warnings +from dataclasses import asdict import psutil import torch import torch.distributed -import verl.utils.torch_functional as verl_F +import torch.distributed as dist +import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically. from codetiming import Timer from omegaconf import DictConfig, open_dict +from peft import LoraConfig, TaskType, get_peft_model +from safetensors.torch import save_file from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import FlatParameter from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import FSDP_PREFIX from verl import DataProto +from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_processor, hf_tokenizer +from verl.utils.activation_offload import enable_activation_offloading from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.debug import log_gpu_memory_usage +from verl.utils.device import get_torch_device, is_cuda_available from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + apply_fsdp2, + fsdp2_load_full_state_dict, + fsdp_version, get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, + layered_summon_lora_params, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer, ) from verl.utils.import_utils import import_external_libs -from verl.utils.model import compute_position_id_with_mask +from verl.utils.py_functional import convert_to_regular_types +from verl.workers.fsdp_workers import ( + create_device_mesh, + device_name, + get_sharding_strategy, +) from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager -from trinity.common.constants import AlgorithmType, SyncMethod +from trinity.common.config import AlgorithmConfig +from trinity.common.constants import ( + EXPLORER_NAME, + ROLLOUT_WEIGHT_SYNC_GROUP_NAME, + SyncMethod, +) from trinity.utils.distributed import init_process_group, is_ipv6_address logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) - - -def create_device_mesh(world_size, fsdp_size): - if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) - else: - device_mesh = init_device_mesh( - "cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] - ) - return device_mesh - - -def get_sharding_strategy(device_mesh): - from torch.distributed.fsdp import ShardingStrategy - - if device_mesh.ndim == 1: - sharding_strategy = ShardingStrategy.FULL_SHARD - elif device_mesh.ndim == 2: - sharding_strategy = ShardingStrategy.HYBRID_SHARD - else: - raise NotImplementedError( - f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2" - ) - return sharding_strategy +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) class ActorRolloutRefWorker(Worker): @@ -93,7 +94,13 @@ def __init__(self, config: DictConfig, role: str): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl" if is_cuda_available else "cpu:gloo,npu:hccl", + rank=rank, + world_size=world_size, + ) # build device mesh for FSDP world_size = torch.distributed.get_world_size() @@ -110,18 +117,19 @@ def __init__(self, config: DictConfig, role: str): dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - "cuda", + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"], ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self._lora_rank = self.config.model.get("lora_rank", 0) + self._is_lora = self._lora_rank > 0 self.role = role assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] - self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] self._is_ref = self.role in ["ref", "actor_rollout_ref"] self._is_offload_param = False @@ -152,6 +160,8 @@ def __init__(self, config: DictConfig, role: str): self.config.actor.ppo_micro_batch_size_per_gpu = ( self.config.actor.ppo_micro_batch_size ) + + if self.config.actor.ppo_micro_batch_size_per_gpu is not None: assert ( self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu @@ -163,14 +173,6 @@ def __init__(self, config: DictConfig, role: str): > 0 ), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" - # normalize rollout config - if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: - self.config.rollout.log_prob_micro_batch_size //= ( - self.device_mesh.size() // self.ulysses_sequence_parallel_size - ) - self.config.rollout.log_prob_micro_batch_size_per_gpu = ( - self.config.rollout.log_prob_micro_batch_size - ) # normalize ref config if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: self.config.ref.log_prob_micro_batch_size //= ( @@ -180,22 +182,22 @@ def __init__(self, config: DictConfig, role: str): self.config.ref.log_prob_micro_batch_size ) - def _build_model_optimizer( + def _build_model_optimizer( # noqa: C901 self, model_path, fsdp_config, optim_config, override_model_config, use_remove_padding=False, + use_fused_kernels=False, enable_gradient_checkpointing=False, trust_remote_code=False, use_liger=False, role="actor", + enable_activation_offload=False, ): from torch import optim - from torch.distributed.fsdp import CPUOffload - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import MixedPrecision + from torch.distributed.fsdp import CPUOffload, MixedPrecision from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -210,8 +212,8 @@ def _build_model_optimizer( assert role in ["actor", "ref"] - log_gpu_memory_usage("Before init from HF AutoModel", logger=logger) - local_path = copy_to_local(model_path) + log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger) + local_path = model_path # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly @@ -226,9 +228,13 @@ def _build_model_optimizer( # override model kwargs actor_model_config = AutoConfig.from_pretrained( - local_path, trust_remote_code=trust_remote_code + local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" ) + # patch for kimi-vl + if getattr(actor_model_config, "model_type", None) == "kimi_vl": + actor_model_config.text_config.topk_method = "greedy" + self.generation_config = get_generation_config( local_path, trust_remote_code=trust_remote_code ) @@ -259,17 +265,9 @@ def _build_model_optimizer( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, - attn_implementation="flash_attention_2", trust_remote_code=trust_remote_code, ) - if use_remove_padding or self.ulysses_sequence_parallel_size > 1: - from verl.models.transformers.monkey_patch import apply_monkey_patch - - apply_monkey_patch( - model=actor_module, ulysses_sp_size=self.ulysses_sequence_parallel_size - ) - # Apply Liger kernel to the model if use_liger is set to True if use_liger: from liger_kernel.transformers.monkey_patch import ( @@ -278,6 +276,13 @@ def _build_model_optimizer( _apply_liger_kernel_to_instance(model=actor_module) + apply_monkey_patch( + model=actor_module, + use_remove_padding=use_remove_padding, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + use_fused_kernels=use_fused_kernels, + ) + # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 actor_module.to(torch_dtype) @@ -285,12 +290,24 @@ def _build_model_optimizer( actor_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) + if self._is_lora: + print("Applying LoRA to actor module") + actor_module.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", + } + actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) torch.distributed.barrier() if self.rank == 0: print_model_size(actor_module) - log_gpu_memory_usage("After init from HF AutoModel", logger=logger) + log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger) # We wrap FSDP for rollout as well mixed_precision_config = fsdp_config.get("mixed_precision", None) @@ -312,14 +329,13 @@ def _build_model_optimizer( ) auto_wrap_policy = get_fsdp_wrap_policy( - module=actor_module, config=fsdp_config.get("wrap_policy", None) + module=actor_module, + config=fsdp_config.get("wrap_policy", None), + is_lora=self.config.model.get("lora_rank", 0) > 0, ) - if self._is_rollout and self.config.rollout.name == "hf": - # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma - auto_wrap_policy = None - - print(f"wrap_policy: {auto_wrap_policy}") + if self.rank == 0: + print(f"wrap_policy: {auto_wrap_policy}") fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) @@ -328,139 +344,106 @@ def _build_model_optimizer( # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) - actor_module_fsdp = FSDP( - actor_module, - cpu_offload=cpu_offload, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, # zero3 - mixed_precision=mixed_precision, - sync_module_states=True, - device_mesh=self.device_mesh, - forward_prefetch=False, - ) + fsdp_strategy = self.config.actor.strategy + if fsdp_strategy == "fsdp": + actor_module_fsdp = FSDP( + actor_module, + cpu_offload=cpu_offload, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_torch_device().current_device(), + sharding_strategy=sharding_strategy, # zero3 + mixed_precision=mixed_precision, + sync_module_states=True, + device_mesh=self.device_mesh, + forward_prefetch=False, + ) + elif fsdp_strategy == "fsdp2": + assert ( + CPUOffloadPolicy is not None + ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) + if role == "actor" and fsdp_config.offload_policy: + cpu_offload = CPUOffloadPolicy(pin_memory=True) + self._is_offload_param = False + self._is_offload_optimizer = False + else: + cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": cpu_offload, + "reshard_after_forward": fsdp_config.reshard_after_forward, + } + full_state = actor_module.state_dict() + apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload) + actor_module_fsdp = actor_module + else: + raise NotImplementedError(f"not implement {fsdp_strategy}") + + if enable_activation_offload: + enable_activation_offloading( + actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing + ) - log_gpu_memory_usage("After Actor FSDP init", logger=logger) + log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) # TODO: add more optimizer args into config if role == "actor" and optim_config is not None: - beta1 = optim_config.get("beta1", 0.9) - beta2 = optim_config.get("beta2", 0.999) + from verl.utils.torch_functional import ( + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + ) + actor_optimizer = optim.AdamW( actor_module_fsdp.parameters(), lr=optim_config.lr, - betas=(beta1, beta2), + betas=optim_config.get("betas", (0.9, 0.999)), weight_decay=optim_config.get("weight_decay", 1e-2), ) total_steps = optim_config.get("total_training_steps", 0) num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) + warmup_style = optim_config.get("warmup_style", "constant") + min_lr_ratio = optim_config.get("min_lr_ratio", 0.0) + num_cycles = optim_config.get("num_cycles", 0.5) if num_warmup_steps < 0: num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") - - if optim_config.warmup_style == "constant": - from verl.utils.torch_functional import ( - get_constant_schedule_with_warmup, - ) + if self.rank == 0: + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + if warmup_style == "constant": actor_lr_scheduler = get_constant_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps ) - elif optim_config.warmup_style == "cosine": - from verl.utils.torch_functional import get_cosine_schedule_with_warmup - - assert ( - total_steps > 0 - ), "Cosine scheduler of actor requires total_training_steps > 0" + elif warmup_style == "cosine": actor_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, - min_lr_ratio=optim_config.min_lr_ratio, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, ) else: - raise NotImplementedError( - f"Lr scheduler style {optim_config.warmup_style} is not supported" - ) + raise NotImplementedError(f"Warmup style {warmup_style} is not supported") + + log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) else: actor_optimizer = None actor_lr_scheduler = None - log_gpu_memory_usage("After actor optimizer init", logger=logger) - return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config - def _build_rollout(self): - from torch.distributed.device_mesh import init_device_mesh - - # TODO(sgm): support FSDP hybrid shard for larger model - infer_tp = self.config.rollout.tensor_model_parallel_size - dp = self.world_size // infer_tp - assert ( - self.world_size % infer_tp == 0 - ), f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - rollout_device_mesh = init_device_mesh( - "cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] - ) - - if self.config.rollout.name == "hf": - from verl.workers.rollout import HFRollout - from verl.workers.sharding_manager import BaseShardingManager - - rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout) - rollout_sharding_manager = BaseShardingManager() - # TODO: a sharding manager that do nothing? - elif self.config.rollout.name == "vllm": - if self.config.rollout.use_fire_sampling: - from verl.workers.rollout.vllm_rollout import ( - FIREvLLMRollout as vLLMRollout, - ) - from verl.workers.rollout.vllm_rollout import vllm_mode - else: - from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode - from verl.workers.sharding_manager import FSDPVLLMShardingManager - - log_gpu_memory_usage("Before building vllm rollout", logger=None) - local_path = copy_to_local(self.config.model.path) - if vllm_mode == "customized": - rollout = vLLMRollout( - actor_module=self.actor_module_fsdp, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - ) - elif vllm_mode == "spmd": - rollout = vLLMRollout( - model_path=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - device_mesh=rollout_device_mesh, - ) - else: - raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'") - log_gpu_memory_usage("After building vllm rollout", logger=None) - if torch.distributed.get_world_size() == 1: - self.config.rollout.load_format = "dummy_hf" - rollout_sharding_manager = FSDPVLLMShardingManager( - module=self.actor_module_fsdp, - inference_engine=rollout.inference_engine, - model_config=self.actor_model_config, - full_params="hf" in self.config.rollout.load_format, - device_mesh=rollout_device_mesh, - ) - log_gpu_memory_usage("After building sharding manager", logger=None) - - return rollout, rollout_sharding_manager - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - from .dp_actor import DataParallelPPOActor + from trinity.trainer.verl.dp_actor import DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) @@ -472,36 +455,43 @@ def init_model(self): ) use_remove_padding = self.config.model.get("use_remove_padding", False) + use_shm = self.config.model.get("use_shm", False) + use_fused_kernels = self.config.model.get("use_fused_kernels", False) - if self._is_actor or self._is_rollout: + if self._is_actor: # we need the model for actor and rollout - if self._is_actor: - optim_config = self.config.actor.optim - fsdp_config = self.config.actor.fsdp_config - else: - optim_config = None - fsdp_config = OmegaConf.create() + optim_config = self.config.actor.optim + fsdp_config = self.config.actor.fsdp_config + + local_path = copy_to_local(self.config.model.path, use_shm=use_shm) ( self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config, ) = self._build_model_optimizer( - model_path=self.config.model.path, + model_path=local_path, fsdp_config=fsdp_config, optim_config=optim_config, override_model_config=override_model_config, use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, enable_gradient_checkpointing=self.config.model.get( "enable_gradient_checkpointing", False ), trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="actor", + enable_activation_offload=self.config.model.get("enable_activation_offload", False), ) # get the original unwrapped module - self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + if fsdp_version(self.actor_module_fsdp) == 1: + self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during init", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) @@ -511,22 +501,22 @@ def init_model(self): OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding + self.config.actor.use_fused_kernels = use_fused_kernels self.actor = DataParallelPPOActor( config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer, ) - if self._is_rollout: - self.rollout, self.rollout_sharding_manager = self._build_rollout() - if self._is_ref: + local_path = copy_to_local(self.config.model.path, use_shm=use_shm) self.ref_module_fsdp = self._build_model_optimizer( - model_path=self.config.model.path, + model_path=local_path, fsdp_config=self.config.ref.fsdp_config, optim_config=None, override_model_config=override_model_config, use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="ref", @@ -534,6 +524,7 @@ def init_model(self): OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding + self.config.ref.use_fused_kernels = use_fused_kernels self.ref_policy = DataParallelPPOActor( config=self.config.ref, actor_module=self.ref_module_fsdp ) @@ -554,8 +545,6 @@ def init_model(self): checkpoint_contents=self.config.actor.checkpoint.contents, ) - torch.cuda.empty_cache() - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def setup_weight_sync_group(self): if ( @@ -576,7 +565,9 @@ def setup_weight_sync_group(self): realname = ( name_prefix[len(FSDP_PREFIX) :] + "." + name if name_prefix else name ) - self.state_dict_meta.append((realname, param.dtype, param.shape)) + self.state_dict_meta.append( + (realname, str(param.dtype), tuple(param.shape)) + ) param = None torch.cuda.empty_cache() @@ -586,8 +577,7 @@ def setup_weight_sync_group(self): master_address, master_port = self.get_availale_master_addr_port() world_size = self.config.synchronizer.explorer_world_size + 1 print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") - explorer = ray.get_actor("explorer") - group_name = "rollout_weight_sync" + explorer = ray.get_actor(EXPLORER_NAME) setup_ref = explorer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) @@ -604,7 +594,7 @@ def setup_weight_sync_group(self): timeout=timeout, world_size=world_size, rank=0, - group_name=group_name, + group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, ) ray.get(setup_ref) @@ -620,27 +610,27 @@ def sync_weight(self): continue torch.distributed.broadcast(param, 0, group=self._model_update_group) param = None + torch.distributed.barrier() + torch.cuda.synchronize() torch.cuda.empty_cache() @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def set_mode(self, algo_type: AlgorithmType = AlgorithmType.PPO): - self.actor.set_mode(algo_type) + def set_algorithm(self, algo_config: AlgorithmConfig): + self.actor.set_algorithm(algo_config) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: load_fsdp_optimizer( - optimizer=self.actor_optimizer, device_id=torch.cuda.current_device() + optimizer=self.actor_optimizer, device_id=get_torch_device().current_device() ) - log_gpu_memory_usage("Before update policy", logger=logger) - with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) # perform training @@ -654,17 +644,17 @@ def update_actor(self, data: DataProto): metrics["perf/mfu/actor"] = ( estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size ) - metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / ( + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / ( + 1024**3 + ) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / ( 1024**3 ) - metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) - self.actor_lr_scheduler.step() lr = self.actor_lr_scheduler.get_last_lr()[0] metrics["actor/lr"] = lr - - log_gpu_memory_usage("After update policy", logger=logger) + self.actor_lr_scheduler.step() # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) @@ -674,60 +664,27 @@ def update_actor(self, data: DataProto): if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during update_actor", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def generate_sequences(self, prompts: DataProto): - # Support all hardwares - prompts = prompts.to(torch.cuda.current_device()) - - assert self._is_rollout - if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - - meta_info = { - "eos_token_id": self.generation_config.eos_token_id - if self.generation_config is not None - else self.tokenizer.eos_token_id, - "pad_token_id": self.generation_config.pad_token_id - if self.generation_config is not None - else self.tokenizer.pad_token_id, - } - prompts.meta_info.update(meta_info) - with self.rollout_sharding_manager: - # after parameters sync with rollout, offload actor model to CPU - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.actor_optimizer) - - log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) - - prompts = self.rollout_sharding_manager.preprocess_data(prompts) - output = self.rollout.generate_sequences(prompts=prompts) - - log_gpu_memory_usage("After rollout generation", logger=logger) - - output = self.rollout_sharding_manager.postprocess_data(output) - - output = output.to("cpu") - - # clear kv cache - torch.cuda.empty_cache() - log_gpu_memory_usage("After recompute log prob", logger=logger) - return output - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_log_prob(self, data: DataProto): + # when is_lora is True, we use the actor without lora applied to calculate the log_prob + # which is mostly used for ref log_prob calculation assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares - data = data.to(torch.cuda.current_device()) + from contextlib import nullcontext + + is_lora = data.meta_info.pop("is_lora", False) + adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() + data = data.to(get_torch_device().current_device()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu @@ -736,9 +693,10 @@ def compute_log_prob(self, data: DataProto): # perform recompute log_prob with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) - output = self.actor.compute_log_prob(data=data) + with adapter_ctx: + output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) output = DataProto.from_dict( - tensors={"old_log_probs": output}, + tensors={"old_log_probs": output, "entropys": entropys}, meta_info={"temperature": self.config.rollout.temperature}, ) output = self.ulysses_sharding_manager.postprocess_data(output) @@ -747,21 +705,29 @@ def compute_log_prob(self, data: DataProto): # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module - if self.world_size > 1: + if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1: self.actor.actor_module._handle.reshard(True) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger) - log_gpu_memory_usage("After compute_log_prob", logger=logger) return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): + if self._is_lora: + # if _is_lora, actor without lora applied is the ref + data.meta_info["is_lora"] = True + data = self.compute_log_prob(data) + # this old_log_probs is in fact ref_log_prob + data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]}) + return data assert self._is_ref - + # else: + # otherwise, the class have a standalone ref model # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -770,7 +736,7 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) - output = self.ref_policy.compute_log_prob(data=data) + output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = self.ulysses_sharding_manager.postprocess_data(output) @@ -778,17 +744,15 @@ def compute_ref_log_prob(self, data: DataProto): # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module - if self.world_size > 1: + if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1: self.ref_policy.actor_module._handle.reshard(True) - torch.cuda.empty_cache() return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): # only support save and load ckpt for actor assert self._is_actor - import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) @@ -799,8 +763,42 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep, ) + dist.barrier() + + if self._is_lora and hasattr( + getattr(self, "actor_module", self.actor_module_fsdp), "peft_config" + ): + lora_save_path = os.path.join(local_path, "lora_adapter") + peft_model = getattr(self, "actor_module", self.actor_module_fsdp) + peft_config = {} + if dist.get_rank() == 0: + os.makedirs(lora_save_path, exist_ok=True) + peft_config = asdict(peft_model.peft_config.get("default", {})) + peft_config["task_type"] = peft_config["task_type"].value + peft_config["peft_type"] = peft_config["peft_type"].value + peft_config["target_modules"] = list(peft_config["target_modules"]) + try: + if fsdp_version(self.actor_module_fsdp) > 0: + self.actor_module_fsdp = self.actor_module_fsdp.cuda() + lora_params = layered_summon_lora_params(self.actor_module_fsdp) + if dist.get_rank() == 0: + save_file( + lora_params, os.path.join(lora_save_path, "adapter_model.safetensors") + ) + with open( + os.path.join(lora_save_path, "adapter_config.json"), + "w", + encoding="utf-8", + ) as f: + json.dump(peft_config, f, ensure_ascii=False, indent=4) + except Exception as e: + if dist.get_rank() == 0: + print(f"[rank-{self.rank}]: Save LoRA Adapter Error ({e})") + + dist.barrier() + if dist.get_rank() == 0: + print(f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}") - torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) @@ -838,7 +836,7 @@ def __init__(self, config): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -853,7 +851,7 @@ def __init__(self, config): dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - "cuda", + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"], ) @@ -878,26 +876,29 @@ def __init__(self, config): ) self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size + + if self.config.ppo_micro_batch_size_per_gpu is not None: assert ( self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" assert ( self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0 ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + self._is_lora = self.config.model.get("lora_rank", 0) > 0 def _build_critic_model_optimizer(self, config): # the following line is necessary from torch import optim - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision from verl.utils.model import print_model_size from verl.utils.torch_dtypes import PrecisionType - local_path = copy_to_local(config.model.path) + use_shm = config.model.get("use_shm", False) + local_path = copy_to_local(config.model.path, use_shm=use_shm) # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info # using random initialized model from any architecture. May not be the same as Actor. - tokenizer_path = copy_to_local(config.model.tokenizer_path) + tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm) self.tokenizer = hf_tokenizer( tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False) ) @@ -924,11 +925,15 @@ def _build_critic_model_optimizer(self, config): from transformers import AutoConfig, AutoModelForTokenClassification - trust_remote_code = False critic_model_config = AutoConfig.from_pretrained( - local_path, trust_remote_code=trust_remote_code + local_path, + attn_implementation="flash_attention_2", + trust_remote_code=config.model.get("trust_remote_code", False), ) critic_model_config.num_labels = 1 + # patch for kimi-vl + if getattr(critic_model_config, "model_type", None) == "kimi_vl": + critic_model_config.text_config.topk_method = "greedy" init_context = get_init_weight_context_manager( use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh @@ -936,23 +941,22 @@ def _build_critic_model_optimizer(self, config): with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") - setattr(critic_model_config, "classifier_dropout", 0.0) - setattr(critic_model_config, "hidden_dropout", "0") + critic_model_config.classifier_dropout = 0.0 + critic_model_config.hidden_dropout = "0" critic_module = AutoModelForTokenClassification.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=critic_model_config, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, + trust_remote_code=config.model.get("trust_remote_code", False), ) use_remove_padding = config.model.get("use_remove_padding", False) - if use_remove_padding or self.ulysses_sequence_parallel_size > 1: - from verl.models.transformers.monkey_patch import apply_monkey_patch - apply_monkey_patch( - model=critic_module, ulysses_sp_size=self.ulysses_sequence_parallel_size - ) + apply_monkey_patch( + model=critic_module, + use_remove_padding=use_remove_padding, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + ) # some parameters may not in torch_dtype critic_module.to(torch_dtype) @@ -961,6 +965,20 @@ def _build_critic_model_optimizer(self, config): critic_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) + + if self._is_lora: + print("Applying LoRA to critic module") + critic_module.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", + } + critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) + if self.rank == 0: print_model_size(critic_module) @@ -986,7 +1004,9 @@ def _build_critic_model_optimizer(self, config): ) auto_wrap_policy = get_fsdp_wrap_policy( - module=critic_module, config=self.config.model.fsdp_config.wrap_policy + module=critic_module, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self.config.model.get("lora_rank", 0) > 0, ) log_gpu_memory_usage("Before critic FSDP", logger=None) @@ -995,59 +1015,87 @@ def _build_critic_model_optimizer(self, config): sharding_strategy = get_sharding_strategy(fsdp_mesh) # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation - critic_module = FSDP( - critic_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - sync_module_states=True, - forward_prefetch=False, - device_mesh=self.device_mesh, - cpu_offload=None, - ) + if config.strategy == "fsdp": + critic_module = FSDP( + critic_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_torch_device().current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None, + ) + elif config.strategy == "fsdp2": + assert ( + CPUOffloadPolicy is not None + ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) + offload_policy = None + if fsdp_config.offload_policy: + self._is_offload_param = False + self._is_offload_optimizer = False + offload_policy = CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": offload_policy, + "reshard_after_forward": fsdp_config.reshard_after_forward, + } + full_state = critic_module.state_dict() + apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy) + else: + raise NotImplementedError(f"Unknown strategy {config.strategy}") + + if config.model.get("enable_activation_offload", False): + enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) + enable_activation_offloading( + critic_module, config.strategy, enable_gradient_checkpointing + ) log_gpu_memory_usage("After critic FSDP", logger=None) - beta1 = config.optim.get("beta1", 0.9) - beta2 = config.optim.get("beta2", 0.999) critic_optimizer = optim.AdamW( critic_module.parameters(), lr=config.optim.lr, - betas=(beta1, beta2), + betas=config.optim.get("betas", (0.9, 0.999)), weight_decay=config.optim.get("weight_decay", 1e-2), ) total_steps = config.optim.get("total_training_steps", 0) num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) + warmup_style = config.optim.get("warmup_style", "constant") if num_warmup_steps < 0: num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + if self.rank == 0: + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") - if config.optim.warmup_style == "constant": - from verl.utils.torch_functional import get_constant_schedule_with_warmup + from verl.utils.torch_functional import ( + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + ) + if warmup_style == "constant": critic_lr_scheduler = get_constant_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps ) - elif config.optim.warmup_style == "cosine": - from verl.utils.torch_functional import get_cosine_schedule_with_warmup - - assert total_steps > 0, "Cosine scheduler of critic requires total_training_steps > 0" + elif warmup_style == "cosine": critic_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, - min_lr_ratio=config.optim.min_lr_ratio, ) else: - raise NotImplementedError( - f"Lr scheduler style {config.optim.warmup_style} is not supported" - ) + raise NotImplementedError(f"Warmup style {warmup_style} is not supported") return critic_module, critic_optimizer, critic_lr_scheduler @@ -1066,8 +1114,10 @@ def init_model(self): if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) + log_gpu_memory_usage("After offload critic model during init", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) + log_gpu_memory_usage("After offload critic optimizer during init", logger=logger) self.critic = DataParallelPPOCritic( config=self.config, @@ -1087,7 +1137,7 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -1110,12 +1160,12 @@ def compute_values(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: load_fsdp_optimizer( - optimizer=self.critic_optimizer, device_id=torch.cuda.current_device() + optimizer=self.critic_optimizer, device_id=get_torch_device().current_device() ) # perform forward computation @@ -1196,327 +1246,3 @@ def clear_optimizer_state(self): self.critic_optimizer.zero_grad() if self._is_offload_optimizer: offload_fsdp_optimizer(self.critic_optimizer) - - -# TODO(sgm): we may need to extract it to dp_reward_model.py -class RewardModelWorker(Worker): - """ - Note that we only implement the reward model that is subclass of AutoModelForTokenClassification. - """ - - def __init__(self, config): - super().__init__() - import torch.distributed - - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") - self.config = config - - # build device mesh for Ulysses Sequence Parallel - world_size = torch.distributed.get_world_size() - from torch.distributed.device_mesh import init_device_mesh - - fsdp_size = self.config.model.fsdp_config.fsdp_size - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) - - self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) - dp = world_size // self.ulysses_sequence_parallel_size - if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh( - "cuda", - mesh_shape=(dp, self.ulysses_sequence_parallel_size), - mesh_dim_names=["dp", "sp"], - ) - - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - - self.use_remove_padding = self.config.model.get("use_remove_padding", False) - - # normalize config - if self.config.micro_batch_size is not None: - self.config.micro_batch_size //= torch.distributed.get_world_size() - self.config.micro_batch_size_per_gpu = self.config.micro_batch_size - - def _build_model(self, config): - # the following line is necessary - from torch.distributed.fsdp import CPUOffload - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from transformers import AutoConfig, AutoModelForTokenClassification - - # download the checkpoint from hdfs - local_path = copy_to_local(config.model.path) - - if self.config.model.input_tokenizer is None: - self._do_switch_chat_template = False - else: - self._do_switch_chat_template = True - input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer) - self.input_tokenizer = hf_tokenizer( - input_tokenizer_local_path, - trust_remote_code=config.model.get("trust_remote_code", False), - ) - self.tokenizer = hf_tokenizer( - local_path, trust_remote_code=config.model.get("trust_remote_code", False) - ) - - trust_remote_code = config.model.get("trust_remote_code", False) - model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) - model_config.num_labels = 1 - - # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect - init_context = get_init_weight_context_manager( - use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh - ) - - with init_context(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - setattr(model_config, "classifier_dropout", 0.0) - reward_module = AutoModelForTokenClassification.from_pretrained( - pretrained_model_name_or_path=local_path, - config=model_config, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, - ) - - if ( - config.model.get("use_remove_padding", False) - or self.ulysses_sequence_parallel_size > 1 - ): - from verl.models.transformers.monkey_patch import apply_monkey_patch - - apply_monkey_patch( - model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size - ) - - reward_module.to(torch.bfloat16) - auto_wrap_policy = get_fsdp_wrap_policy( - module=reward_module, config=self.config.model.fsdp_config - ) - - fsdp_mesh = self.device_mesh - sharding_strategy = get_sharding_strategy(fsdp_mesh) - - reward_module = FSDP( - reward_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, # zero3 - sync_module_states=True, - cpu_offload=CPUOffload(offload_params=True), - forward_prefetch=False, - device_mesh=self.device_mesh, - ) - - return reward_module - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get("external_lib", None)) - self.reward_module = self._build_model(config=self.config) - - def _forward_micro_batch(self, micro_batch): - from flash_attn.bert_padding import ( - index_first_axis, - pad_input, - rearrange, - unpad_input, - ) - from verl.utils.ulysses import ( - gather_outpus_and_unpad, - ulysses_pad_and_slice_inputs, - ) - - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - - if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # pad and slice the inputs if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, - position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) - - # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.reward_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - use_cache=False, - ) # prevent model thinks we are generating - reward_rmpad = output.logits - reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) - - # gather output if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - reward_rmpad = gather_outpus_and_unpad( - reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - - # pad it back - rm_score = pad_input( - reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen - ).squeeze(-1) - else: - output = self.reward_module( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False, - ) - rm_score = output.logits # (batch_size, seq_len, 1) - rm_score = rm_score.squeeze(-1) - - # extract the result of the last valid token - eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) - rm_score = rm_score[torch.arange(batch_size), eos_mask_idx] - return rm_score - - def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): - batch_size = data.batch.batch_size[0] - # expand as token_level_reward - attention_mask = data.batch["attention_mask"] - position_ids = data.batch["position_ids"] - response_length = data.batch["responses"].shape[-1] - eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) - token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) - token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores - - # select the response part - token_level_scores = token_level_scores[:, -response_length:] - - return token_level_scores - - def _switch_chat_template(self, data: DataProto): - src_max_length = data.batch["attention_mask"].shape[-1] - - src_tokenizer = self.input_tokenizer - target_tokenizer = self.tokenizer - - rm_input_ids = [] - rm_attention_mask = [] - - for i in range(data.batch.batch_size[0]): - # extract raw prompt - chat: list = data.non_tensor_batch["raw_prompt"][i].tolist() - - # extract response - response_ids = data.batch["responses"][i] - response_length = response_ids.shape[-1] - valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() - valid_response_ids = response_ids[:valid_response_length] - - # decode - response = src_tokenizer.decode(valid_response_ids) - # remove bos and eos - response = response.replace(src_tokenizer.eos_token, "") - - chat.append({"role": "assistant", "content": response}) - - prompt_with_chat_template = target_tokenizer.apply_chat_template( - chat, add_generation_prompt=False, tokenize=False - ) - if self.rank == 0 and i == 0: - # for debugging purpose - print(f"Switch template. chat: {prompt_with_chat_template}") - - # the maximum length is actually determined by the reward model itself - max_length = self.config.get("max_length", src_max_length) - if max_length is None: - max_length = src_max_length - input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( - prompt=prompt_with_chat_template, - tokenizer=target_tokenizer, - max_length=max_length, - pad_token_id=target_tokenizer.pad_token_id, - left_pad=False, # right padding - truncation=self.config.get("truncation", "right"), - ) # truncate from the right - - rm_input_ids.append(input_ids) - rm_attention_mask.append(attention_mask) - - rm_input_ids = torch.cat(rm_input_ids, dim=0) - rm_attention_mask = torch.cat(rm_attention_mask, dim=0) - - rm_position_ids = compute_position_id_with_mask(rm_attention_mask) - - rm_inputs = { - "input_ids": rm_input_ids, - "attention_mask": rm_attention_mask, - "position_ids": rm_position_ids, - } - - return DataProto.from_dict(rm_inputs) - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_rm_score(self, data: DataProto): - import itertools - - from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches - - # Support all hardwares - data = data.to(torch.cuda.current_device()) - if self._do_switch_chat_template: - rm_data = self._switch_chat_template(data) - - # Support all hardwares - rm_data.batch = rm_data.batch.to(torch.cuda.current_device()) - - # perform forward computation - with self.ulysses_sharding_manager: - rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data) - data = self.ulysses_sharding_manager.preprocess_data(data=data) - - use_dynamic_bsz = self.config.use_dynamic_bsz - if use_dynamic_bsz: - max_token_len = ( - self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - ) - micro_batches, indices = rearrange_micro_batches( - batch=rm_data.batch, max_token_len=max_token_len - ) - else: - micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) - output = [] - for micro_batch in micro_batches: - rm_score = self._forward_micro_batch(micro_batch) - output.append(rm_score) - scores = torch.cat(output, dim=0) # (batch_size) - - if use_dynamic_bsz: - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - scores = scores[revert_indices] - - token_level_scores = self._expand_to_token_level(data, scores) - # Note that this is only the scores, may not be the final rewards used to train RL - output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) - - # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes - # unshard the root FSDP module - self.reward_module._handle.reshard(True) - - output = output.to("cpu") - return output diff --git a/trinity/trainer/verl/ray_trainer.py b/trinity/trainer/verl/ray_trainer.py deleted file mode 100644 index 7073319db0..0000000000 --- a/trinity/trainer/verl/ray_trainer.py +++ /dev/null @@ -1,1160 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Modified from ray_trainer.py -""" - -import os -import uuid -from contextlib import contextmanager -from copy import deepcopy -from dataclasses import dataclass, field -from enum import Enum -from pprint import pprint -from typing import Dict, Type - -import numpy as np -import ray -import torch -import tqdm -from codetiming import Timer -from omegaconf import OmegaConf, open_dict -from torch.utils.data import RandomSampler, SequentialSampler -from torchdata.stateful_dataloader import StatefulDataLoader -from verl import DataProto -from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.single_controller.base import Worker -from verl.single_controller.ray import ( - RayClassWithInitArgs, - RayResourcePool, - RayWorkerGroup, -) -from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.ppo.metric_utils import ( - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, - reduce_metrics, -) -from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path -from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn -from verl.utils.seqlen_balancing import ( - get_seqlen_balanced_partitions, - log_seqlen_unbalance, -) -from verl.utils.torch_functional import masked_mean -from verl.utils.tracking import ValidationGenerationsLogger - -from trinity.common.constants import AlgorithmType -from trinity.trainer.verl import core_algos - -WorkerType = Type[Worker] - - -class Role(Enum): - """ - To create more roles dynamically, you can subclass Role and add new members - """ - - Actor = 0 - Rollout = 1 - ActorRollout = 2 - Critic = 3 - RefPolicy = 4 - RewardModel = 5 - ActorRolloutRef = 6 - - -class AdvantageEstimator(str, Enum): - """ - Using an enumeration class to avoid spelling errors in adv_estimator - """ - - GAE = "gae" - GRPO = "grpo" - REINFORCE_PLUS_PLUS = "reinforce_plus_plus" - REMAX = "remax" - RLOO = "rloo" - - -@dataclass -class ResourcePoolManager: - """ - Define a resource pool specification. Resource pool will be initialized first. - Mapping - """ - - resource_pool_spec: dict[str, list[int]] - mapping: dict[Role, str] - resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) - - def create_resource_pool(self): - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool - # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. - # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, - use_gpu=True, - max_colocate_count=1, - name_prefix=resource_pool_name, - ) - self.resource_pool_dict[resource_pool_name] = resource_pool - - self._check_resource_available() - - def get_resource_pool(self, role: Role) -> RayResourcePool: - """Get the resource pool of the worker_cls""" - return self.resource_pool_dict[self.mapping[role]] - - def get_n_gpus(self) -> int: - """Get the number of gpus in this cluster.""" - return sum( - [ - n_gpus - for process_on_nodes in self.resource_pool_spec.values() - for n_gpus in process_on_nodes - ] - ) - - def _check_resource_available(self): - """Check if the resource pool can be satisfied in this ray cluster.""" - node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = { - node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items() - } - - # check total required gpus can be satisfied - total_available_gpus = sum(node_available_gpus.values()) - total_required_gpus = sum( - [ - n_gpus - for process_on_nodes in self.resource_pool_spec.values() - for n_gpus in process_on_nodes - ] - ) - if total_available_gpus < total_required_gpus: - raise ValueError( - f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" - ) - - # check each resource pool can be satisfied, O(#resource_pools * #nodes) - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes) - for node, available_gpus in node_available_gpus.items(): - if available_gpus >= num_gpus: - node_available_gpus[node] -= num_gpus - num_nodes -= 1 - if num_nodes == 0: - break - if num_nodes > 0: - raise ValueError( - f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this ray cluster" - ) - - -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): - responses = data.batch["responses"] - response_length = responses.size(1) - token_level_scores = data.batch["token_level_scores"] - batch_size = data.batch.batch_size[0] - attention_mask = data.batch["attention_mask"] - # response_mask = attention_mask[:, -response_length:] - response_mask = data.batch["response_mask"] - assert response_mask.shape == attention_mask[:, -response_length:].shape - - # compute kl between ref_policy and current policy - if "ref_log_prob" in data.batch.keys(): - kld = core_algos.kl_penalty( - data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty - ) # (batch_size, response_length) - kld = kld * response_mask - beta = kl_ctrl.value - else: - beta = 0 - kld = torch.zeros_like(response_mask, dtype=torch.float32) - - token_level_rewards = token_level_scores - beta * kld - - current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence - current_kl = torch.mean(current_kl, dim=0).item() - - # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 - kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) - data.batch["token_level_rewards"] = token_level_rewards - - metrics = {"critic/kl": current_kl, "critic/kl_coeff": beta} - - return data, metrics - - -def compute_response_mask(data: DataProto): - responses = data.batch["responses"] - response_length = responses.size(1) - attention_mask = data.batch["attention_mask"] - return attention_mask[:, -response_length:] - - -def compute_advantage(data: DataProto, **kwargs): - """Extend verl's original compute_advantage with OPMD""" - - algorithm_type: AlgorithmType = kwargs.get("algorithm_type", AlgorithmType.PPO) - - if algorithm_type == AlgorithmType.OPMD: - tau = kwargs.get("tau", 1.0) - opmd_baseline = kwargs.get("opmd_baseline", "mean") - - return compute_advantage_opmd( - data=data, - tau=tau, - opmd_baseline=opmd_baseline, - ) - - elif algorithm_type == AlgorithmType.PAIRWISE_OPMD: - data.batch["advantages"] = None - data.batch["returns"] = None - return data - - elif algorithm_type.is_rft(): - adv_estimator = kwargs.get("adv_estimator", None) - gamma = kwargs.get("gamma", 1.0) - lam = kwargs.get("lam", 1.0) - num_repeat = kwargs.get("num_repeat", 1) - - return compute_advantage_ppo( - data=data, - adv_estimator=adv_estimator, - gamma=gamma, - lam=lam, - num_repeat=num_repeat, - ) - - else: - raise ValueError(f"Get invalid algorithm_type '{algorithm_type}'.") - - -def compute_advantage_opmd(data: DataProto, tau=1.0, opmd_baseline="mean"): - # Modified from GRPO version - token_level_rewards = data.batch["token_level_rewards"] - index = data.non_tensor_batch["uid"] - responses = data.batch["responses"] - response_length = responses.size(-1) - attention_mask = data.batch["attention_mask"] - response_mask = attention_mask[:, -response_length:] - advantages, returns = core_algos.compute_opmd_outcome_advantage( - token_level_rewards=token_level_rewards, - eos_mask=response_mask, - index=index, - opmd_baseline=opmd_baseline, - tau=tau, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - - return data - - -def compute_advantage_ppo(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): - # prepare response group - # TODO: add other ways to estimate advantages - if adv_estimator == AdvantageEstimator.GAE: - advantages, returns = core_algos.compute_gae_advantage_return( - token_level_rewards=data.batch["token_level_rewards"], - values=data.batch["values"], - eos_mask=data.batch["response_mask"], - gamma=gamma, - lam=lam, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.GRPO: - advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - eos_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - eos_mask=data.batch["response_mask"], - gamma=gamma, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REMAX: - advantages, returns = core_algos.compute_remax_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - reward_baselines=data.batch["reward_baselines"], - eos_mask=data.batch["response_mask"], - ) - - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.RLOO: - advantages, returns = core_algos.compute_rloo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - eos_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - else: - raise NotImplementedError - return data - - -@contextmanager -def _timer(name: str, timing_raw: Dict[str, float]): - with Timer(name=name, logger=None) as timer: - yield - timing_raw[name] = timer.last - - -class RayPPOTrainer(object): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - # TODO: support each role have individual ray_worker_group_cls, - # i.e., support different backend of different role - def __init__( - self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, - processor=None, - reward_fn=None, - val_reward_fn=None, - ): - # assert torch.cuda.is_available(), 'cuda must be available on driver' - - self.tokenizer = tokenizer - self.processor = processor - self.config = config - self.reward_fn = reward_fn - self.val_reward_fn = val_reward_fn - - self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, "Currently, only support hybrid engine" - - if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" - - self.role_worker_mapping = role_worker_mapping - self.resource_pool_manager = resource_pool_manager - self.use_reference_policy = Role.RefPolicy in role_worker_mapping - self.use_rm = Role.RewardModel in role_worker_mapping - self.ray_worker_group_cls = ray_worker_group_cls - self.validation_generations_logger = ValidationGenerationsLogger() - - # define KL control - if self.use_reference_policy: - self.kl_ctrl = core_algos.get_kl_controller(config.algorithm.kl_ctrl) - else: - self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.0) - - if ( - self.config.actor_rollout_ref.actor.get("algorithm_type", AlgorithmType.PPO) - != AlgorithmType.PPO - ): - self.use_critic = False - elif self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: - self.use_critic = True - elif self.config.algorithm.adv_estimator in [ - AdvantageEstimator.GRPO, - AdvantageEstimator.REINFORCE_PLUS_PLUS, - AdvantageEstimator.REMAX, - AdvantageEstimator.RLOO, - ]: - self.use_critic = False - else: - raise NotImplementedError - - self._validate_config() - self._create_dataloader() - - def _validate_config(self): # noqa: C901 - config = self.config - # number of GPUs total - n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes - - # 1. Check total batch size for data correctness - real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n - assert ( - real_train_batch_size % n_gpus == 0 - ), f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." - - # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" - # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". - def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): - if mbs is None and mbs_per_gpu is None: - raise ValueError( - f"[{name}] Please set at least one of '{name}.micro_batch_size' or " - f"'{name}.micro_batch_size_per_gpu'." - ) - - if mbs is not None and mbs_per_gpu is not None: - raise ValueError( - f"[{name}] You have set both '{name}.micro_batch_size' AND " - f"'{name}.micro_batch_size_per_gpu'. Please remove '{name}.micro_batch_size' " - f"because only '*_micro_batch_size_per_gpu' is supported (the former is deprecated)." - ) - - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.actor.ppo_micro_batch_size, - config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, - "actor_rollout_ref.actor", - ) - - # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.ref.log_prob_micro_batch_size, - config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.ref", - ) - - # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.rollout.log_prob_micro_batch_size, - config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.rollout", - ) - - if self.use_critic and not config.critic.use_dynamic_bsz: - # Check for critic micro-batch size conflicts - check_mutually_exclusive( - config.critic.ppo_micro_batch_size, - config.critic.ppo_micro_batch_size_per_gpu, - "critic", - ) - - # Check for reward model micro-batch size conflicts - if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive( - config.reward_model.micro_batch_size, - config.reward_model.micro_batch_size_per_gpu, - "reward_model", - ) - - # Actor - # if NOT dynamic_bsz, we must ensure: - # ppo_mini_batch_size is divisible by ppo_micro_batch_size - # ppo_micro_batch_size * sequence_parallel_size >= n_gpus - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - assert ( - config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size - ) - sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) - if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: - assert ( - config.actor_rollout_ref.actor.ppo_mini_batch_size - % config.actor_rollout_ref.actor.ppo_micro_batch_size - == 0 - ) - assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus - - # critic - if self.use_critic and not config.critic.use_dynamic_bsz: - assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size - sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) - if config.critic.ppo_micro_batch_size is not None: - assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 - assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus - - # Check if use_remove_padding is enabled when using sequence parallelism for fsdp - if config.actor_rollout_ref.actor.strategy == "fsdp": - if ( - config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 - or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 - ): - assert ( - config.actor_rollout_ref.model.use_remove_padding - ), "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." - - if self.use_critic and config.critic.strategy == "fsdp": - if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: - assert ( - config.critic.model.use_remove_padding - ), "When using sequence parallelism for critic, you must enable `use_remove_padding`." - - if config.data.get("val_batch_size", None) is not None: - print( - "WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves." - ) - - print("[validate_config] All configuration checks passed successfully!") - - def _create_dataloader(self): - # TODO: we have to make sure the batch size is divisible by the dp size - self.train_dataset = RLHFDataset( - parquet_files=self.config.data.train_files, - tokenizer=self.tokenizer, - processor=self.processor, - prompt_key=self.config.data.prompt_key, - image_key=self.config.data.get("image_key", "images"), - max_prompt_length=self.config.data.max_prompt_length, - filter_prompts=True, - return_raw_chat=self.config.data.get("return_raw_chat", False), - truncation="error", - ) - # use sampler for better ckpt resume - if self.config.data.shuffle: - train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed(self.config.data.get("seed", 1)) - sampler = RandomSampler( - data_source=self.train_dataset, generator=train_dataloader_generator - ) - else: - sampler = SequentialSampler(data_source=self.train_dataset) - - self.train_dataloader = StatefulDataLoader( - dataset=self.train_dataset, - batch_size=self.config.data.train_batch_size, - num_workers=8, - drop_last=True, - collate_fn=collate_fn, - sampler=sampler, - ) - - self.val_dataset = RLHFDataset( - parquet_files=self.config.data.val_files, - tokenizer=self.tokenizer, - processor=self.processor, - prompt_key=self.config.data.prompt_key, - image_key=self.config.data.get("image_key", "images"), - max_prompt_length=self.config.data.max_prompt_length, - filter_prompts=True, - return_raw_chat=self.config.data.get("return_raw_chat", False), - truncation="error", - ) - self.val_dataloader = StatefulDataLoader( - dataset=self.val_dataset, - # Validation datasets are sent to inference engines as a whole batch, - # which will schedule the memory themselves. - batch_size=len(self.val_dataset), - num_workers=8, - shuffle=False, - drop_last=False, - collate_fn=collate_fn, - ) - - assert len(self.train_dataloader) >= 1 - assert ( - len(self.val_dataloader) == 1 - ), "Validation dataloader must have a single batch, which inference engines will schedule the memory themselves." - - print(f"Size of train dataloader: {len(self.train_dataloader)}") - - # inject total_training_steps to actor/critic optim_config. This is hacky. - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs - - if self.config.trainer.total_training_steps is not None: - total_training_steps = self.config.trainer.total_training_steps - - self.total_training_steps = total_training_steps - print(f"Total training steps: {self.total_training_steps}") - - OmegaConf.set_struct(self.config, True) - with open_dict(self.config): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps - self.config.critic.optim.total_training_steps = total_training_steps - - def _maybe_log_val_generations(self, inputs, outputs, scores): - """Log a table of validation samples to the configured logger (wandb or swanlab)""" - - generations_to_log = self.config.trainer.val_generations_to_log_to_wandb - - if generations_to_log == 0: - return - - import numpy as np - - # Create tuples of (input, output, score) and sort by input text - samples = list(zip(inputs, outputs, scores)) - samples.sort(key=lambda x: x[0]) # Sort by input text - - # Use fixed random seed for deterministic shuffling - rng = np.random.RandomState(42) - rng.shuffle(samples) - - # Take first N samples after shuffling - samples = samples[:generations_to_log] - - # Log to each configured logger - self.validation_generations_logger.log( - self.config.trainer.logger, samples, self.global_steps - ) - - def _validate(self): - reward_tensor_lst = [] - data_source_lst = [] - - # Lists to collect samples for the table - sample_inputs = [] - sample_outputs = [] - sample_scores = [] - - for test_data in self.val_dataloader: - test_batch = DataProto.from_single_dict(test_data) - - # we only do validation on rule-based rm - if ( - self.config.reward_model.enable - and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model" - ): - return {} - - # Store original inputs - input_ids = test_batch.batch["input_ids"] - input_texts = [ - self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids - ] - sample_inputs.extend(input_texts) - - if "multi_modal_inputs" in test_batch.non_tensor_batch.keys(): - test_gen_batch = test_batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=[ - "raw_prompt_ids", - "multi_modal_data", - "multi_modal_inputs", - ], - ) - else: - test_gen_batch = test_batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], - ) - - test_gen_batch.meta_info = { - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - "recompute_log_prob": False, - "do_sample": False, - "validate": True, - } - - # pad to be divisible by dp_size - test_gen_batch_padded, pad_size = pad_dataproto_to_divisor( - test_gen_batch, self.actor_rollout_wg.world_size - ) - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences( - test_gen_batch_padded - ) - # unpad - test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) - print("validation generation end") - - # Store generated outputs - output_ids = test_output_gen_batch.batch["responses"] - output_texts = [ - self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids - ] - sample_outputs.extend(output_texts) - - test_batch = test_batch.union(test_output_gen_batch) - - # evaluate using reward_function - reward_tensor = self.val_reward_fn(test_batch) - - # Store scores - scores = reward_tensor.sum(-1).cpu().tolist() - sample_scores.extend(scores) - - reward_tensor_lst.append(reward_tensor) - data_source_lst.append( - test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]) - ) - - self._maybe_log_val_generations( - inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores - ) - - reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,) - data_sources = np.concatenate(data_source_lst, axis=0) - - # evaluate test_score based on data source - data_source_reward = {} - for i in range(reward_tensor.shape[0]): - data_source = data_sources[i] - if data_source not in data_source_reward: - data_source_reward[data_source] = [] - data_source_reward[data_source].append(reward_tensor[i].item()) - - metric_dict = {} - for data_source, rewards in data_source_reward.items(): - metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards) - - return metric_dict - - def init_workers(self): - """Init resource pool and worker group""" - self.resource_pool_manager.create_resource_pool() - - self.resource_pool_to_cls = { - pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values() - } - - # create actor and rollout - if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role="actor", - ) - self.resource_pool_to_cls[resource_pool]["actor"] = actor_rollout_cls - else: - raise NotImplementedError - - # create critic - if self.use_critic: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.Critic], config=self.config.critic - ) - self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls - - # create reference policy if needed - if self.use_reference_policy: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RefPolicy], - config=self.config.actor_rollout_ref, - role="ref", - ) - self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls - - # create a reward model if reward_fn is None - if self.use_rm: - # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model - ) - self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls - - # initialize WorkerGroup - # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, - # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. - # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. - all_wg = {} - self.wg_dicts = [] - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls( - resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls - ) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 - self.wg_dicts.append(wg_dict) - - if self.use_critic: - self.critic_wg = all_wg["critic"] - self.critic_wg.init_model() - - if self.use_reference_policy: - self.ref_policy_wg = all_wg["ref"] - self.ref_policy_wg.init_model() - - if self.use_rm: - self.rm_wg = all_wg["rm"] - self.rm_wg.init_model() - - # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg["actor"] - self.actor_rollout_wg.init_model() - - def _save_checkpoint(self): - # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join( - self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" - ) - - print(f"local_global_step_folder: {local_global_step_folder}") - actor_local_path = os.path.join(local_global_step_folder, "actor") - - actor_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join( - self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor" - ) - ) - - remove_previous_ckpt_in_save = self.config.trainer.get( - "remove_previous_ckpt_in_save", False - ) - if remove_previous_ckpt_in_save: - print( - "Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" - ) - max_actor_ckpt_to_keep = ( - self.config.trainer.get("max_actor_ckpt_to_keep", None) - if not remove_previous_ckpt_in_save - else 1 - ) - max_critic_ckpt_to_keep = ( - self.config.trainer.get("max_critic_ckpt_to_keep", None) - if not remove_previous_ckpt_in_save - else 1 - ) - - self.actor_rollout_wg.save_checkpoint( - actor_local_path, - actor_remote_path, - self.global_steps, - max_ckpt_to_keep=max_actor_ckpt_to_keep, - ) - - if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, "critic") - critic_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join( - self.config.trainer.default_hdfs_dir, - f"global_step_{self.global_steps}", - "critic", - ) - ) - self.critic_wg.save_checkpoint( - critic_local_path, - critic_remote_path, - self.global_steps, - max_ckpt_to_keep=max_critic_ckpt_to_keep, - ) - - # save dataloader - dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") - dataloader_state_dict = self.train_dataloader.state_dict() - torch.save(dataloader_state_dict, dataloader_local_path) - - # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join( - self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" - ) - with open(local_latest_checkpointed_iteration, "w") as f: - f.write(str(self.global_steps)) - - def _load_checkpoint(self): - if self.config.trainer.resume_mode == "disable": - return 0 - - # load from hdfs - if self.config.trainer.default_hdfs_dir is not None: - raise NotImplementedError("load from hdfs is not implemented yet") - else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path - if not os.path.isabs(checkpoint_folder): - working_dir = os.getcwd() - checkpoint_folder = os.path.join(working_dir, checkpoint_folder) - global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest - - # find global_step_folder - if self.config.trainer.resume_mode == "auto": - if global_step_folder is None: - print("Training from scratch") - return 0 - else: - if self.config.trainer.resume_mode == "resume_path": - assert isinstance( - self.config.trainer.resume_from_path, str - ), "resume ckpt must be str type" - assert ( - "global_step_" in self.config.trainer.resume_from_path - ), "resume ckpt must specify the global_steps" - global_step_folder = self.config.trainer.resume_from_path - if not os.path.isabs(global_step_folder): - working_dir = os.getcwd() - global_step_folder = os.path.join(working_dir, global_step_folder) - print(f"Load from checkpoint folder: {global_step_folder}") - # set global step - self.global_steps = int(global_step_folder.split("global_step_")[-1]) - - print(f"Setting global step to {self.global_steps}") - print(f"Resuming from {global_step_folder}") - - actor_path = os.path.join(global_step_folder, "actor") - critic_path = os.path.join(global_step_folder, "critic") - # load actor - self.actor_rollout_wg.load_checkpoint( - actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - # load critic - if self.use_critic: - self.critic_wg.load_checkpoint( - critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - - # load dataloader, - # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, "data.pt") - if os.path.exists(dataloader_local_path): - dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) - self.train_dataloader.load_state_dict(dataloader_state_dict) - else: - print( - f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch" - ) - - def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): - """Reorder the data on single controller such that each dp rank gets similar total tokens""" - attention_mask = batch.batch["attention_mask"] - batch_size = attention_mask.shape[0] - global_seqlen_lst = ( - batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() - ) # (train_batch_size,) - world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) - # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) - batch.reorder(global_idx) - global_balance_stats = log_seqlen_unbalance( - seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix - ) - metrics.update(global_balance_stats) - - def fit(self): # noqa: C901 - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - from verl.utils.tracking import Tracking - - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - - self.global_steps = 0 - - # load checkpoint before doing anything - self._load_checkpoint() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - # add tqdm - progress_bar = tqdm( - total=self.total_training_steps, initial=self.global_steps, desc="Training Progress" - ) - - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - timing_raw = {} - - batch: DataProto = DataProto.from_single_dict(batch_dict) - - # pop those keys for generation - if "multi_modal_inputs" in batch.non_tensor_batch.keys(): - gen_batch = batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=[ - "raw_prompt_ids", - "multi_modal_data", - "multi_modal_inputs", - ], - ) - else: - gen_batch = batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], - ) - - is_last_step = self.global_steps >= self.total_training_steps - - with _timer("step", timing_raw): - # generate a batch - with _timer("gen", timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer("gen_max", timing_raw): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences( - gen_baseline_batch - ) - - batch = batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - batch.batch["reward_baselines"] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object - ) - # repeat to align with repeated responses in rollout - batch = batch.repeat( - repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True - ) - batch = batch.union(gen_batch_output) - - batch.batch["response_mask"] = compute_response_mask(batch) - - # balance the number of valid tokens on each dp rank. - # Note that this breaks the order of data inside the batch. - # Please take care when you implement group based adv computation such as GRPO and rloo - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum( - batch.batch["attention_mask"], dim=-1 - ).tolist() - - # recompute old_log_probs - with _timer("old_log_prob", timing_raw): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - batch = batch.union(old_log_prob) - - if self.use_reference_policy: - # compute reference log_prob - with _timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # compute values - if self.use_critic: - with _timer("values", timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with _timer("adv", timing_raw): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - # we combine with rule-based rm - reward_tensor = self.reward_fn(batch) - batch.batch["token_level_scores"] = reward_tensor - - # compute rewards. apply_kl_penalty if available - if not self.config.actor_rollout_ref.actor.get("use_kl_loss", False): - batch, kl_metrics = apply_kl_penalty( - batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty, - ) - metrics.update(kl_metrics) - else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - - # compute advantages, executed on the driver process - algorithm_type = self.config.actor_rollout_ref.actor.get( - "algorithm_type", AlgorithmType.PPO - ) - tau = self.config.actor_rollout_ref.actor.get("tau", 1.0) - opmd_baseline = self.config.actor_rollout_ref.actor.get( - "opmd_baseline", "mean" - ) - batch = compute_advantage( - batch, - algorithm_type=algorithm_type, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - # additional config params for OPMD - tau=tau, - opmd_baseline=opmd_baseline, - ) - - # update critic - if self.use_critic: - with _timer("update_critic", timing_raw): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with _timer("update_actor", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # validate - if ( - self.val_reward_fn is not None - and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) - ): - with _timer("testing", timing_raw): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 - ): - with _timer("save_checkpoint", timing_raw): - self._save_checkpoint() - - # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - - # Implement actual tflpo and theoretical tflpo - n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update( - compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus) - ) - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - if is_last_step: - pprint(f"Final validation metrics: {last_val_metrics}") - progress_bar.close() - return - - progress_bar.update(1) - self.global_steps += 1 diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 1041600d87..d041bea128 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -4,36 +4,41 @@ Modified from verl/trainer/ppo/ray_trainer.py """ import os -from typing import Tuple +import sys +from pprint import pprint +from typing import Dict, List import pandas as pd import ray import torch from omegaconf import OmegaConf -from verl.utils import hf_tokenizer -from verl.utils.fs import copy_local_path_from_hdfs - -from trinity.common.config import Config -from trinity.common.constants import AlgorithmType -from trinity.common.experience import Experiences -from trinity.trainer.trainer import TrainEngineWrapper -from trinity.trainer.verl.ray_trainer import ( - DataProto, +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + reduce_metrics, +) +from verl.trainer.ppo.ray_trainer import ( + RayClassWithInitArgs, RayPPOTrainer, RayWorkerGroup, ResourcePoolManager, Role, _timer, - apply_kl_penalty, - compute_advantage, - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, + create_colocated_worker_cls, find_latest_ckpt_path, - np, - pprint, - reduce_metrics, ) +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_local_path_from_hdfs + +from trinity.algorithm import ADVANTAGE_FN, KL_FN, SAMPLE_STRATEGY +from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm +from trinity.algorithm.algorithm_manager import AlgorithmManager +from trinity.algorithm.utils import prefix_metrics +from trinity.common.config import Config +from trinity.common.constants import TRAINER_NAME +from trinity.common.experience import Experiences +from trinity.trainer.trainer import TrainEngineWrapper from trinity.utils.monitor import MONITOR @@ -116,7 +121,24 @@ def __init__( resource_pool_manager = ResourcePoolManager( resource_pool_spec=resource_pool_spec, mapping=mapping ) - + self.algorithm_config = global_config.algorithm + self.algorithm = None + self.algorithm_manager = AlgorithmManager(global_config) + + # specify advantage function for various rft algorithms + algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type) + if algorithm.use_advantage: + self.advantage_fn = ADVANTAGE_FN.get(self.algorithm_config.advantage_fn)( + **self.algorithm_config.advantage_fn_args + ) + self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)( + **self.algorithm_config.kl_penalty_fn_args + ) + self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)( + buffer_config=global_config.buffer, + trainer_type=global_config.trainer.trainer_type, + **global_config.algorithm.sample_strategy_args, + ) super().__init__( config, tokenizer, @@ -125,31 +147,130 @@ def __init__( ray_worker_group_cls, ) self.init_workers() - self.algorithm_type = ( - AlgorithmType.PPO - ) # TODO: initialize algorithm_type according to config self.logger = MONITOR.get(global_config.monitor.monitor_type)( project=config.trainer.project_name, name=config.trainer.experiment_name, - role="trainer", + role=TRAINER_NAME, config=global_config, ) self.reset_experiences_example_table() + def _validate_config(self): # TODO + algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type) + self.use_critic = algorithm.use_critic + super()._validate_config() + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + + Creates: + + 1. Ray resource pools from configuration + + 2. Worker groups for each role (actor, critic, etc.) + + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = { + pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values() + } + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor", + ) + self.resource_pool_to_cls[resource_pool]["actor"] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.Critic], config=self.config.critic + ) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role="ref", + ) + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model + ) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs[ + "ray_wait_register_center_timeout" + ] = self.config.trainer.ray_wait_register_center_timeout + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + device_name=self.device_name, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg["actor"] + self.actor_rollout_wg.init_model() + def reset_experiences_example_table(self): - self.experiences_example_table = pd.DataFrame( - columns=["step", "reward", "prompt", "response"] - ) + self.sample_exps_to_log = [] + + @property + def train_step_num(self) -> int: + return self.global_steps def prepare(self): self.actor_rollout_wg.setup_weight_sync_group() + # The global step counter, initialized to 0 + # It represents the total number of training steps completed so far + # We increment this counter at the beginning of each training step self.global_steps = 0 - self.sft_warmup_step_num = 0 # load checkpoint before doing anything self._load_checkpoint() - self.sft_warmup_step_num = min(self.global_steps, self.config.trainer.sft_warmup_steps) # perform validation before training # currently, we only support validation using the reward_function. @@ -160,193 +281,33 @@ def prepare(self): if self.config.trainer.get("val_only", False): return - # we start from step 1 - self.global_steps += 1 - - def _create_dataloader(self): + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler): self.train_dataloader = _InternalDataLoader(self.config) # TODO: compute total training steps - # if self.algorithm_type.is_dpo(): - # train_batch_size = self.config.buffer.read_batch_size - # total_epochs = self.config.trainer.total_epochs - # from math import ceil - - # self.total_training_steps = ceil( - # self.train_dataloader.size() // train_batch_size * total_epochs - # ) - # if not self.config.actor_rollout_ref.actor.optim.total_training_steps > 0: - # self.config.actor_rollout_ref.actor.optim.total_training_steps = ( - # self.total_training_steps - # ) - # if not self.config.critic.optim.total_training_steps > 0: - # self.config.critic.optim.total_training_steps = self.total_training_steps - # else: - self.total_training_steps = float("inf") - - def train_dpo_step(self, experiences: Experiences) -> Tuple[bool, int]: - metrics = {} - timing_raw = {} - - with _timer("step", timing_raw): - # generate a batch - attention_mask = experiences.attention_masks - cumsum = torch.cumsum(attention_mask, dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - - batch = DataProto.from_single_dict( - { - "uid": np.array(experiences.run_ids), # useless - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") - and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - } - ) - batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature - - # self._balance_batch(batch, metrics=metrics) # _balance_batch will shuffle the batch, which will break DPO - # TODO: implement a new _balance_batch for DPO - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum( - batch.batch["attention_mask"], dim=-1 - ).tolist() + self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize - if self.use_reference_policy: - # compute reference log_prob - with _timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # update actor - with _timer("update_actor", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # collect metrics - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - - self.logger.log(data=metrics, step=self.global_steps) - - # save checkpoint - if ( - self.config.trainer.save_freq > 0 - and self.global_steps % self.config.trainer.save_freq == 0 - ): - with _timer("save_checkpoint", timing_raw): - self._save_checkpoint() - - self.global_steps += 1 - return True, self.global_steps - 1 - - def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]: - if self.sft_warmup_step_num >= self.config.trainer.sft_warmup_steps: - return False, self.global_steps - 1 + def train_step(self) -> bool: # noqa C901 metrics = {} - timing_raw = {} - - with _timer("step", timing_raw): - # generate a batch - attention_mask = experiences.attention_masks - cumsum = torch.cumsum(attention_mask, dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - - batch = DataProto.from_single_dict( - { - "uid": np.array(experiences.run_ids), - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") - and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - } - ) - batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature - - self._balance_batch(batch, metrics=metrics) # TODO this may affect multi-turn - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum( - batch.batch["attention_mask"], dim=-1 - ).tolist() - - if self.use_reference_policy: - # compute reference log_prob - with _timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # update actor - with _timer("update_actor", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # collect metrics - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - - # TODO: log as sft metrics - self.logger.log(data=metrics, step=self.global_steps) - self.sft_warmup_step_num += 1 + try: + batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1) + prefix_metrics(sample_metrics, "sample", metrics) + except StopIteration: + print("No more data to train. Stop training.") + return False self.global_steps += 1 - if self.sft_warmup_step_num == self.config.trainer.sft_warmup_steps: - self.logger.log( - data={"sft_warmup_steps": self.sft_warmup_step_num}, - step=self.global_steps - 1, - ) - with _timer("save_checkpoint", timing_raw): - self._save_checkpoint() - return False, self.global_steps - 1 - return True, self.global_steps - 1 - - def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: - metrics = {} timing_raw = {} + algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps) + algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type) + if self.algorithm != algorithm: + self.actor_rollout_wg.set_algorithm(algorithm_config) + if self.algorithm == SFTAlgorithm: + self.sft_to_rft() + self.algorithm = algorithm with _timer("step", timing_raw): - # Convert rewards to token_level_rewards - attention_mask = experiences.attention_masks - token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) - cumsum = torch.cumsum(attention_mask, dim=-1) - eos_mask_idx = cumsum.argmax(dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] - - batch = DataProto.from_single_dict( - { - "uid": np.array(experiences.run_ids), - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") - and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - "token_level_scores": token_level_rewards, - "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore - } - ) batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature - if self.config.trainer.balance_batch: + if self.algorithm.can_balance_batch and self.config.trainer.balance_batch: self._balance_batch(batch, metrics=metrics) # TODO this may affect multi-turn # compute global_valid tokens @@ -354,61 +315,37 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: batch.batch["attention_mask"], dim=-1 ).tolist() - if self.use_reference_policy: + if self.algorithm.use_reference: # ref_logprob may not be used # compute reference log_prob with _timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) - # compute values - if self.use_critic: + if self.algorithm.use_critic: with _timer("values", timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - with _timer("adv", timing_raw): - # compute rewards. apply_kl_penalty if available - if not self.config.actor_rollout_ref.actor.get("use_kl_loss", False): - batch, kl_metrics = apply_kl_penalty( - batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty, - ) - metrics.update(kl_metrics) - else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - - # compute advantages, executed on the driver process - kwargs = {} - algorithm_type = self.config.actor_rollout_ref.actor.get( - "algorithm_type", AlgorithmType.PPO - ) - if algorithm_type == AlgorithmType.OPMD: - tau = self.config.actor_rollout_ref.actor.get("tau", 0.0) - opmd_baseline = self.config.actor_rollout_ref.actor.get("opmd_baseline", "mean") - kwargs = { - "algorithm_type": algorithm_type, - "tau": tau, - "opmd_baseline": opmd_baseline, - } - batch = compute_advantage( - batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - **kwargs, - ) - - # update critic - if self.use_critic: + if self.algorithm.use_advantage: + with _timer("adv", timing_raw): + # compute kl penalty + batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch) + metrics.update(prefix_metrics(kl_metrics, prefix="critic")) + # compute advantages, executed on the driver process + batch, _ = self.advantage_fn(batch) + + # update critic + if self.algorithm.use_critic: with _timer("update_critic", timing_raw): critic_output = self.critic_wg.update_critic(batch) critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) metrics.update(critic_output_metrics) # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: + if ( + not self.algorithm.use_critic + or self.config.trainer.critic_warmup <= self.global_steps + ): # update actor with _timer("update_actor", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) @@ -424,33 +361,29 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: self._save_checkpoint() # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + if self.algorithm.use_advantage: # TODO + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) n_gpus = self.resource_pool_manager.get_n_gpus() metrics.update( compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus) ) - if self.config.enable_preview: - self._log_experiences(experiences) + if self.algorithm.use_advantage and self.config.enable_preview: # TODO + self._log_experiences(exp_samples) # TODO: make a canonical logger that supports various backend self.logger.log(data=metrics, step=self.global_steps) - self.global_steps += 1 - - if self.global_steps >= self.total_training_steps: + train_status = self.global_steps < self.total_training_steps + if not train_status or self.algorithm_manager.need_save(self.global_steps): if ( - self.config.trainer.save_freq > 0 - and (self.global_steps - 1) % self.config.trainer.save_freq != 0 + self.config.trainer.save_freq == 0 + or self.global_steps % self.config.trainer.save_freq != 0 ): with _timer("save_checkpoint", timing_raw): self._save_checkpoint() - # stop training - return False, self.global_steps - 1 - else: - # continue - return True, self.global_steps - 1 + return train_status def _log_single_experience( self, experiences: Experiences, idx: int, skip_special_tokens: bool @@ -475,21 +408,13 @@ def _log_single_experience( "response": [response_text], } ) - self.experiences_example_table = pd.concat( - [self.experiences_example_table, new_row], ignore_index=True - ) - - def _log_experiences(self, experiences: Experiences) -> None: - skip_special_tokens = False - reward_max_id = torch.argmax(experiences.rewards) - self._log_single_experience(experiences, reward_max_id, skip_special_tokens) - - reward_min_id = torch.argmin(experiences.rewards) - self._log_single_experience(experiences, reward_min_id, skip_special_tokens) + self.sample_exps_to_log = pd.concat([self.sample_exps_to_log, new_row], ignore_index=True) + def _log_experiences(self, samples: List[Dict]) -> None: + self.sample_exps_to_log.extend(samples) if self.global_steps % self.config.trainer.sync_freq == 0: self.logger.log_table( - "rollout_examples", self.experiences_example_table, self.global_steps + "rollout_examples", pd.DataFrame(self.sample_exps_to_log), self.global_steps ) self.reset_experiences_example_table() @@ -499,12 +424,6 @@ def save_checkpoint(self) -> None: def sync_weight(self) -> None: self.actor_rollout_wg.sync_weight() - def set_mode(self, algorithm_type: AlgorithmType = AlgorithmType.PPO) -> None: - self.actor_rollout_wg.set_mode(algorithm_type) - if self.algorithm_type.is_sft() and (not algorithm_type.is_sft()): - self.sft_to_rft() - self.algorithm_type = algorithm_type - def sft_to_rft(self) -> None: # load from hdfs if self.config.trainer.default_hdfs_dir is not None: @@ -535,9 +454,9 @@ def sft_to_rft(self) -> None: global_step_folder = os.path.join(working_dir, global_step_folder) print(f"Load from checkpoint folder: {global_step_folder}") # set global step - self.global_steps = int(global_step_folder.split("global_step_")[-1]) + global_steps = int(global_step_folder.split("global_step_")[-1]) + assert self.global_steps == global_steps + 1 - print(f"Setting global step to {self.global_steps}") print(f"Resuming from {global_step_folder}") actor_path = os.path.join(global_step_folder, "actor") diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py index 3ad4e844fe..83cd393519 100644 --- a/trinity/utils/registry.py +++ b/trinity/utils/registry.py @@ -5,6 +5,8 @@ logger = get_logger(__name__) +# TODO: support lazy load +# e.g. @MODULES.register_module("name", lazy=True) class Registry(object): """A class for registry.""" @@ -41,7 +43,7 @@ def list(self) -> None: for m in self._modules.keys(): logger.info(f"{self._name}\t{m}") - def get(self, module_key: str) -> Any: + def get(self, module_key) -> Any: """ Get module named module_key from in current registry. If not found, return None. @@ -81,21 +83,21 @@ def register_module(self, module_name: str, module_cls: Type = None, force=False Default: False. Example: - ```python - WORKFLOWS = Registry("workflows") - - # register a module using decorator - @WORKFLOWS.register_module(name="workflow_name") - class MyWorkflow(Workflow): - pass - - # or register a module directly - WORKFLOWS.register_module( - name="workflow_name", - module_cls=MyWorkflow, - force=True, - ) - ``` + + .. code-block:: python + WORKFLOWS = Registry("workflows") + + # register a module using decorator + @WORKFLOWS.register_module(name="workflow_name") + class MyWorkflow(Workflow): + pass + + # or register a module directly + WORKFLOWS.register_module( + name="workflow_name", + module_cls=MyWorkflow, + force=True, + ) """ if not (module_name is None or isinstance(module_name, str)): diff --git a/trinity/utils/timer.py b/trinity/utils/timer.py new file mode 100644 index 0000000000..5e80f406b8 --- /dev/null +++ b/trinity/utils/timer.py @@ -0,0 +1,18 @@ +"""Timer context manager""" + +import time + + +class Timer: + def __init__(self, metrics_dict, key_name): + self.metrics = metrics_dict + self.key = key_name + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + end_time = time.time() + elapsed_time = end_time - self.start_time + self.metrics[self.key] = elapsed_time