Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 164 additions & 50 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
@@ -1,57 +1,84 @@
# Developer Guide

This guide will introduce how to add new task types to Trinity-RFT and provide relevant development guidelines.
This guide will introduce how to add new workflows to Trinity-RFT and provide relevant development guidelines.

```{note}
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
```

---

## Creating New Task Types
## Creating New Workflows

Trinity-RFT supports developers in registering new task types (e.g., multi-round interaction scenarios). Below are the steps for creating a new task type.
Trinity-RFT supports developers in registering new workflows (e.g., multi-round interaction scenarios). Below are the steps to create a new workflow:

---

### Step 0: Basic Concepts

Before starting development, it's important to understand several core concepts:

- **Task**: Represents a data structure that can be converted into a `Workflow`. The `Task` data format may vary significantly depending on the type of task:

- **Task** ({class}`trinity.common.workflows.Task`): Represents a data structure that can be converted into a `Workflow`. The `Task` data format may vary significantly depending on the type of task:
- **Math problems**: `Task` contains the problem description and the standard answer.
- **Programming scenarios**: `Task` includes the problem description, test cases, runtime environment, and other complex information.

- **Workflow**: Can be understood as the running state of a `Task`, defining the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates `Experience`. Trinity-RFT has several built-in `Workflows`:
- `MathWorkflow`: For math scenarios, submits problems to LLM, parses results, and calculates scores (rewards).

