Skip to content

Commit a592af7

Browse files
authored
Add guideline for adding new algorithm (#85)
1 parent 69ddbd0 commit a592af7

File tree

8 files changed

+252
-57
lines changed

8 files changed

+252
-57
lines changed

docs/sphinx_doc/source/index.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,23 @@ Welcome to Trinity-RFT's documentation!
1414
:maxdepth: 1
1515
:glob:
1616
:hidden:
17-
:caption: Tutorial
17+
:caption: Examples
1818

1919
tutorial/example_reasoning_basic.md
2020
tutorial/example_reasoning_advanced.md
2121
tutorial/example_async_mode.md
2222
tutorial/example_multi_turn.md
2323
tutorial/example_dpo.md
2424
tutorial/example_data_functionalities.md
25-
tutorial/trinity_configs.md
25+
26+
.. toctree::
27+
:maxdepth: 2
28+
:glob:
29+
:hidden:
30+
:caption: Guidelines
31+
2632
tutorial/trinity_programming_guide.md
33+
tutorial/trinity_configs.md
2734
tutorial/example_mix_algo.md
2835

2936
.. toctree::
@@ -34,6 +41,7 @@ Welcome to Trinity-RFT's documentation!
3441
build_api/trinity.buffer
3542
build_api/trinity.explorer
3643
build_api/trinity.trainer
44+
build_api/trinity.algorithm
3745
build_api/trinity.manager
3846
build_api/trinity.common
3947
build_api/trinity.utils

docs/sphinx_doc/source/tutorial/example_mix_algo.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
# Integrate A New Algorithm
1+
# Algorithm Development
22

3+
```{note}
4+
This guide is an advanced version of the {ref}`Algorithms <Algorithms>` section in the Developer Guide.
5+
```
36

47
This guide introduces how to integrate a new algorithm to Trinity-RFT.
58
As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective:
@@ -19,13 +22,10 @@ The first term corresponds to the standard GRPO objective, which aims to maximiz
1922

2023
## Step 0: Prepare the Expert Data
2124

22-
We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a JSON file `expert_data.json` with the following format:
25+
We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a `jsonl` file `expert_data.jsonl` with the following format:
2326

2427
```json
25-
{
26-
"question": "What is the average of 4, 6, and 8?",
27-
"response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."
28-
}
28+
{"question": "What is the average of 4, 6, and 8?","response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."}
2929
...
3030
```
3131

@@ -42,7 +42,6 @@ class MIXAlgorithm(AlgorithmType):
4242
use_critic: bool = False
4343
use_reference: bool = True
4444
use_advantage: bool = True
45-
use_rollout: bool = True
4645
can_balance_batch: bool = True
4746
schema: type = ExperienceModel
4847

docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) a
66

77

88

9-
9+
(OPMD)=
1010
## OPMD: a native off-policy RL algorithm
1111

1212

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 217 additions & 22 deletions
Large diffs are not rendered by default.

trinity/algorithm/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
2+
from trinity.algorithm.algorithm import ALGORITHM_TYPE, AlgorithmType
23
from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
34
from trinity.algorithm.kl_fn import KL_FN, KLFn
45
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
56
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy
67

78
__all__ = [
9+
"ALGORITHM_TYPE",
10+
"AlgorithmType",
811
"AdvantageFn",
912
"ADVANTAGE_FN",
1013
"PolicyLossFn",

trinity/algorithm/algorithm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ class AlgorithmType(ABC, metaclass=ConstantMeta):
2626
use_critic: bool
2727
use_reference: bool
2828
use_advantage: bool
29-
use_rollout: bool
3029
can_balance_batch: bool
3130
schema: type
3231

@@ -50,7 +49,6 @@ class SFTAlgorithm(AlgorithmType):
5049
use_critic: bool = False
5150
use_reference: bool = False
5251
use_advantage: bool = False
53-
use_rollout: bool = False
5452
can_balance_batch: bool = True
5553
schema: type = SFTDataModel
5654

@@ -71,7 +69,6 @@ class PPOAlgorithm(AlgorithmType):
7169
use_critic: bool = True
7270
use_reference: bool = True
7371
use_advantage: bool = True
74-
use_rollout: bool = True
7572
can_balance_batch: bool = True
7673
schema: type = ExperienceModel
7774

@@ -95,7 +92,6 @@ class GRPOAlgorithm(AlgorithmType):
9592
use_critic: bool = False
9693
use_reference: bool = True
9794
use_advantage: bool = True
98-
use_rollout: bool = True
9995
can_balance_batch: bool = True
10096
schema: type = ExperienceModel
10197

@@ -119,7 +115,6 @@ class OPMDAlgorithm(AlgorithmType):
119115
use_critic: bool = False
120116
use_reference: bool = True
121117
use_advantage: bool = True
122-
use_rollout: bool = True
123118
can_balance_batch: bool = True
124119
schema: type = ExperienceModel
125120

@@ -143,7 +138,6 @@ class DPOAlgorithm(AlgorithmType):
143138
use_critic: bool = False
144139
use_reference: bool = True
145140
use_advantage: bool = False
146-
use_rollout: bool = False
147141
can_balance_batch: bool = False
148142
schema: type = DPODataModel
149143

trinity/buffer/writer/sql_writer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import ray
44

5-
from trinity.algorithm.algorithm import ALGORITHM_TYPE
65
from trinity.buffer.buffer_writer import BufferWriter
76
from trinity.buffer.db_wrapper import DBWrapper
87
from trinity.common.config import BufferConfig, StorageConfig
@@ -15,9 +14,6 @@ class SQLWriter(BufferWriter):
1514
def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
1615
assert meta.storage_type == StorageType.SQL
1716
# we only support write RFT algorithm buffer for now
18-
# TODO: support other algorithms
19-
algorithm = ALGORITHM_TYPE.get(meta.algorithm_type)
20-
assert algorithm.use_rollout, "Only RFT buffer is supported for writing."
2117
self.wrap_in_ray = meta.wrap_in_ray
2218
self.db_wrapper = DBWrapper.get_wrapper(meta, config)
2319

trinity/utils/registry.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,21 +83,21 @@ def register_module(self, module_name: str, module_cls: Type = None, force=False
8383
Default: False.
8484
8585
Example:
86-
```python
87-
WORKFLOWS = Registry("workflows")
88-
89-
# register a module using decorator
90-
@WORKFLOWS.register_module(name="workflow_name")
91-
class MyWorkflow(Workflow):
92-
pass
93-
94-
# or register a module directly
95-
WORKFLOWS.register_module(
96-
name="workflow_name",
97-
module_cls=MyWorkflow,
98-
force=True,
99-
)
100-
```
86+
87+
.. code-block:: python
88+
WORKFLOWS = Registry("workflows")
89+
90+
# register a module using decorator
91+
@WORKFLOWS.register_module(name="workflow_name")
92+
class MyWorkflow(Workflow):
93+
pass
94+
95+
# or register a module directly
96+
WORKFLOWS.register_module(
97+
name="workflow_name",
98+
module_cls=MyWorkflow,
99+
force=True,
100+
)
101101
102102
"""
103103
if not (module_name is None or isinstance(module_name, str)):

0 commit comments

Comments
 (0)