Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 172 additions & 54 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
@@ -1,57 +1,86 @@
# Developer Guide

This guide will introduce how to add new task types to Trinity-RFT and provide relevant development guidelines.
This guide introduces how to add new workflows to Trinity-RFT and provides relevant development guidelines.

```{note}
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
```

---

## Creating New Task Types
## Creating New Workflows

Trinity-RFT supports developers in registering new task types (e.g., multi-round interaction scenarios). Below are the steps for creating a new task type.
Trinity-RFT allows developers to register new workflows (e.g., for multi-turn interactions or agentic scenarios). Below are the steps to create a new workflow:

---

### Step 0: Basic Concepts

Before starting development, it's important to understand several core concepts:

- **Task**: Represents a data structure that can be converted into a `Workflow`. The `Task` data format may vary significantly depending on the type of task:
- **Math problems**: `Task` contains the problem description and the standard answer.
- **Programming scenarios**: `Task` includes the problem description, test cases, runtime environment, and other complex information.

- **Workflow**: Can be understood as the running state of a `Task`, defining the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates `Experience`. Trinity-RFT has several built-in `Workflows`:
- `MathWorkflow`: For math scenarios, submits problems to LLM, parses results, and calculates scores (rewards).
- **Task** ({class}`trinity.common.workflows.Task`): Represents a data structure that can be converted into a `Workflow`. The content of the `Task` varies depending on the task type:
- **Math problems**: A `Task` contains the problem description and the standard answer.
- **Programming scenarios**: A `Task` includes the problem description, test cases, runtime environment, and other complex information.


- **Workflow** ({class}`trinity.common.workflows.Workflow`): Can be understood as the running state of a `Task`. . It defines the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows:
- `MathWorkflow` ({class}`trinity.common.workflows.MathWorkflow`): For math scenarios, submits problems to LLM, parses results, and calculates scores (rewards).
- `WebShopWorkflow` ({class}`trinity.common.workflows.WebShopWorkflow`): For webshop scenarios, it contains multi-turn interaction with environment.
- `CodeWorkflow` (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results.
- ...

- **Experience**: The output of running a `Workflow`, where the internal data format depends on the algorithm used for training. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token_ids, action_mask (identifying which tokens were generated by the LLM), logprobs, rewards, etc.

- **Experience** ({class}`trinity.common.experience.Experience`): The output of running a `Workflow`. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.

---

### Step 1: Prepare Task Dataset

Each `Task` contains various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario.
The task dataset is loaded via the `buffer.explorer_input.taskset` configuration entry in your YAML config file.
To handle differences in `Task` contents, Trinity-RFT provides a unified `Task` interface containing the following fields.

In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line’s JSON contains `question` and `answer` fields representing the problem description and standard answer, respectively.
- **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file.
- **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.
- **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. his field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.

In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example:

```
{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...
```

Example configuration snippet:

```yaml
# some config
buffer:
explorer_input:
taskset:
default_workflow: "math_workflow"
path: "/PATH/TO/FILE/DIR"
format:
prompt_key: "question"
response_key: "answer"
rollout_args:
temperature: 1.0
# some other configs
```

In this example, each task object's `raw_task` is a `Dict` with two keys (`question` and `answer`). The `MathWorkflow` uses the `prompt_key` and `response_key` to extract the question and answer from the `raw_task` and use the `rollout_args` to generate the response.


---

### Step 2: Write Workflow
### Step 2: Implement a New Workflow

The core of creating a new task type is writing a new `Workflow`, whose base class interface is as follows:
The `Workflow` base class interface is as follows:

```python
# import some packages

class Workflow(ABC):

def __init__(
Expand All @@ -68,39 +97,48 @@ class Workflow(ABC):
"""Run the workflow and return a list of Experiences."""
```

Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflow` classes.

```python
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS
#### Initializing Your Workflow

@WORKFLOWS.register_module("my_workflow")
class MyWorkflow(Workflow):
pass
```
During initialization, `Workflow` receives the following parameters:

- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset.
- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.

#### Initialization Parameters
When initializing, `Workflow` receives the following parameters:
- `model`: The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `task`: An instance of `Task`, which is generated by one line of data from the `Task` dataset. The `raw_task` field contains the `Dict` format source data, which can be used to construct the `Workflow` instance.
The `rollout_args` field contains the parameters for the rollout process, such as `n`, `temperature`, `top_k` and `top_p`.
- `auxiliary_models`: A list of auxiliary models, which will not be trained. All of them provide OpenAI compatible API.

```{tip}
The `model` also provided an OpenAI compatible API, you can switch to it by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and use `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
```

#### Example Code
Below is a simple example demonstrating how to implement a math problem `Workflow`:
Here’s an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

def __init__(self, model: ModelWrapper, task: Task, **kwargs):
super().__init__(model, **kwargs)
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
super().__init__(model, task, auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
# Optional: If you want to use OpenAI API in your workflow
# self.openai_client = self.model.get_openai_client()
```

#### Implementing the `run` method

The `run` method is the core of your workflow. It returns a list of `Experience`.
Below is a simple implementation for a math workflow.

We first call the model to generate multiple response using the provided question and rollout arguments.
Then we calculate the reward for each response using the `calculate_reward` function.
Finally, we construct a list of `Experience` with the responses and rewards and return it.


```python
class ExampleWorkflow(Workflow):

# the __init__ function

def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
Expand All @@ -109,28 +147,50 @@ class ExampleWorkflow(Workflow):
return 0.0

def run(self) -> List[Experience]:
response = self.model.chat(
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.task.rollout_args.n,
temperature=self.task.rollout_args.temperature,
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
reward: float = self.calculate_reward(response.response_text, self.answer)
return [
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
]
return experiences
```

For some heavy workflows, the initialization process may be time-consuming.
#### Registering Your Workflow

Register your workflow using the `WORKFLOWS.register_module` decorator.
Ensure the name does not conflict with existing workflows.

```python
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
pass
```

#### Avoid Re-initialization

For heavy workflows, avoid re-initializing resources every time.
In this case, you can implement the `resettable` and `reset` methods to avoid re-initialization.

```python
Expand All @@ -147,23 +207,81 @@ class ExampleWorkflow(Workflow):
self.answer = task.raw_task.get("answer")
```


#### Full Code Example

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
super().__init__(model, task, auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args

def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0

def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
return experiences

def resettable(self):
return True

def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
```


---

### Step 3: Modify Configuration File
### Step 3: Use Your Workflow

After completing the development of the `Workflow`, you need to modify the configuration file to set the `default_workflow_type` in the `buffer.explorer_input` domain to the newly registered `Workflow` name.
After implementing and registering your workflow, you need to update the configuration file to set the `default_workflow_type` in the `buffer.explorer_input.taskset` domain to the newly registered `Workflow` name.

```yaml
buffer:
# Other fields
explorer_input:
taskset:
name: example_task
storage_type: file
path: /path/to/taskset
# Other fields
default_workflow_type: example_workflow
# Other fields
default_workflow_type: example_workflow
# Other fields
```

Now you can run your workflow in Trinity-RFT using the command:

```
trinity run --config <your_yaml_file>
```

---
Expand Down