- **Workflow** ({class}`trinity.common.workflows.Workflow`): Can be understood as the running state of a `Task`, defining the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates `Experience`. Trinity-RFT has several built-in workflows:
- `MathWorkflow` ({class}`trinity.common.workflows.MathWorkflow`): For math scenarios, submits problems to LLM, parses results, and calculates scores (rewards).
- `WebShopWorkflow` ({class}`trinity.common.workflows.WebShopWorkflow`): For webshop scenarios, it contains multi-turn interaction with environment.
- `CodeWorkflow` (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results.
- ...

- **Experience**: The output of running a `Workflow`, where the internal data format depends on the algorithm used for training. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token_ids, action_mask (identifying which tokens were generated by the LLM), logprobs, rewards, etc.

- **Experience** ({class}`trinity.common.experience.Experience`): The output of running a `Workflow`, where the internal data format depends on the algorithm used for training. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token id, action_mask (identifying which tokens were generated by the LLM), logprobs, rewards, etc.

---

### Step 1: Prepare Task Dataset

Each `Task` contains various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario.
The explorer load the task dataset through the `buffer.explorer_input.taskset` in configuration file.
To deal with the differences in `Task` data format, Trinity-RFT provides a unified `Task` interface, which containes the following fields.

In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line’s JSON contains `question` and `answer` fields representing the problem description and standard answer, respectively.
- **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your yaml config file.
- **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some some workflows have already integrated the reward calculation, you can ignore this field in such cases.
- **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without the following fields.
- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. The `format_args` comes from the yaml configuration file, and you can set it in the `buffer.explorer_input.task_set.format` of the yaml file.
- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters to facilitate the rollout process, e.g., the `temperature`. This field also comes from the yaml configuration file, and you can set it in the `buffer.explorer_input.task_set.rollout_args` of the yaml file.

In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line’s JSON contains `question` and `answer` fields representing the problem description and standard answer, respectively. For example:

```
{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...
```

```yaml
# some config
buffer:
explorer_input:
taskset:
default_workflow: "math_workflow"
path: "/PATH/TO/FILE/DIR"
format:
prompt_key: "question"
response_key: "answer"
rollout_args:
temperature: 1.0
# some other configs
```

In this example, each task object's `raw_task` is a `Dict` with two keys (`question` and `answer`), and the `MathWorkflow` will use the `prompt_key` and `response_key` to extract the question and answer from the `raw_task` and use the `rollout_args` to generate the response.


---

### Step 2: Write Workflow
### Step 2: Implement a New Workflow

The core of creating a new task type is writing a new `Workflow`, whose base class interface is as follows:
The `Workflow` base class interface is as follows:

```python
# import some packages

class Workflow(ABC):

def __init__(
Expand All @@ -68,39 +95,47 @@ class Workflow(ABC):
"""Run the workflow and return a list of Experiences."""
```

Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflow` classes.

```python
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("my_workflow")
class MyWorkflow(Workflow):
pass
```
#### Initialization Your Workflow

#### Initialization Parameters
When initializing, `Workflow` receives the following parameters:
- `model`: The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `task`: An instance of `Task`, which is generated by one line of data from the `Task` dataset. The `raw_task` field contains the `Dict` format source data, which can be used to construct the `Workflow` instance.
The `rollout_args` field contains the parameters for the rollout process, such as `n`, `temperature`, `top_k` and `top_p`.
- `auxiliary_models`: A list of auxiliary models, which will not be trained. All of them provide OpenAI compatible API.
- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `task`({class}`trinity.common.workflows.Task`): An data item generated by one line of data from the task dataset.
- `auxiliary_models`(`List` of `openai.OpenAI`): A list of auxiliary models, which will not be trained. All of them are provide as OpenAI compatible API.


```{tip}
The `model` also provided an OpenAI compatible API, you can switch to it by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and use `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
```

#### Example Code
Below is a simple example demonstrating how to implement a math problem `Workflow`:
In the example below, we only use the `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` in `Task` to further the initialization.

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

def __init__(self, model: ModelWrapper, task: Task, **kwargs):
super().__init__(model, **kwargs)
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
super().__init__(model, task, auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
# Optional: If you want to use OpenAI API in your workflow
# self.openai_client = self.model.get_openai_client()
```

#### Implement the `run` method

The `run` method is the core of your workflow. It returns a list of `Experience`.
Below is a simple example demonstrating how to implement the `run` method for a math workflow.

We first call the model to generate multiple response using the provided question and rollout arguments.
And then we use the `calculate_reward` function to calculate the reward for each response.
Finally, we construct a list of `Experience` with the responses and rewards and return it.


```python
class ExampleWorkflow(Workflow):

# the __init__ function

def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
Expand All @@ -109,27 +144,48 @@ class ExampleWorkflow(Workflow):
return 0.0

def run(self) -> List[Experience]:
response = self.model.chat(
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.task.rollout_args.n,
temperature=self.task.rollout_args.temperature,
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
reward: float = self.calculate_reward(response.response_text, self.answer)
return [
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
]
return experiences
```

#### Register Your Workflow

Developers can register `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflow` classes.

```python
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
pass
```

#### Avoid Re-initialization

For some heavy workflows, the initialization process may be time-consuming.
In this case, you can implement the `resettable` and `reset` methods to avoid re-initialization.

Expand All @@ -147,23 +203,81 @@ class ExampleWorkflow(Workflow):
self.answer = task.raw_task.get("answer")
```


#### Full Code

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
super().__init__(model, task, auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args

def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0

def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
return experiences

def resettable(self):
return True

def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
```


---

### Step 3: Modify Configuration File
### Step 3: Use Your Workflow

After completing the development of the `Workflow`, you need to modify the configuration file to set the `default_workflow_type` in the `buffer.explorer_input` domain to the newly registered `Workflow` name.
After completing the development of the `Workflow`, you need to modify the configuration file to set the `default_workflow_type` in the `buffer.explorer_input.taskset` domain to the newly registered `Workflow` name.

```yaml
buffer:
# Other fields
explorer_input:
taskset:
name: example_task
storage_type: file
path: /path/to/taskset
# Other fields
default_workflow_type: example_workflow
# Other fields
default_workflow_type: example_workflow
# Other fields
```

Then you can run your workflow in the RFT procesing, through the following command.

```
trinity run --config <your_yaml_file>
```

---
Expand Down