Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 209 additions & 21 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
@@ -1,36 +1,44 @@
# Developer Guide

This guide introduces how to add new workflows to Trinity-RFT and provides relevant development guidelines.
This guide introduces how to develop new modules to Trinity-RFT and provides relevant development guidelines.

Trinity-RFT consists of three main modules: **Explorer**, **Trainer** and **Buffer**.
We decouple the RL pipeline into three modules to make it easier to customize and extend.
Below is a table summarizing the modules that different types of developers need to focus on.

| Developer Type | Focus Module | Key Component |
|----------------|--------------|---------------|
| Developers aiming to extend existing RL algorithms to new environments | *Explorer* | `Workflow` |
| Developers needing to design and implement new RL algorithms for comparing training effectiveness | *Trainer* | `Algorithm` |
| Developers seeking to enhance training performance from the data perspective | *Buffer* | Data Processing Module (Coming soon) |

```{note}
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
```

---

## Creating New Workflows
## Workflows (For RL Environment Developers)

Trinity-RFT allows developers to register new workflows (e.g., for multi-turn interactions or agentic scenarios). Below are the steps to create a new workflow:
In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments.
A qualified workflow needs to use the trained model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow:

---

### Step 0: Basic Concepts

Before starting development, it's important to understand several core concepts:


- **Task** ({class}`trinity.common.workflows.Task`): Represents a data structure that can be converted into a `Workflow`. The content of the `Task` varies depending on the task type:
- **Math problems**: A `Task` contains the problem description and the golden answer.
- **Programming scenarios**: A `Task` includes the problem description, test cases, runtime environment, and other complex information.


