Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions docs/sphinx_doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,23 @@ Welcome to Trinity-RFT's documentation!
:maxdepth: 1
:glob:
:hidden:
:caption: Tutorial
:caption: Examples

tutorial/example_reasoning_basic.md
tutorial/example_reasoning_advanced.md
tutorial/example_async_mode.md
tutorial/example_multi_turn.md
tutorial/example_dpo.md
tutorial/example_data_functionalities.md
tutorial/trinity_configs.md

.. toctree::
:maxdepth: 2
:glob:
:hidden:
:caption: Guidelines

tutorial/trinity_programming_guide.md
tutorial/trinity_configs.md
tutorial/example_mix_algo.md

.. toctree::
Expand All @@ -34,6 +41,7 @@ Welcome to Trinity-RFT's documentation!
build_api/trinity.buffer
build_api/trinity.explorer
build_api/trinity.trainer
build_api/trinity.algorithm
build_api/trinity.manager
build_api/trinity.common
build_api/trinity.utils
13 changes: 6 additions & 7 deletions docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Integrate A New Algorithm
# Algorithm Development

```{note}
This guide is an advanced version of the {ref}`Algorithms <Algorithms>` section in the Developer Guide.
```

This guide introduces how to integrate a new algorithm to Trinity-RFT.
As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective:
Expand All @@ -19,13 +22,10 @@ The first term corresponds to the standard GRPO objective, which aims to maximiz

## Step 0: Prepare the Expert Data

We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a JSON file `expert_data.json` with the following format:
We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a `jsonl` file `expert_data.jsonl` with the following format:

```json
{
"question": "What is the average of 4, 6, and 8?",
"response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."
}
{"question": "What is the average of 4, 6, and 8?","response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."}
...
```

Expand All @@ -42,7 +42,6 @@ class MIXAlgorithm(AlgorithmType):
use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) a




(OPMD)=
## OPMD: a native off-policy RL algorithm


Expand Down
239 changes: 217 additions & 22 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.algorithm.algorithm import ALGORITHM_TYPE, AlgorithmType
from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
from trinity.algorithm.kl_fn import KL_FN, KLFn
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy

__all__ = [
"ALGORITHM_TYPE",
"AlgorithmType",
"AdvantageFn",
"ADVANTAGE_FN",
"PolicyLossFn",
Expand Down
6 changes: 0 additions & 6 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class AlgorithmType(ABC, metaclass=ConstantMeta):
use_critic: bool
use_reference: bool
use_advantage: bool
use_rollout: bool
can_balance_batch: bool
schema: type

Expand All @@ -50,7 +49,6 @@ class SFTAlgorithm(AlgorithmType):
use_critic: bool = False
use_reference: bool = False
use_advantage: bool = False
use_rollout: bool = False
can_balance_batch: bool = True
schema: type = SFTDataModel

Expand All @@ -71,7 +69,6 @@ class PPOAlgorithm(AlgorithmType):
use_critic: bool = True
use_reference: bool = True
use_advantage: bool = True
use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel

Expand All @@ -95,7 +92,6 @@ class GRPOAlgorithm(AlgorithmType):
use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel

Expand All @@ -119,7 +115,6 @@ class OPMDAlgorithm(AlgorithmType):
use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel

Expand All @@ -143,7 +138,6 @@ class DPOAlgorithm(AlgorithmType):
use_critic: bool = False
use_reference: bool = True
use_advantage: bool = False
use_rollout: bool = False
can_balance_batch: bool = False
schema: type = DPODataModel

Expand Down
4 changes: 0 additions & 4 deletions trinity/buffer/writer/sql_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import ray

from trinity.algorithm.algorithm import ALGORITHM_TYPE
from trinity.buffer.buffer_writer import BufferWriter
from trinity.buffer.db_wrapper import DBWrapper
from trinity.common.config import BufferConfig, StorageConfig
Expand All @@ -15,9 +14,6 @@ class SQLWriter(BufferWriter):
def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
assert meta.storage_type == StorageType.SQL
# we only support write RFT algorithm buffer for now
# TODO: support other algorithms
algorithm = ALGORITHM_TYPE.get(meta.algorithm_type)
assert algorithm.use_rollout, "Only RFT buffer is supported for writing."
self.wrap_in_ray = meta.wrap_in_ray
self.db_wrapper = DBWrapper.get_wrapper(meta, config)

Expand Down
30 changes: 15 additions & 15 deletions trinity/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,21 @@ def register_module(self, module_name: str, module_cls: Type = None, force=False
Default: False.

Example:
```python
WORKFLOWS = Registry("workflows")

# register a module using decorator
@WORKFLOWS.register_module(name="workflow_name")
class MyWorkflow(Workflow):
pass

# or register a module directly
WORKFLOWS.register_module(
name="workflow_name",
module_cls=MyWorkflow,
force=True,
)
```

.. code-block:: python
WORKFLOWS = Registry("workflows")

# register a module using decorator
@WORKFLOWS.register_module(name="workflow_name")
class MyWorkflow(Workflow):
pass

# or register a module directly
WORKFLOWS.register_module(
name="workflow_name",
module_cls=MyWorkflow,
force=True,
)

"""
if not (module_name is None or isinstance(module_name, str)):
Expand Down