Skip to content

Commit 7bddb32

Browse files
authored
Update developer guide (#53)
1 parent 43ffbc0 commit 7bddb32

File tree

1 file changed

+172
-54
lines changed

1 file changed

+172
-54
lines changed

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 172 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,86 @@
11
# Developer Guide
22

3-
This guide will introduce how to add new task types to Trinity-RFT and provide relevant development guidelines.
3+
This guide introduces how to add new workflows to Trinity-RFT and provides relevant development guidelines.
44

55
```{note}
66
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
77
```
88

99
---
1010

11-
## Creating New Task Types
11+
## Creating New Workflows
1212

13-
Trinity-RFT supports developers in registering new task types (e.g., multi-round interaction scenarios). Below are the steps for creating a new task type.
13+
Trinity-RFT allows developers to register new workflows (e.g., for multi-turn interactions or agentic scenarios). Below are the steps to create a new workflow:
1414

1515
---
1616

1717
### Step 0: Basic Concepts
1818

1919
Before starting development, it's important to understand several core concepts:
2020

21-
- **Task**: Represents a data structure that can be converted into a `Workflow`. The `Task` data format may vary significantly depending on the type of task:
22-
- **Math problems**: `Task` contains the problem description and the standard answer.
23-
- **Programming scenarios**: `Task` includes the problem description, test cases, runtime environment, and other complex information.
2421

25-
- **Workflow**: Can be understood as the running state of a `Task`, defining the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates `Experience`. Trinity-RFT has several built-in `Workflows`:
26-
- `MathWorkflow`: For math scenarios, submits problems to LLM, parses results, and calculates scores (rewards).
22+
- **Task** ({class}`trinity.common.workflows.Task`): Represents a data structure that can be converted into a `Workflow`. The content of the `Task` varies depending on the task type:
23+
- **Math problems**: A `Task` contains the problem description and the golden answer.
24+
- **Programming scenarios**: A `Task` includes the problem description, test cases, runtime environment, and other complex information.
25+
26+
27+
- **Workflow** ({class}`trinity.common.workflows.Workflow`): Can be understood as the running state of a `Task`. It defines the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows:
28+
- `MathWorkflow` ({class}`trinity.common.workflows.MathWorkflow`): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards).
29+
- `WebShopWorkflow` ({class}`trinity.common.workflows.WebShopWorkflow`): For webshop scenarios, it contains multi-turn interaction with environment.
2730
- `CodeWorkflow` (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results.
2831
- ...
2932

30-
- **Experience**: The output of running a `Workflow`, where the internal data format depends on the algorithm used for training. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token_ids, action_mask (identifying which tokens were generated by the LLM), logprobs, rewards, etc.
33+
34+
- **Experience** ({class}`trinity.common.experience.Experience`): The output of running a `Workflow`. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.
3135

3236
---
3337

3438
### Step 1: Prepare Task Dataset
3539

36-
Each `Task` contains various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario.
40+
The task dataset is loaded via the `buffer.explorer_input.taskset` configuration entry in your YAML config file.
41+
To handle differences in `Task` contents, Trinity-RFT provides a unified `Task` interface containing the following fields.
3742

38-
In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line’s JSON contains `question` and `answer` fields representing the problem description and standard answer, respectively.
43+
- **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file.
44+
- **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.
45+
- **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
46+
- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
47+
- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.
48+
49+
In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example:
3950

4051
```
4152
{"question": "1+1=", "answer": "2"}
4253
{"question": "2+2=", "answer": "4"}
4354
...
4455
```
4556

57+
Example configuration snippet:
58+
59+
```yaml
60+
# some config
61+
buffer:
62+
explorer_input:
63+
taskset:
64+
default_workflow: "math_workflow"
65+
path: "/PATH/TO/FILE/DIR"
66+
format:
67+
prompt_key: "question"
68+
response_key: "answer"
69+
rollout_args:
70+
temperature: 1.0
71+
# some other configs
72+
```
73+
74+
In this example, each task object's `raw_task` is a `Dict` with two keys (`question` and `answer`). The `MathWorkflow` uses the `prompt_key` and `response_key` to extract the question and answer from the `raw_task` and use the `rollout_args` to generate the response.
75+
76+
4677
---
4778

48-
### Step 2: Write Workflow
79+
### Step 2: Implement a New Workflow
4980

50-
The core of creating a new task type is writing a new `Workflow`, whose base class interface is as follows:
81+
The `Workflow` base class interface is as follows:
5182

5283
```python
53-
# import some packages
54-
5584
class Workflow(ABC):
5685

5786
def __init__(
@@ -68,39 +97,48 @@ class Workflow(ABC):
6897
"""Run the workflow and return a list of Experiences."""
6998
```
7099

71-
Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflow` classes.
72100

73-
```python
74-
# import some packages
75-
from trinity.common.workflows.workflow import WORKFLOWS
101+
#### Initializing Your Workflow
76102

77-
@WORKFLOWS.register_module("my_workflow")
78-
class MyWorkflow(Workflow):
79-
pass
80-
```
103+
During initialization, `Workflow` receives the following parameters:
104+
105+
- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
106+
- `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset.
107+
- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.
81108

82-
#### Initialization Parameters
83-
When initializing, `Workflow` receives the following parameters:
84-
- `model`: The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
85-
- `task`: An instance of `Task`, which is generated by one line of data from the `Task` dataset. The `raw_task` field contains the `Dict` format source data, which can be used to construct the `Workflow` instance.
86-
The `rollout_args` field contains the parameters for the rollout process, such as `n`, `temperature`, `top_k` and `top_p`.
87-
- `auxiliary_models`: A list of auxiliary models, which will not be trained. All of them provide OpenAI compatible API.
88109

89110
```{tip}
90-
The `model` also provided an OpenAI compatible API, you can switch to it by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and use `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
111+
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
91112
```
92113

93-
#### Example Code
94-
Below is a simple example demonstrating how to implement a math problem `Workflow`:
114+
Here’s an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
95115

96116
```python
97-
@WORKFLOWS.register_module("example_workflow")
98117
class ExampleWorkflow(Workflow):
99118

100-
def __init__(self, model: ModelWrapper, task: Task, **kwargs):
101-
super().__init__(model, **kwargs)
119+
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
120+
super().__init__(model, task, auxiliary_models)
102121
self.question = task.raw_task.get("question")
103122
self.answer = task.raw_task.get("answer")
123+
self.rollout_args = task.rollout_args
124+
# Optional: If you want to use OpenAI API in your workflow
125+
# self.openai_client = self.model.get_openai_client()
126+
```
127+
128+
#### Implementing the `run` method
129+
130+
The `run` method is the core of your workflow. It returns a list of `Experience`.
131+
Below is a simple implementation for a math workflow.
132+
133+
We first call the model to generate multiple response using the provided question and rollout arguments.
134+
Then we calculate the reward for each response using the `calculate_reward` function.
135+
Finally, we construct a list of `Experience` with the responses and rewards and return it.
136+
137+
138+
```python
139+
class ExampleWorkflow(Workflow):
140+
141+
# the __init__ function
104142

105143
def calculate_reward(self, response: str, truth: str) -> float:
106144
if response == truth:
@@ -109,28 +147,50 @@ class ExampleWorkflow(Workflow):
109147
return 0.0
110148

111149
def run(self) -> List[Experience]:
112-
response = self.model.chat(
150+
# call the model to generate multiple responses
151+
responses = self.model.chat(
113152
[
114153
{
115154
"role": "user",
116155
"content": f"Question:\n{self.question}",
117156
}
118157
],
119-
n=self.task.rollout_args.n,
120-
temperature=self.task.rollout_args.temperature,
158+
n=self.rollout_args.n,
159+
temperature=self.rollout_args.temperature,
121160
)
122-
reward: float = self.calculate_reward(response.response_text, self.answer)
123-
return [
124-
Experience(
125-
tokens=response.tokens,
126-
prompt_length=response.prompt_length,
127-
reward=reward,
128-
logprobs=response.logprobs,
161+
experiences = []
162+
for response in responses:
163+
# calulcate reward
164+
reward: float = self.calculate_reward(response.response_text, self.answer)
165+
# construct Experience
166+
experiences.append(
167+
Experience(
168+
tokens=response.tokens,
169+
prompt_length=response.prompt_length,
170+
reward=reward,
171+
logprobs=response.logprobs,
172+
)
129173
)
130-
]
174+
return experiences
131175
```
132176

133-
For some heavy workflows, the initialization process may be time-consuming.
177+
#### Registering Your Workflow
178+
179+
Register your workflow using the `WORKFLOWS.register_module` decorator.
180+
Ensure the name does not conflict with existing workflows.
181+
182+
```python
183+
# import some packages
184+
from trinity.common.workflows.workflow import WORKFLOWS
185+
186+
@WORKFLOWS.register_module("example_workflow")
187+
class ExampleWorkflow(Workflow):
188+
pass
189+
```
190+
191+
#### Avoid Re-initialization
192+
193+
For heavy workflows, re-initializing every time can incurs extra computational costs.
134194
In this case, you can implement the `resettable` and `reset` methods to avoid re-initialization.
135195

136196
```python
@@ -147,23 +207,81 @@ class ExampleWorkflow(Workflow):
147207
self.answer = task.raw_task.get("answer")
148208
```
149209

210+
211+
#### Full Code Example
212+
213+
```python
214+
@WORKFLOWS.register_module("example_workflow")
215+
class ExampleWorkflow(Workflow):
216+
217+
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
218+
super().__init__(model, task, auxiliary_models)
219+
self.question = task.raw_task.get("question")
220+
self.answer = task.raw_task.get("answer")
221+
self.rollout_args = task.rollout_args
222+
223+
def calculate_reward(self, response: str, truth: str) -> float:
224+
if response == truth:
225+
return 1.0
226+
else:
227+
return 0.0
228+
229+
def run(self) -> List[Experience]:
230+
# call the model to generate multiple responses
231+
responses = self.model.chat(
232+
[
233+
{
234+
"role": "user",
235+
"content": f"Question:\n{self.question}",
236+
}
237+
],
238+
n=self.rollout_args.n,
239+
temperature=self.rollout_args.temperature,
240+
)
241+
experiences = []
242+
for response in responses:
243+
# calulcate reward
244+
reward: float = self.calculate_reward(response.response_text, self.answer)
245+
# construct Experience
246+
experiences.append(
247+
Experience(
248+
tokens=response.tokens,
249+
prompt_length=response.prompt_length,
250+
reward=reward,
251+
logprobs=response.logprobs,
252+
)
253+
)
254+
return experiences
255+
256+
def resettable(self):
257+
return True
258+
259+
def reset(self, task: Task):
260+
self.question = task.raw_task.get("question")
261+
self.answer = task.raw_task.get("answer")
262+
```
263+
264+
150265
---
151266

152-
### Step 3: Modify Configuration File
267+
### Step 3: Use Your Workflow
153268

154-
After completing the development of the `Workflow`, you need to modify the configuration file to set the `default_workflow_type` in the `buffer.explorer_input` domain to the newly registered `Workflow` name.
269+
After implementing and registering your workflow, you need to update the configuration file to set the `default_workflow_type` in the `buffer.explorer_input.taskset` domain to the newly registered `Workflow` name.
155270

156271
```yaml
157272
buffer:
158273
# Other fields
159274
explorer_input:
160275
taskset:
161-
name: example_task
162-
storage_type: file
163276
path: /path/to/taskset
164-
# Other fields
165-
default_workflow_type: example_workflow
166-
# Other fields
277+
default_workflow_type: example_workflow
278+
# Other fields
279+
```
280+
281+
Now you can run your workflow in Trinity-RFT using the command:
282+
283+
```
284+
trinity run --config <your_yaml_file>
167285
```
168286

169287
---

0 commit comments

Comments
 (0)