You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/sphinx_doc/source/tutorial/develop_algorithm.md
+5-8Lines changed: 5 additions & 8 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -47,9 +47,8 @@ For convenience, Trinity-RFT provides an abstract class {class}`trinity.algorith
47
47
Here's an implementation example for the OPMD algorithm's advantage function:
48
48
49
49
```python
50
-
from trinity.algorithm.advantage_fn importADVANTAGE_FN, GroupAdvantage
50
+
from trinity.algorithm.advantage_fn import GroupAdvantage
51
51
52
-
@ADVANTAGE_FN.register_module("opmd")
53
52
classOPMDGroupAdvantage(GroupAdvantage):
54
53
"""OPMD Group Advantage computation"""
55
54
@@ -90,7 +89,7 @@ class OPMDGroupAdvantage(GroupAdvantage):
90
89
return {"opmd_baseline": "mean", "tau": 1.0}
91
90
```
92
91
93
-
After implementation, you need to register this module through {class}`trinity.algorithm.ADVANTAGE_FN`. Once registered, the module can be configured in the configuration file using the registered name.
92
+
After implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/__init__.py`. Once registered, the module can be configured in the configuration file using the registered name.
94
93
95
94
96
95
#### Step 1.2: Implement `PolicyLossFn`
@@ -100,13 +99,12 @@ Developers need to implement the {class}`trinity.algorithm.PolicyLossFn` interfa
100
99
-`__call__`: Calculates the loss based on input parameters. Unlike `AdvantageFn`, the input parameters here are all `torch.Tensor`. This interface automatically scans the parameter list of the `__call__` method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters from `kwargs`.
101
100
-`default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file.
102
101
103
-
Similarly, after implementation, you need to register this module through {class}`trinity.algorithm.POLICY_LOSS_FN`.
102
+
Similarly, after implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/policy_loss_fn/__init__.py`.
104
103
105
104
Here's an implementation example for the OPMD algorithm's Policy Loss Fn. Since OPMD's Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the `__call__` method:
106
105
107
106
108
107
```python
109
-
@POLICY_LOSS_FN.register_module("opmd")
110
108
classOPMDPolicyLossFn(PolicyLossFn):
111
109
def__init__(self, tau: float=1.0) -> None:
112
110
self.tau = tau
@@ -134,7 +132,7 @@ class OPMDPolicyLossFn(PolicyLossFn):
134
132
135
133
The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.
136
134
137
-
To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {class}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.
135
+
To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in `trinity/algorithm/__init__.py`, enabling one-click configuration.
138
136
139
137
The `AlgorithmType` class includes the following attributes and methods:
140
138
@@ -145,14 +143,13 @@ The `AlgorithmType` class includes the following attributes and methods:
145
143
-`schema`: The format of experience data corresponding to the algorithm
146
144
-`default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE`
147
145
148
-
Similarly, after implementation, you need to register this module through `ALGORITHM_TYPE`.
146
+
Similarly, after implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/__init__.py`.
149
147
150
148
Below is the implementation for the OPMD algorithm.
151
149
Since the OPMD algorithm doesn't need to use the Critic model, `use_critic` is set to `False`.
152
150
The dictionary returned by the `default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss.
Copy file name to clipboardExpand all lines: docs/sphinx_doc/source/tutorial/develop_overview.md
+11-1Lines changed: 11 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,13 +17,23 @@ The table below lists the main functions of each extension interface, its target
17
17
Trinity-RFT provides a modular development approach, allowing you to flexibly add custom modules without modifying the framework code.
18
18
You can place your module code in the `trinity/plugins` directory. Trinity-RFT will automatically load all Python files in that directory at runtime and register the custom modules within them.
19
19
Trinity-RFT also supports specifying other directories at runtime by setting the `--plugin-dir` option, for example: `trinity run --config <config_file> --plugin-dir <your_plugin_dir>`.
20
+
Alternatively, you can use the relative path to the custom module in the YAML configuration file, for example: `default_workflow_type: 'examples.agentscope_frozenlake.workflow.FrozenLakeWorkflow'`.
20
21
```
21
22
22
23
For modules you plan to contribute to Trinity-RFT, please follow these steps:
23
24
24
25
1. Implement your code in the appropriate directory, such as `trinity/common/workflows` for `Workflow`, `trinity/algorithm` for `Algorithm`, and `trinity/buffer/operators` for `Operator`.
25
26
26
-
2. Register your module in the corresponding `__init__.py` file of the directory.
27
+
2. Register your module in the corresponding mapping dictionary in the `__init__.py` file of the directory.
28
+
For example, if you want to register a new workflow class `ExampleWorkflow`, you need to modify the `default_mapping` dictionary of `WORKFLOWS` in the `trinity/common/workflows/__init__.py` file:
Copy file name to clipboardExpand all lines: docs/sphinx_doc/source/tutorial/develop_selector.md
+11-5Lines changed: 11 additions & 5 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -60,7 +60,6 @@ To create a new selector, inherit from `BaseSelector` and implement the followin
60
60
This selector focuses on samples whose predicted performance is closest to a target (e.g., 90% success rate), effectively choosing "just right" difficulty tasks.
> 🔁 After defining your class, use `@SELECTORS.register_module("your_name")` so it can be referenced by name in configs.
127
+
> 🔁 After defining your class, remember to register it in the `default_mapping` of `trinity/buffer/selector/__init__.py` so it can be referenced by name in configs.
Copy file name to clipboardExpand all lines: docs/sphinx_doc/source/tutorial/develop_workflow.md
+7-24Lines changed: 7 additions & 24 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -176,28 +176,16 @@ class ExampleWorkflow(Workflow):
176
176
177
177
#### Registering Your Workflow
178
178
179
-
Register your workflow using the `WORKFLOWS.register_module` decorator.
179
+
Register your workflow using the `default_mapping` in `trinity/common/workflows/__init__.py`.
180
180
Ensure the name does not conflict with existing workflows.
181
181
182
182
```python
183
-
# import some packages
184
-
from trinity.common.workflows.workflow importWORKFLOWS
185
-
186
-
@WORKFLOWS.register_module("example_workflow")
187
-
classExampleWorkflow(Workflow):
188
-
pass
189
-
```
190
-
191
-
For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`:
192
-
193
-
```python
194
-
# existing import lines
195
-
from trinity.common.workflows.example_workflow import ExampleWorkflow
@@ -212,7 +200,6 @@ The `can_reset` is a class property that indicates whether the workflow supports
212
200
The `reset` method accepts a `Task` parameter and resets the workflow's internal state based on the new task.
213
201
214
202
```python
215
-
@WORKFLOWS.register_module("example_workflow")
216
203
classExampleWorkflow(Workflow):
217
204
can_reset: bool=True
218
205
@@ -234,7 +221,6 @@ The `can_repeat` is a class property that indicates whether the workflow support
234
221
The `set_repeat_times` method accepts two parameters: `repeat_times` specifies the number of times to execute within the `run` method, and `run_id_base` is an integer used to identify the first run ID in multiple runs (this parameter is used in multi-turn interaction scenarios; for tasks that can be completed with a single model call, this can be ignored).
235
222
236
223
```python
237
-
@WORKFLOWS.register_module("example_workflow")
238
224
classExampleWorkflow(Workflow):
239
225
can_repeat: bool=True
240
226
# some code
@@ -275,7 +261,6 @@ class ExampleWorkflow(Workflow):
275
261
#### Full Code Example
276
262
277
263
```python
278
-
@WORKFLOWS.register_module("example_workflow")
279
264
classExampleWorkflow(Workflow):
280
265
can_reset: bool=True
281
266
can_repeat: bool=True
@@ -359,7 +344,6 @@ trinity run --config <your_yaml_file>
359
344
The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set `is_async` to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, and the initialization parameter `auxiliary_models` will also change to `List[openai.AsyncOpenAI]`, while other methods and properties remain changed.
Copy file name to clipboardExpand all lines: docs/sphinx_doc/source/tutorial/example_mix_algo.md
-2Lines changed: 0 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -47,7 +47,6 @@ The path to expert data is passed to `buffer.trainer_input.auxiliary_buffers.sft
47
47
In `trinity/algorithm/algorithm.py`, we introduce a new algorithm type `MIX`.
48
48
49
49
```python
50
-
@ALGORITHM_TYPE.register_module("mix")
51
50
classMIXAlgorithm(AlgorithmType):
52
51
"""MIX algorithm."""
53
52
@@ -159,7 +158,6 @@ Here we use the `custom_fields` argument of `Experiences.gather_experiences` to
159
158
We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively.
0 commit comments