- **Workflow** ({class}`trinity.common.workflows.Workflow`): Can be understood as the running state of a `Task`. It defines the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows:
- **Workflow** ({class}`trinity.common.workflows.Workflow`): Can be understood as the running state of a `Task`. It defines the interaction flow between Agents and Environments, including logic similar to *Rollout* and *Reward* calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows:
- `MathWorkflow` ({class}`trinity.common.workflows.MathWorkflow`): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards).
- `WebShopWorkflow` ({class}`trinity.common.workflows.WebShopWorkflow`): For webshop scenarios, it contains multi-turn interaction with environment.
- `CodeWorkflow` (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results.
- ...


- **Experience** ({class}`trinity.common.experience.Experience`): The output of running a `Workflow`. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.

---
Expand All @@ -40,12 +48,12 @@ Before starting development, it's important to understand several core concepts:
The task dataset is loaded via the `buffer.explorer_input.taskset` configuration entry in your YAML config file.
To handle differences in `Task` contents, Trinity-RFT provides a unified `Task` interface containing the following fields.

- **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file.
- **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.
- **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.
- **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field.
- **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file.
- **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.
- **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.
- **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field.

```{tip}
`workflow`, `workflow_args` and `raw_task` provide different levels of customization.
Expand Down Expand Up @@ -82,7 +90,6 @@ buffer:

In this example, each task object's `raw_task` is a `Dict` with two keys (`question` and `answer`). The `MathWorkflow` uses the `prompt_key` and `response_key` to extract the question and answer from the `raw_task` and use the `rollout_args` to generate the response.


---

### Step 2: Implement a New Workflow
Expand All @@ -106,7 +113,6 @@ class Workflow(ABC):
"""Run the workflow and return a list of Experiences."""
```


#### Initializing Your Workflow

During initialization, `Workflow` receives the following parameters:
Expand All @@ -115,7 +121,6 @@ During initialization, `Workflow` receives the following parameters:
- `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset.
- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.


```{tip}
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
```
Expand Down Expand Up @@ -143,7 +148,6 @@ We first call the model to generate multiple response using the provided questio
Then we calculate the reward for each response using the `calculate_reward` function.
Finally, we construct a list of `Experience` with the responses and rewards and return it.


```python
class ExampleWorkflow(Workflow):

Expand Down Expand Up @@ -215,7 +219,6 @@ For workflows that are not intended to be contributed to Trinity-RFT project, yo
You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `<Trinity_RFT_ROOT_DIR>/trinity/plugins` as the default directory.
```


#### Avoid Re-initialization

For heavy workflows, re-initializing every time can incurs extra computational costs.
Expand All @@ -235,7 +238,6 @@ class ExampleWorkflow(Workflow):
self.answer = task.raw_task.get("answer")
```


#### Full Code Example

```python
Expand Down Expand Up @@ -289,7 +291,6 @@ class ExampleWorkflow(Workflow):
self.answer = task.raw_task.get("answer")
```


---

### Step 3: Use Your Workflow
Expand All @@ -314,6 +315,191 @@ trinity run --config <your_yaml_file>

---

## Algorithms (For RL Algorithm Developers)

Trinity-RFT provides a standardized process for implementing new algorithms.

### Step 0: Basic Concepts of Algorithm Module

In Trinity-RFT, the algorithm module is primarily responsible for extracting experience data from the Replay Buffer during the RL process and calculating the loss to update models based on this data.
To avoid implementing a new Trainer class each time a new algorithm is added, we have decomposed the representative PPO algorithm process into multiple sub-modules to adapt to various algorithms.

- **Sample Strategy**: Responsible for reading experience data from the buffer module. By customizing this module, you can implement requirements like filtering experience data or mixed sampling from multiple data sources.
- **Advantage Fn**: Responsible for calculating the Advantage and Returns of experience data.
- **Policy Loss Fn**: Responsible for calculating the loss of the policy network.
- **KL Fn**: Responsible for calculating KL Divergence, which is generally used in two places in existing RL algorithms: Reward Penalty and Actor Loss.
- **Entropy Loss Fn**: Responsible for calculating the entropy loss of the policy network.

We provide several implementations of above modules in `trinity/algorithm`.


### Step 1: Implementing Algorithm Components


Trinity-RFT allows developers to customize all the above modules. Developers only need to implement specific modules according to the requirements of their new algorithm. This section will provide a simple introduction using the OPMD algorithm as an example.

The main difference between OPMD and PPO algorithms lies in the calculation of Advantage and Policy Loss. Therefore, only new Advantage Fn and Policy Loss Fn modules need to be implemented.

#### Step 1.1: Implementing `AdvantageFn`

Developers need to implement the {class}`trinity.algorithm.AdvantageFn` interface, which mainly includes two methods:

- `__call__`: Calculates advantages and returns based on input experience data, records observable metrics during the calculation process, and returns the experience data containing advantages and returns as well as a metrics dictionary. The input experience data format is `verl`'s `DataProto`.
- `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file.

After implementation, you need to register this module through {class}`trinity.algorithm.ADVANTAGE_FN`. Once registered, the module can be configured in the configuration file using the registered name.

Here's an implementation example for the OPMD algorithm's Advantage Fn:

```python
# trinity/algorithm/advantage_fn/opmd.py
# import some modules
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn


@ADVANTAGE_FN.register_module("opmd")
class OPMDAdvantageFn(AdvantageFn):
"""OPMD advantage computation"""

def __init__(
self,
opmd_baseline: str = "mean",
tau: float = 1.0,
) -> None:
self.opmd_baseline = opmd_baseline
self.tau = tau


def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
# calculate advantages and returns based on the exps

# record some metrics

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"opmd_baseline": "mean",
"tau": 1.0,
}
```

#### Step 1.2: Implementing `PolicyLossFn`

Developers need to implement the {class}`trinity.algorithm.PolicyLossFn` interface, which is similar to `AdvantageFn` and includes two methods:

- `__call__`: Calculates the loss based on input parameters. Unlike `AdvantageFn`, the input parameters here are all `torch.Tensor`. This interface automatically scans the parameter list of the `__call__` method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters from `kwargs`.
- `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file.

