diff --git a/docs/sphinx_doc/source/conf.py b/docs/sphinx_doc/source/conf.py index be97e1a131..4842a34557 100644 --- a/docs/sphinx_doc/source/conf.py +++ b/docs/sphinx_doc/source/conf.py @@ -54,6 +54,14 @@ "navigation_depth": 3, } +html_context = { + "display_github": True, + "github_user": "modelscope", + "github_repo": "Trinity-RFT", + "github_version": "main", + "conf_py_path": "/docs/sphinx_doc/source/", +} + # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md index 35c70389cb..d62f56de3f 100644 --- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md @@ -1,4 +1,4 @@ -# Data processing functionalities +# Data Processing ## Example: reasoning task diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index fc599e206b..09377e1f66 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -1,86 +1,132 @@ -# Trinity-RFT Configuration +# Configuration Guide -The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example. +This section provides a detailed description of the configuration files used in **Trinity-RFT**. -## Global Config +## Overview + +The configuration for **Trinity-RFT** is defined in a `YAML` file and organized into multiple sections based on different modules. Here's an example of a basic configuration file: ```yaml project: Trinity-RFT -name: example +name: tutorial mode: both checkpoint_root_dir: /PATH/TO/CHECKPOINT + +algorithm: + # Algorithm-related parameters + ... +model: + # Model-specific configurations + ... +cluster: + # Cluster node and GPU settings + ... +buffer: + # Data buffer configurations + ... +explorer: + # Explorer-related settings (rollout models, workflow runners) + ... +trainer: + # Trainer-specific parameters + ... +synchronizer: + # Model weight synchronization settings + ... +monitor: + # Monitoring configurations (e.g., WandB or TensorBoard) + ... +data_processor: + # Preprocessing data settings + ... ``` -- `project`: The name of the project. -- `name`: The name of the experiment. -- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`. -- `checkpoint_root_dir`: The root directory to save the checkpoints. Sepcifically, the generated checkpoints will be saved in `///. +Each of these sections will be explained in detail below. -## Algorithm +```{note} +For additional details about specific parameters not covered here, please refer to the [source code](https://github.com/modelscope/Trinity-RFT/blob/main/trinity/common/config.py). +``` + +--- + +## Global Configuration + +These are general settings that apply to the entire experiment. ```yaml -algorithm: - algorithm_type: grpo - repeat_times: 1 +project: Trinity-RFT +name: example +mode: both +checkpoint_root_dir: /PATH/TO/CHECKPOINT ``` -- `algorithm.algorithm_type`: The type of the algorithm. Support `ppo`, `grpo`, `opmd` and `dpo`. -- `algorithm.repeat_times`: The number of times to repeat each task. Used for GRPO-like algorithm. Default is `1`. +- `project`: The name of the project. +- `name`: The name of the current experiment. +- `mode`: Running mode of Trinity-RFT. Options include: + - `both`: Launches both the trainer and explorer (default). + - `train`: Only launches the trainer. + - `explore`: Only launches the explorer. + - `bench`: Used for benchmarking. +- `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `///`. + +--- -## Monitor +## Algorithm Configuration + +Specifies the algorithm type and its related hyperparameters. ```yaml -monitor: - monitor_type: MonitorType.WANDB +algorithm: + algorithm_type: grpo + repeat_times: 1 + gamma: 1.0 + lam: 1.0 ``` -- `monitor.monitor_type`: The type of the monitor. For now, `MonitorType.WANDB` and `MonitorType.TENSORBOARD` are supported. +- `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`. +- `repeat_times`: Number of times each task is repeated. Default is `1`. In `dpo`, this is automatically set to `2`. +- `gamma`: Discount factor for future rewards. Default is `1.0`. +- `lam`: Lambda value for Generalized Advantage Estimation (GAE). Default is `1.0`. +--- -## Data Processing +## Monitor Configuration - +Used to log training metrics during execution. ```yaml -data_processor: - source_data_path: '/PATH/TO/DATASET' - load_kwargs: - split: 'train' # only need the train split - format: - prompt_key: 'question' - response_key: 'answer' - - # cleaner related - dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml' - clean_strategy: 'iterative' - # db related - db_url: 'postgresql://{username}@localhost:5432/{db_name}' +monitor: + monitor_type: wandb ``` -- `data.source_data_path`: The path to the source dataset. -- `data.load_kwargs`: The kwargs used in `datasets.load_dataset`. -- `data.format`: The format of the source dataset. It includes `prompt_key` and `response_key`. -- `data.dj_config_path`: The path to the Data-Juicer configuration. -- `data.clean_strategy`: The cleaning strategy used for `DataCleaner`, which iteratively cleans dataset until targets are met. -- `data.db_url`: The URL of the database. +- `monitor_type`: Type of monitoring system. Options: + - `wandb`: Logs to Weights & Biases. Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs. + - `tensorboard`: Logs to TensorBoard. Files are saved under `///monitor/tensorboard`. + +--- -## Model +## Model Configuration -The `model` configuration specifies the model used for training. It includes the path to the model checkpoint, the maximum number of tokens in the prompt, the maximum number of tokens in the response, the path to the checkpoint of the model, and whether to load the checkpoint of the model. +Defines the model paths and token limits. ```yaml model: model_path: '/PATH/TO/MODEL/CHECKPOINT/' critic_model_path: '' + max_prompt_tokens: 4096 + max_response_tokens: 16384 ``` -- `model.model_path`: The path to the model checkpoint. It must be set manually. -- `model.critic_model_path`: The path to the critic model checkpoint. If not set, the `model.critic_model_path` will be set to `model.model_path`. +- `model_path`: Path to the model checkpoint being trained. +- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`. +- `max_prompt_tokens`: Maximum number of tokens allowed in input prompts. +- `max_response_tokens`: Maximum number of tokens allowed in generated responses. +--- -## Cluster +## Cluster Configuration -The `cluster` configuration specifies the cluster configuration. It includes the number of nodes and the number of GPUs per node. +Defines how many nodes and GPUs per node are used. ```yaml cluster: @@ -88,86 +134,152 @@ cluster: gpu_per_node: 8 ``` -- `cluster.node_num`: The number of nodes used for training. -- `cluster.gpu_per_node`: The number of GPUs per node used for training. +- `node_num`: Total number of compute nodes. +- `gpu_per_node`: Number of GPUs available per node. + +--- -## Buffer +## Buffer Configuration + +Configures the data buffers used by the explorer and trainer. ```yaml buffer: - max_retry_times: 3 - max_retry_interval: 1 + batch_size: 32 + total_epochs: 100 + explorer_input: taskset: - name: countdown - path: 'countdown_dataset/oneshot-split' + ... + eval_tasksets: + ... + + trainer_input: + experience_buffer: + ... + sft_warmup_dataset: + ... + + default_workflow_type: 'math_workflow' + default_reward_fn_type: 'countdown_reward' +``` + +- `batch_size`: Number of samples used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*. +- `total_epochs`: Total number of training epochs. Not applicable for streaming datasets (e.g., queue-based buffers). + +### Explorer Input + +Defines the dataset(s) used by the explorer for training and evaluation. + +```yaml +buffer: + ... + explorer_input: + taskset: + name: countdown_train + storage_type: file + path: /PATH/TO/DATA split: train format: prompt_key: 'question' response_key: 'answer' rollout_args: - n: 1 temperature: 1.0 - logprobs: 0 - eval_tasksets: [] - default_workflow_type: 'math_workflow' - default_reward_fn_type: 'countdown_reward' + default_workflow_type: 'math_workflow' + default_reward_fn_type: 'countdown_reward' + + eval_tasksets: + - name: countdown_eval + storage_type: file + path: /PATH/TO/DATA + split: test + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 0.1 + default_workflow_type: 'math_workflow' + default_reward_fn_type: 'countdown_reward' +``` + +- `buffer.explorer_input.taskset`: Task dataset used for training exploration policies. +- `buffer.explorer_input.eval_taskset`: List of task datasets used for evaluation. + +The configuration for each task dataset is defined as follows: + +- `name`: Name of the dataset. Name must be unique. +- `storage_type`: How the dataset is stored. Options: `file`, `queue`, `sql`. + - `file`: The dataset is stored in `jsonl`/`parquet` files. The data file organization is required to meet the huggingface standard. *We recommand using this storage type for most cases.* + - `queue`: The dataset is stored in a queue. The queue is a simple FIFO queue that stores the task dataset. *Do not use this storage type for task dataset unless you know what you are doing.* + - `sql`: The dataset is stored in a SQL database. *This type is unstable and will be optimized in the future versions.* +- `path`: The path to the task dataset. + - For `file` storage type, the path is the path to the directory that contains the task dataset files. + - For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here. + - For `sql` storage type, the path is the path to the sqlite database file. +- `format`: Defines keys for prompts and responses in the dataset. + - `prompt_key`: Specifies which column in the dataset contains the prompt data. + - `response_key`: Specifies which column in the dataset contains the response data. +- `rollout_args`: The parameters for rollout. + - `temperature`: The temperature for sampling. +- `default_workflow_type`: Type of workflow logic applied to this dataset. If not specified, the `buffer.default_workflow_type` is used. +- `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used. + +### Trainer Input + +Defines the experience buffer and optional SFT warm-up dataset. + +```yaml +buffer: + ... trainer_input: experience_buffer: name: countdown_buffer storage_type: queue - path: 'sqlite:///countdown.db' - sft_warmup_dataset: null + path: sqlite:///countdown_buffer.db + + sft_warmup_dataset: + name: warmup_data + storage_type: file + path: /PATH/TO/WARMUP_DATA + format: + prompt_key: 'question' + response_key: 'answer' + + sft_warmup_steps: 0 ``` -- `buffer.max_retry_times`: The maximum number of retries when loading the data from database. -- `buffer.max_retry_interval`: The maximum interval between retries when loading the data from database. -- `buffer.explorer_input.taskset`: The configuration of the taskset. -- `buffer.explorer_input.taskset.name`: The name of the taskset. -- `buffer.explorer_input.taskset.path`: The path to the taskset. -- `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`. -- `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`. -- `buffer.explorer_input.taskset.rollout_args.n`: The number of times to repeat each task. This field is automatically set to `algorithm.repeat_times`. -- `buffer.explorer_input.taskset.rollout_args.temperature`: The temperature used in vLLM. Default is `1.0`. -- `buffer.explorer_input.taskset.rollout_args.logprobs`: The logprobs used in vLLM. Default is `0`. -- `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default. -- `buffer.explorer_input.default_workflow_type`: The default workflow type for `taskset` and `eval_tasksets`. -- `buffer.explorer_input.default_reward_fn_type`: The default reward function type for `taskset` and `eval_tasksets`. -- `buffer.trainer_input.experience_buffer`: The configuration of experience_buffer. -- `buffer.trainer_input.experience_buffer.name`: The name of the experience buffer. -- `buffer.trainer_input.experience_buffer.storage_type`: The storage type of the experience buffer. Default is `queue`. -- `buffer.trainer_input.experience_buffer.path`: The sql path to store the experience buffer. It can be empty to indicate not saving to the database. -- `buffer.trainer_input.sft_warmup_dataset`: The configuration of the SFT warmup dataset. The structure of `sft_warmup_dataset` is the similar to `buffer.explorer_input.taskset`. - -## Explorer - -The `explorer` configuration specifies the explorer configuration. It includes the type of the engine, the number of engines, the number of workflow runners, the tensor parallel size, whether to enable prefix caching, whether to enforce eager mode, the data type, the `temperature`, the `top-p`, the `top-k`, the `seed`, the `logprobs`, the number of times to repeat each task, the maximum number of pending requests, and the maximum number of waitingsteps. +- `experience_buffer`: Experience replay buffer used by the trainer. +- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup). +- `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins. + +--- + +## Explorer Configuration + +Controls the rollout models and workflow execution. ```yaml explorer: runner_num: 32 rollout_model: engine_type: vllm_async - engine_num: 2 + engine_num: 1 + tensor_parallel_size: 1 + auxiliary_models: + - model_path: /PATH/TO/MODEL tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 ``` -- `explorer.engine_type`: The type of the engine, Support `vllm_async` and `vllm_sync`. Default is `vllm_async`. -- `explorer.engine_num`: The number of engines. Default is `2`. It should be set manually. -- `explorer.runner_num`: The number of workflow runners. Default is `32`. -- `explorer.tensor_parallel_size`: The tensor parallel size used in vLLM. Default is `1`. -- `explorer.enable_prefix_caching`: Whether to enable prefix caching. Default is `False`. -- `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`. -- `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`. -- `explorer.seed`: The seed used in vLLM. Default is `42`. -- `explorer.rollout_model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually. -- `explorer.rollout_model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually. +- `runner_num`: Number of parallel workflow runners. +- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`. +- `rollout_model.engine_num`: Number of inference engines. +- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. +- `auxiliary_models`: Additional models used for custom workflows. +--- + +## Synchronizer Configuration -## Synchronizer +Controls how model weights are synchronized between trainer and explorer. ```yaml synchronizer: @@ -176,36 +288,64 @@ synchronizer: sync_timeout: 1200 ``` -- `synchronizer.sync_method`: The synchronization method between `trainer` and `explorer`. -Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explorer` will be synchronized from `trainer` through `nccl`, -`checkpoint` represents that `explorer` will load the newest checkpoints saved by `trainer` then update its model weights. Default is `nccl`. -- `synchronizer.sync_interval`: The interval steps between two synchronizations. Default is `10`. It should be set manually. -- `synchronizer.sync_timeout`: The timeout of the synchronization. Default is `1200`. +- `sync_method`: Method of synchronization. Options: + - `nccl`: Uses NCCL for fast synchronization. + - `checkpoint`: Loads latest model from disk. +- `sync_interval`: Interval (in steps) between synchronizations. +- `sync_timeout`: Timeout duration for synchronization. + +--- -## Trainer +## Trainer Configuration + +Specifies the backend and behavior of the trainer. ```yaml trainer: trainer_type: 'verl' save_interval: 100 trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' + trainer_config: null ``` -- `trainer.trainer_type`: The backend of the trainer, Only `verl` is supported. -- `trainer.save_interval`: The interval steps between two checkpoints. Default is `100`. +- `trainer_type`: Trainer backend implementation. Currently only supports `verl`. +- `save_interval`: Frequency (in steps) at which to save model checkpoints. +- `trainer_config_path`: The path to the trainer configuration file. +- `train_config`: The configuration of the trainer. Only one needs to be set for `trainer.trainer_config` and `trainer.trainer_config_path` + +--- + +## Data Processor Configuration + +Configures preprocessing and data cleaning pipelines. + +```yaml +data_processor: + source_data_path: '/PATH/TO/DATASET' + load_kwargs: + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml' + clean_strategy: 'iterative' + db_url: 'postgresql://{username}@localhost:5432/{db_name}' +``` + +- `source_data_path`: Path to the raw dataset. +- `load_kwargs`: Arguments passed to HuggingFace’s `load_dataset()`. +- `dj_config_path`: Path to Data-Juicer configuration for cleaning. +- `clean_strategy`: Strategy for iterative data cleaning. +- `db_url`: Database URL if using SQL backend. + +--- -- `trainer.actor_grad_clip`: Gradient clip for actor model training. -- `trainer.actor_clip_ratio`: Used for compute policy loss. -- `trainer.actor_entropy_coeff`: Used for compute policy loss. -- `trainer.actor_use_kl_loss`: Whether to enable kl loss. -- `trainer.actor_kl_loss_coef`: The coefficient of kl loss. +## veRL Trainer Configuration (Advanced) -- `trainer.train_config`: The configuration of the trainer. Only one needs to be set for `trainer.trainer_config` and `trainer.trainer_config_path` -- `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually. +For advanced users working with the `verl` trainer backend. This includes fine-grained settings for actor/critic models, optimizer parameters, and training loops. -### veRL Trainer Configuration +> For full parameter meanings, refer to the [veRL documentation](https://github.com/volcengine/verl/blob/v0.3.0.post1/docs/examples/config.rst). -Here we mainly introduce the parameters that can be set in veRL. For the specific meaning of the parameters, please refer to the official document of [veRL](https://github.com/volcengine/verl/blob/0bdf7f469854815177e73dcfe9e420836c952e6e/docs/examples/config.rst). ```yaml actor_rollout_ref: diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 7b7cc1dc76..58e467b20a 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -1,4 +1,4 @@ -# Trinity-RFT Developer Guide +# Developer Guide This guide will introduce how to add new task types to Trinity-RFT and provide relevant development guidelines. @@ -87,7 +87,7 @@ The `rollout_args` field contains the parameters for the rollout process, such a - `auxiliary_models`: A list of auxiliary models, which will not be trained. All of them provide OpenAI compatible API. ```{tip} -The `model` also provided an OpenAI compatible API, you can switch to it by setting `explorer.enable_openai_api` to `true` in your config file and use `model.get_openai_client()` to get an `openai.OpenAI` instance. +The `model` also provided an OpenAI compatible API, you can switch to it by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and use `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. ``` #### Example Code @@ -130,6 +130,23 @@ class ExampleWorkflow(Workflow): ] ``` +For some heavy workflows, the initialization process may be time-consuming. +In this case, you can implement the `resettable` and `reset` methods to avoid re-initialization. + +```python +@WORKFLOWS.register_module("example_workflow") +class ExampleWorkflow(Workflow): + # some code + # ... + + def resettable(self): + return True + + def reset(self, task: Task): + self.question = task.raw_task.get("question") + self.answer = task.raw_task.get("answer") +``` + --- ### Step 3: Modify Configuration File @@ -141,10 +158,10 @@ buffer: # Other fields explorer_input: taskset: - name: taskset_name - path: 'path/to/taskset' + name: example_task + storage_type: file + path: /path/to/taskset # Other fields - eval_tasksets: [] default_workflow_type: example_workflow # Other fields ``` diff --git a/pyproject.toml b/pyproject.toml index 1ec92220ba..dcf86f8349 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,9 @@ build-backend = "setuptools.build_meta" [project] name = "trinity-rft" version = "0.1.0" -authors = [] +authors = [ + {name="Trinity-RFT Team", email="trinity-rft@outlook.com"}, +] description = "Trinity-RFT: A Framework for Training Large Language Models with Reinforcement Fine-Tuning" readme = "README.md" classifiers = [ @@ -97,3 +99,7 @@ exclude = ''' [tool.isort] known_third_party = ["wandb"] + +[project.urls] +"Homepage" = "https://github.com/modelscope/Trinity-RFT" +"Documentation" = "https://modelscope.github.io/Trinity-RFT/" diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index baf5ffdd8c..b5104f2cc7 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -108,7 +108,7 @@ def get_openai_client(self) -> openai.OpenAI: if not ray.get(self.model.has_api_server.remote()): raise ValueError( "OpenAI API server is not running on current model." - "Please set `explorer.enable_openai_api` to `True`." + "Please set `enable_openai_api` to `True`." ) api_address = None while True: