Skip to content

Commit cb89407

Browse files
committed
Merge main into dev/on-policy-distillation
2 parents 2dd56e0 + e412fbe commit cb89407

File tree

159 files changed

+1674
-889
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

159 files changed

+1674
-889
lines changed

benchmark/bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.distributed as dist
1010
import yaml
1111

12-
from trinity.algorithm.algorithm import ALGORITHM_TYPE
12+
from trinity.algorithm import ALGORITHM_TYPE
1313
from trinity.common.constants import MODEL_PATH_ENV_VAR, SyncStyle
1414
from trinity.utils.dlc_utils import get_dlc_env_vars
1515

benchmark/plugins/guru_math/reward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22

3+
from trinity.common.rewards import REWARD_FUNCTIONS
34
from trinity.common.rewards.math_reward import MathBoxedRewardFn
4-
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS
55

66

77
@REWARD_FUNCTIONS.register_module("math_boxed_reward_naive_dapo")

benchmark/reports/gsm8k.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,11 @@ from typing import List, Optional
174174
import openai
175175
from trinity.common.experience import Experience
176176
from trinity.common.models.model import ModelWrapper
177-
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
177+
from trinity.common.workflows.workflow import Task, Workflow
178178

179179
from verl.utils.reward_score import gsm8k
180180

181181

182-
@WORKFLOWS.register_module("verl_gsm8k_workflow")
183182
class VerlGSM8kWorkflow(Workflow):
184183
can_reset: bool = True
185184
can_repeat: bool = True

docs/sphinx_doc/source/tutorial/develop_algorithm.md

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ For convenience, Trinity-RFT provides an abstract class {class}`trinity.algorith
4747
Here's an implementation example for the OPMD algorithm's advantage function:
4848

4949
```python
50-
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, GroupAdvantage
50+
from trinity.algorithm.advantage_fn import GroupAdvantage
5151

52-
@ADVANTAGE_FN.register_module("opmd")
5352
class OPMDGroupAdvantage(GroupAdvantage):
5453
"""OPMD Group Advantage computation"""
5554

@@ -90,7 +89,7 @@ class OPMDGroupAdvantage(GroupAdvantage):
9089
return {"opmd_baseline": "mean", "tau": 1.0}
9190
```
9291

93-
After implementation, you need to register this module through {class}`trinity.algorithm.ADVANTAGE_FN`. Once registered, the module can be configured in the configuration file using the registered name.
92+
After implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/__init__.py`. Once registered, the module can be configured in the configuration file using the registered name.
9493

9594

9695
#### Step 1.2: Implement `PolicyLossFn`
@@ -100,13 +99,12 @@ Developers need to implement the {class}`trinity.algorithm.PolicyLossFn` interfa
10099
- `__call__`: Calculates the loss based on input parameters. Unlike `AdvantageFn`, the input parameters here are all `torch.Tensor`. This interface automatically scans the parameter list of the `__call__` method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters from `kwargs`.
101100
- `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file.
102101

103-
Similarly, after implementation, you need to register this module through {class}`trinity.algorithm.POLICY_LOSS_FN`.
102+
Similarly, after implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/policy_loss_fn/__init__.py`.
104103

105104
Here's an implementation example for the OPMD algorithm's Policy Loss Fn. Since OPMD's Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the `__call__` method:
106105

107106

108107
```python
109-
@POLICY_LOSS_FN.register_module("opmd")
110108
class OPMDPolicyLossFn(PolicyLossFn):
111109
def __init__(self, tau: float = 1.0) -> None:
112110
self.tau = tau
@@ -134,7 +132,7 @@ class OPMDPolicyLossFn(PolicyLossFn):
134132

135133
The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.
136134

137-
To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {class}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.
135+
To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in `trinity/algorithm/__init__.py`, enabling one-click configuration.
138136

139137
The `AlgorithmType` class includes the following attributes and methods:
140138

@@ -145,14 +143,13 @@ The `AlgorithmType` class includes the following attributes and methods:
145143
- `schema`: The format of experience data corresponding to the algorithm
146144
- `default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE`
147145

148-
Similarly, after implementation, you need to register this module through `ALGORITHM_TYPE`.
146+
Similarly, after implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/__init__.py`.
149147

150148
Below is the implementation for the OPMD algorithm.
151149
Since the OPMD algorithm doesn't need to use the Critic model, `use_critic` is set to `False`.
152150
The dictionary returned by the `default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss.
153151

154152
```python
155-
@ALGORITHM_TYPE.register_module("opmd")
156153
class OPMDAlgorithm(AlgorithmType):
157154
"""OPMD algorithm."""
158155

docs/sphinx_doc/source/tutorial/develop_operator.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@ class ExperienceOperator(ABC):
4040
Here is an implementation of a simple operator that filters out experiences with rewards below a certain threshold:
4141

4242
```python
43-
from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator
43+
from trinity.buffer.operators import ExperienceOperator
4444
from trinity.common.experience import Experience
4545

4646

47-
@EXPERIENCE_OPERATORS.register_module("reward_filter")
4847
class RewardFilter(ExperienceOperator):
4948

5049
def __init__(self, threshold: float = 0.0) -> None:

docs/sphinx_doc/source/tutorial/develop_overview.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,23 @@ The table below lists the main functions of each extension interface, its target
1717
Trinity-RFT provides a modular development approach, allowing you to flexibly add custom modules without modifying the framework code.
1818
You can place your module code in the `trinity/plugins` directory. Trinity-RFT will automatically load all Python files in that directory at runtime and register the custom modules within them.
1919
Trinity-RFT also supports specifying other directories at runtime by setting the `--plugin-dir` option, for example: `trinity run --config <config_file> --plugin-dir <your_plugin_dir>`.
20+
Alternatively, you can use the relative path to the custom module in the YAML configuration file, for example: `default_workflow_type: 'examples.agentscope_frozenlake.workflow.FrozenLakeWorkflow'`.
2021
```
2122

2223
For modules you plan to contribute to Trinity-RFT, please follow these steps:
2324

2425
1. Implement your code in the appropriate directory, such as `trinity/common/workflows` for `Workflow`, `trinity/algorithm` for `Algorithm`, and `trinity/buffer/operators` for `Operator`.
2526

26-
2. Register your module in the corresponding `__init__.py` file of the directory.
27+
2. Register your module in the corresponding mapping dictionary in the `__init__.py` file of the directory.
28+
For example, if you want to register a new workflow class `ExampleWorkflow`, you need to modify the `default_mapping` dictionary of `WORKFLOWS` in the `trinity/common/workflows/__init__.py` file:
29+
```python
30+
WORKFLOWS: Registry = Registry(
31+
"workflows",
32+
default_mapping={
33+
"example_workflow": "trinity.common.workflows.workflow.ExampleWorkflow",
34+
},
35+
)
36+
```
2737

2838
3. Add tests for your module in the `tests` directory, following the naming conventions and structure of existing tests.
2939

docs/sphinx_doc/source/tutorial/develop_selector.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ To create a new selector, inherit from `BaseSelector` and implement the followin
6060
This selector focuses on samples whose predicted performance is closest to a target (e.g., 90% success rate), effectively choosing "just right" difficulty tasks.
6161

6262
```python
63-
@SELECTORS.register_module("difficulty_based")
6463
class DifficultyBasedSelector(BaseSelector):
6564
def __init__(self, data_source, config: TaskSelectorConfig) -> None:
6665
super().__init__(data_source, config)
@@ -125,7 +124,15 @@ class DifficultyBasedSelector(BaseSelector):
125124
self.current_index = state_dict.get("current_index", 0)
126125
```
127126

128-
> 🔁 After defining your class, use `@SELECTORS.register_module("your_name")` so it can be referenced by name in configs.
127+
> 🔁 After defining your class, remember to register it in the `default_mapping` of `trinity/buffer/selector/__init__.py` so it can be referenced by name in configs.
128+
```python
129+
SELECTORS = Registry(
130+
"selectors",
131+
default_mapping={
132+
"difficulty_based": "trinity.buffer.selector.selector.DifficultyBasedSelector",
133+
},
134+
)
135+
```
129136

130137

131138

@@ -152,7 +159,6 @@ The operator must output a metric under the key `trinity.common.constants.SELECT
152159
#### Example: Pass Rate Calculator
153160

154161
```python
155-
@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator")
156162
class PassRateCalculator(ExperienceOperator):
157163
def __init__(self, **kwargs):
158164
pass
@@ -194,7 +200,7 @@ After implementing your selector and operator, register them in the config file.
194200
data_processor:
195201
experience_pipeline:
196202
operators:
197-
- name: pass_rate_calculator # Must match @register_module name
203+
- name: pass_rate_calculator
198204
```
199205
200206
#### Configure the Taskset with Your Selector
@@ -207,7 +213,7 @@ buffer:
207213
storage_type: file
208214
path: ./path/to/tasks
209215
task_selector:
210-
selector_type: difficulty_based # Matches @register_module name
216+
selector_type: difficulty_based
211217
feature_keys: ["correct", "uncertainty"]
212218
kwargs:
213219
m: 16

docs/sphinx_doc/source/tutorial/develop_workflow.md

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -176,28 +176,16 @@ class ExampleWorkflow(Workflow):
176176

177177
#### Registering Your Workflow
178178

179-
Register your workflow using the `WORKFLOWS.register_module` decorator.
179+
Register your workflow using the `default_mapping` in `trinity/common/workflows/__init__.py`.
180180
Ensure the name does not conflict with existing workflows.
181181

182182
```python
183-
# import some packages
184-
from trinity.common.workflows.workflow import WORKFLOWS
185-
186-
@WORKFLOWS.register_module("example_workflow")
187-
class ExampleWorkflow(Workflow):
188-
pass
189-
```
190-
191-
For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`:
192-
193-
```python
194-
# existing import lines
195-
from trinity.common.workflows.example_workflow import ExampleWorkflow
196-
197-
__all__ = [
198-
# existing __all__ lines
199-
"ExampleWorkflow",
200-
]
183+
WORKFLOWS = Registry(
184+
"workflows",
185+
default_mapping={
186+
"example_workflow": "trinity.common.workflows.workflow.ExampleWorkflow",
187+
},
188+
)
201189
```
202190

203191
#### Performance Optimization
@@ -212,7 +200,6 @@ The `can_reset` is a class property that indicates whether the workflow supports
212200
The `reset` method accepts a `Task` parameter and resets the workflow's internal state based on the new task.
213201

214202
```python
215-
@WORKFLOWS.register_module("example_workflow")
216203
class ExampleWorkflow(Workflow):
217204
can_reset: bool = True
218205

@@ -234,7 +221,6 @@ The `can_repeat` is a class property that indicates whether the workflow support
234221
The `set_repeat_times` method accepts two parameters: `repeat_times` specifies the number of times to execute within the `run` method, and `run_id_base` is an integer used to identify the first run ID in multiple runs (this parameter is used in multi-turn interaction scenarios; for tasks that can be completed with a single model call, this can be ignored).
235222

236223
```python
237-
@WORKFLOWS.register_module("example_workflow")
238224
class ExampleWorkflow(Workflow):
239225
can_repeat: bool = True
240226
# some code
@@ -275,7 +261,6 @@ class ExampleWorkflow(Workflow):
275261
#### Full Code Example
276262

277263
```python
278-
@WORKFLOWS.register_module("example_workflow")
279264
class ExampleWorkflow(Workflow):
280265
can_reset: bool = True
281266
can_repeat: bool = True
@@ -359,7 +344,6 @@ trinity run --config <your_yaml_file>
359344
The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set `is_async` to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, and the initialization parameter `auxiliary_models` will also change to `List[openai.AsyncOpenAI]`, while other methods and properties remain changed.
360345

361346
```python
362-
@WORKFLOWS.register_module("example_workflow_async")
363347
class ExampleWorkflowAsync(Workflow):
364348

365349
is_async: bool = True
@@ -386,7 +370,6 @@ explorer:
386370
```
387371

388372
```python
389-
@WORKFLOWS.register_module("example_workflow")
390373
class ExampleWorkflow(Workflow):
391374

392375
def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):

docs/sphinx_doc/source/tutorial/example_mix_algo.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ The path to expert data is passed to `buffer.trainer_input.auxiliary_buffers.sft
4747
In `trinity/algorithm/algorithm.py`, we introduce a new algorithm type `MIX`.
4848

4949
```python
50-
@ALGORITHM_TYPE.register_module("mix")
5150
class MIXAlgorithm(AlgorithmType):
5251
"""MIX algorithm."""
5352

@@ -159,7 +158,6 @@ Here we use the `custom_fields` argument of `Experiences.gather_experiences` to
159158
We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively.
160159

161160
```python
162-
@POLICY_LOSS_FN.register_module("mix")
163161
class MIXPolicyLossFn(PolicyLossFn):
164162
def __init__(
165163
self,

docs/sphinx_doc/source/tutorial/example_multi_turn.md

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -126,28 +126,14 @@ class AlfworldWorkflow(MultiTurnWorkflow):
126126
return self.generate_env_inference_samples(env, rollout_n)
127127
```
128128

129-
Also, remember to register your workflow:
129+
Also, remember to register your workflow in the `default_mapping` of `trinity/common/workflows/__init__.py`.
130130
```python
131-
@WORKFLOWS.register_module("alfworld_workflow")
132-
class AlfworldWorkflow(MultiTurnWorkflow):
133-
"""A workflow for alfworld task."""
134-
...
135-
```
136-
137-
and include it in the init file `trinity/common/workflows/__init__.py`
138-
139-
```diff
140-
# -*- coding: utf-8 -*-
141-
"""Workflow module"""
142-
from trinity.common.workflows.workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow
143-
+from trinity.common.workflows.envs.alfworld.alfworld_workflow import AlfworldWorkflow
144-
145-
__all__ = [
146-
"WORKFLOWS",
147-
"SimpleWorkflow",
148-
"MathWorkflow",
149-
+ "AlfworldWorkflow",
150-
]
131+
WORKFLOWS = Registry(
132+
"workflows",
133+
default_mapping={
134+
"alfworld_workflow": "trinity.common.workflows.envs.alfworld.alfworld_workflow.AlfworldWorkflow",
135+
},
136+
)
151137
```
152138

153139
Then you are all set! It should be pretty simple😄, and the training processes in both environments converge.

0 commit comments

Comments
 (0)