Similarly, after implementation, you need to register this module through {class}`trinity.algorithm.POLICY_LOSS_FN`.

Here's an implementation example for the OPMD algorithm's Policy Loss Fn. Since OPMD's Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the `__call__` method:


```python
@POLICY_LOSS_FN.register_module("opmd")
class OPMDPolicyLossFn(PolicyLossFn):
def __init__(self, tau: float = 1.0) -> None:
self.tau = tau

def __call__( # type: ignore
self,
logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
pg_losses = -advantages * logprob
opmd_loss = masked_mean(pg_losses, action_mask)
opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta)
return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}

@classmethod
def default_args(cls) -> Dict:
return {"tau": 1.0}
```

### Step 2: Register Your Algorithm

The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.

To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {class}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.

上述步骤实现了算法所需要的各项组件,但这些组件较为分散,需要在多处进行配置才能生效。
为了简化配置,Trinity-RFT 提供了 {class}`trinity.algorithm.algorithm.AlgorithmType` 用于描述完整的算法,并注册到 {class}`trinity.algorithm.ALGORITHM_TYPE` 中,从而实现一键式配置。

The `AlgorithmType` class includes the following attributes and methods:

- `use_critic`: Whether to use the Critic model
- `use_reference`: Whether to use the Reference model
- `use_advantage`: Whether to calculate Advantage; if False, the `AdvantageFn` call will be skipped
- `can_balance_batch`: Whether the algorithm can automatically balance batches
- `schema`: The format of experience data corresponding to the algorithm
- `get_default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `trinity.algorithm.ALGORITHM_TYPE`

Similarly, after implementation, you need to register this module through {class}`trinity.algorithm.ALGORITHM_TYPE`.

Below is the implementation for the OPMD algorithm.
Since the OPMD algorithm doesn't need to use the Critic model, `use_critic` is set to `False`.
The dictionary returned by the `get_default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss.

```python
@ALGORITHM_TYPE.register_module("opmd")
class OPMDAlgorithm(AlgorithmType):
"""OPMD algorithm."""

use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel

@classmethod
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
"policy_loss_fn": "opmd",
"advantage_fn": "opmd",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}
```

### Step 3: Use Your Algorithm

After completing all the above steps, you can use the newly registered algorithm through a YAML configuration file.

For default configurations, you just need to add the following content to your `config.yaml` file:

```yaml
# some other configs
algorithm:
algorithm_type: "opmd"
# some other configs
```

If you need to modify certain parameters, you can simply add the corresponding parameters within the `algorithm` section. For example, if you need to modify `repeat_times` and the initialization parameters of `AdvantageFn`, the modified `config.yaml` file would be as follows:

```yaml
# some other configs
algorithm:
algorithm_type: "opmd"
repeat_times: 8
advantage_fn_args:
opmd_baseline: "logavgexp"
tau: 0.99
# some other configs
```

---

## Adding New Config Entries for the Config Generator (Advanced)

### Step 0: Understanding Streamlit
Expand Down Expand Up @@ -344,11 +530,11 @@ The `CONFIG_GENERATORS.register_config` decorator automatically passes `key=conf
```

For `train_batch_size`, we will use the following settings:

- Default value: 96
- Visibility condition: `lambda: st.session_state["trainer_gpu_num"] > 0`
- Additional config: `{"_train_batch_size_per_gpu": 16}`


Here's the complete code for the `train_batch_size` parameter:

```python
Expand Down Expand Up @@ -408,6 +594,7 @@ To successfully integrate new parameters into the `config_manager.py` file, plea
Incorporate the new parameter into the relevant section using the `self.get_configs` method within the `ConfigManager` class.

Example:

```python
class ConfigManager:
def _expert_buffer_part(self):
Expand All @@ -421,6 +608,7 @@ To successfully integrate new parameters into the `config_manager.py` file, plea
Utilize `st.session_state` to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML.

Example:

```python
class ConfigManager:
def _gen_buffer_config(self):
Expand Down
Loading