diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst index a1b6fde647..4b4cab2aa9 100644 --- a/docs/sphinx_doc/source/index.rst +++ b/docs/sphinx_doc/source/index.rst @@ -14,7 +14,7 @@ Welcome to Trinity-RFT's documentation! :maxdepth: 1 :glob: :hidden: - :caption: Tutorial + :caption: Examples tutorial/example_reasoning_basic.md tutorial/example_reasoning_advanced.md @@ -22,8 +22,15 @@ Welcome to Trinity-RFT's documentation! tutorial/example_multi_turn.md tutorial/example_dpo.md tutorial/example_data_functionalities.md - tutorial/trinity_configs.md + +.. toctree:: + :maxdepth: 2 + :glob: + :hidden: + :caption: Guidelines + tutorial/trinity_programming_guide.md + tutorial/trinity_configs.md tutorial/example_mix_algo.md .. toctree:: @@ -34,6 +41,7 @@ Welcome to Trinity-RFT's documentation! build_api/trinity.buffer build_api/trinity.explorer build_api/trinity.trainer + build_api/trinity.algorithm build_api/trinity.manager build_api/trinity.common build_api/trinity.utils diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index ee0010ba24..61ecec33b1 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -1,5 +1,8 @@ -# Integrate A New Algorithm +# Algorithm Development +```{note} +This guide is an advanced version of the {ref}`Algorithms ` section in the Developer Guide. +``` This guide introduces how to integrate a new algorithm to Trinity-RFT. As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective: @@ -19,13 +22,10 @@ The first term corresponds to the standard GRPO objective, which aims to maximiz ## Step 0: Prepare the Expert Data -We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a JSON file `expert_data.json` with the following format: +We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a `jsonl` file `expert_data.jsonl` with the following format: ```json -{ - "question": "What is the average of 4, 6, and 8?", - "response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6." -} +{"question": "What is the average of 4, 6, and 8?","response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."} ... ``` @@ -42,7 +42,6 @@ class MIXAlgorithm(AlgorithmType): use_critic: bool = False use_reference: bool = True use_advantage: bool = True - use_rollout: bool = True can_balance_batch: bool = True schema: type = ExperienceModel diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md index a80032bc12..aa4439e866 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md @@ -6,7 +6,7 @@ Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) a - +(OPMD)= ## OPMD: a native off-policy RL algorithm diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 4d158f86b9..931cb81506 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -1,6 +1,16 @@ # Developer Guide -This guide introduces how to add new workflows to Trinity-RFT and provides relevant development guidelines. +This guide introduces how to develop new modules in Trinity-RFT and provides relevant development guidelines. + +Trinity-RFT consists of three main modules: **Explorer**, **Trainer** and **Buffer**. +We decouple the RL pipeline into three modules to make it easier to customize and extend. +Below is a table summarizing the modules and components that developers with different tragets need to focus on. + +| Development Target | Core Module | Key Component | +|--------------------|-------------|---------------| +| Apply existing RL algorithms to new environments. | *Explorer* | `Workflow` | +| Design new RL algorithms. | *Trainer* | `Algorithm` | +| Enhance the RL process from the data perspective. | *Buffer* | Data Processing Module (Coming soon) | ```{note} Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code. @@ -8,9 +18,10 @@ Trinity-RFT is still under development, and the following interfaces may change. --- -## Creating New Workflows +## Workflows (For RL Environment Developers) -Trinity-RFT allows developers to register new workflows (e.g., for multi-turn interactions or agentic scenarios). Below are the steps to create a new workflow: +In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments. +A qualified workflow needs to use the trained model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow: --- @@ -18,19 +29,16 @@ Trinity-RFT allows developers to register new workflows (e.g., for multi-turn in Before starting development, it's important to understand several core concepts: - - **Task** ({class}`trinity.common.workflows.Task`): Represents a data structure that can be converted into a `Workflow`. The content of the `Task` varies depending on the task type: - **Math problems**: A `Task` contains the problem description and the golden answer. - **Programming scenarios**: A `Task` includes the problem description, test cases, runtime environment, and other complex information. - -- **Workflow** ({class}`trinity.common.workflows.Workflow`): Can be understood as the running state of a `Task`. It defines the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows: +- **Workflow** ({class}`trinity.common.workflows.Workflow`): Describes how a `Task` is executed. It defines the interaction flow between Agents and Environments, including logic similar to *Rollout* and *Reward* calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows: - `MathWorkflow` ({class}`trinity.common.workflows.MathWorkflow`): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards). - `WebShopWorkflow` ({class}`trinity.common.workflows.WebShopWorkflow`): For webshop scenarios, it contains multi-turn interaction with environment. - `CodeWorkflow` (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results. - ... - - **Experience** ({class}`trinity.common.experience.Experience`): The output of running a `Workflow`. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc. --- @@ -40,12 +48,12 @@ Before starting development, it's important to understand several core concepts: The task dataset is loaded via the `buffer.explorer_input.taskset` configuration entry in your YAML config file. To handle differences in `Task` contents, Trinity-RFT provides a unified `Task` interface containing the following fields. - - **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file. - - **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field. - - **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields. - - **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`. - - **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`. - - **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field. +- **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file. +- **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field. +- **`raw_task`** (`Dict`): A record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields. +- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`. +- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`. +- **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field. ```{tip} `workflow`, `workflow_args` and `raw_task` provide different levels of customization. @@ -82,7 +90,6 @@ buffer: In this example, each task object's `raw_task` is a `Dict` with two keys (`question` and `answer`). The `MathWorkflow` uses the `prompt_key` and `response_key` to extract the question and answer from the `raw_task` and use the `rollout_args` to generate the response. - --- ### Step 2: Implement a New Workflow @@ -106,8 +113,7 @@ class Workflow(ABC): """Run the workflow and return a list of Experiences.""" ``` - -#### Initializing Your Workflow +#### Initialize Your Workflow During initialization, `Workflow` receives the following parameters: @@ -115,7 +121,6 @@ During initialization, `Workflow` receives the following parameters: - `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset. - `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs. - ```{tip} You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. ``` @@ -143,7 +148,6 @@ We first call the model to generate multiple response using the provided questio Then we calculate the reward for each response using the `calculate_reward` function. Finally, we construct a list of `Experience` with the responses and rewards and return it. - ```python class ExampleWorkflow(Workflow): @@ -215,7 +219,6 @@ For workflows that are not intended to be contributed to Trinity-RFT project, yo You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `/trinity/plugins` as the default directory. ``` - #### Avoid Re-initialization For heavy workflows, re-initializing every time can incurs extra computational costs. @@ -235,7 +238,6 @@ class ExampleWorkflow(Workflow): self.answer = task.raw_task.get("answer") ``` - #### Full Code Example ```python @@ -289,7 +291,6 @@ class ExampleWorkflow(Workflow): self.answer = task.raw_task.get("answer") ``` - --- ### Step 3: Use Your Workflow @@ -314,6 +315,198 @@ trinity run --config --- +(Algorithms)= +## Algorithms (For RL Algorithm Developers) + +Trinity-RFT provides a standardized process for implementing new algorithms. + +### Step 0: Basic Concepts of Algorithm Module + +In Trinity-RFT, the algorithm module is primarily responsible for extracting experience data from the Replay Buffer during the RL process and calculating the loss to update models based on this data. +To avoid implementing a new Trainer class each time a new algorithm is added, we have decomposed the representative PPO algorithm process into multiple sub-modules to adapt to various algorithms. + +- **Sample Strategy** ({class}`trinity.algorithm.SampleStrategy`): Responsible for sampling experience data from the buffer module. By customizing this module, you can implement functionalities like filtering experience data or mixed sampling from multiple data sources. +- **Advantage Fn**({class}`trinity.algorithm.AdvantageFn`): Responsible for calculating the Advantage and Returns of experience data. +- **Policy Loss Fn**({class}`trinity.algorithm.PolicyLossFn`): Responsible for calculating the core training loss of the policy network. +- **KL Fn**({class}`trinity.algorithm.KLFn`): Responsible for calculating KL Divergence, which is generally used in two places in existing RL algorithms: Reward Penalty and Actor Loss. +- **Entropy Loss Fn**({class}`trinity.algorithm.EntropyLossFn`): Responsible for calculating the entropy loss of the policy network. + +We provide several implementations of above modules in `trinity/algorithm`. + +--- + +### Step 1: Implement Algorithm Components + + +Trinity-RFT allows developers to customize all the above modules. Developers only need to implement specific modules according to the requirements of their new algorithm. This section will provide a simple introduction using the {ref}`OPMD ` algorithm as an example. + +The main difference between OPMD and PPO algorithms lies in the calculation of Advantage and Policy Loss. Therefore, only new Advantage Fn and Policy Loss Fn modules need to be implemented. + +--- + +#### Step 1.1: Implement `AdvantageFn` + +Developers need to implement the {class}`trinity.algorithm.AdvantageFn` interface, which mainly includes two methods: + +- `__call__`: Calculates advantages and returns based on input experience data, records observable metrics during the calculation process, and returns the experience data containing advantages and returns as well as a metrics dictionary. The input experience data format is [verl](https://github.com/volcengine/verl)'s `DataProto`. +- `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file. + +After implementation, you need to register this module through {class}`trinity.algorithm.ADVANTAGE_FN`. Once registered, the module can be configured in the configuration file using the registered name. + +Here's an implementation example for the OPMD algorithm's Advantage Fn: + +```python +# trinity/algorithm/advantage_fn/opmd.py +# import some modules +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn + + +@ADVANTAGE_FN.register_module("opmd") +class OPMDAdvantageFn(AdvantageFn): + """OPMD advantage computation""" + + def __init__( + self, + opmd_baseline: str = "mean", + tau: float = 1.0, + ) -> None: + self.opmd_baseline = opmd_baseline + self.tau = tau + + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + # calculate advantages and returns based on the exps + + # record some metrics + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "opmd_baseline": "mean", + "tau": 1.0, + } +``` + +#### Step 1.2: Implement `PolicyLossFn` + +Developers need to implement the {class}`trinity.algorithm.PolicyLossFn` interface, which is similar to `AdvantageFn` and includes two methods: + +- `__call__`: Calculates the loss based on input parameters. Unlike `AdvantageFn`, the input parameters here are all `torch.Tensor`. This interface automatically scans the parameter list of the `__call__` method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters from `kwargs`. +- `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file. + +Similarly, after implementation, you need to register this module through {class}`trinity.algorithm.POLICY_LOSS_FN`. + +Here's an implementation example for the OPMD algorithm's Policy Loss Fn. Since OPMD's Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the `__call__` method: + + +```python +@POLICY_LOSS_FN.register_module("opmd") +class OPMDPolicyLossFn(PolicyLossFn): + def __init__(self, tau: float = 1.0) -> None: + self.tau = tau + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + pg_losses = -advantages * logprob + opmd_loss = masked_mean(pg_losses, action_mask) + opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta) + return opmd_loss, {"opmd_loss": opmd_loss.detach().item()} + + @classmethod + def default_args(cls) -> Dict: + return {"tau": 1.0} +``` + +--- + +### Step 2: Register Your Algorithm + +The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect. + +To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {object}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration. + +The `AlgorithmType` class includes the following attributes and methods: + +- `use_critic`: Whether to use the Critic model +- `use_reference`: Whether to use the Reference model +- `use_advantage`: Whether to calculate Advantage; if False, the `AdvantageFn` call will be skipped +- `can_balance_batch`: Whether the algorithm allows automatic balancing when splitting a batch into microbatches (which permute the order of samples) +- `schema`: The format of experience data corresponding to the algorithm +- `get_default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE` + +Similarly, after implementation, you need to register this module through `ALGORITHM_TYPE`. + +Below is the implementation for the OPMD algorithm. +Since the OPMD algorithm doesn't need to use the Critic model, `use_critic` is set to `False`. +The dictionary returned by the `get_default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss. + +```python +@ALGORITHM_TYPE.register_module("opmd") +class OPMDAlgorithm(AlgorithmType): + """OPMD algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def get_default_config(cls) -> Dict: + return { + "repeat_times": 2, + "sample_strategy": "warmup", + "policy_loss_fn": "opmd", + "advantage_fn": "opmd", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } +``` + +--- + +### Step 3: Use Your Algorithm + +After completing all the above steps, you can use the newly registered algorithm through a YAML configuration file. + +For default configurations, you just need to add the following content to your `config.yaml` file: + +```yaml +# some other configs +algorithm: + algorithm_type: "opmd" +# some other configs +``` + +If you need to modify certain parameters, you can simply add the corresponding parameters within the `algorithm` section. For example, if you need to modify `repeat_times` and the initialization parameters of `AdvantageFn` and `PolicyLossFn`, the modified `config.yaml` file would be as follows: + +```yaml +# some other configs +algorithm: + algorithm_type: "opmd" + repeat_times: 8 + advantage_fn_args: + opmd_baseline: "logavgexp" + tau: 0.99 + policy_loss_fn_args: + tau: 0.99 +# some other configs +``` + +--- + ## Adding New Config Entries for the Config Generator (Advanced) ### Step 0: Understanding Streamlit @@ -344,11 +537,11 @@ The `CONFIG_GENERATORS.register_config` decorator automatically passes `key=conf ``` For `train_batch_size`, we will use the following settings: + - Default value: 96 - Visibility condition: `lambda: st.session_state["trainer_gpu_num"] > 0` - Additional config: `{"_train_batch_size_per_gpu": 16}` - Here's the complete code for the `train_batch_size` parameter: ```python @@ -408,6 +601,7 @@ To successfully integrate new parameters into the `config_manager.py` file, plea Incorporate the new parameter into the relevant section using the `self.get_configs` method within the `ConfigManager` class. Example: + ```python class ConfigManager: def _expert_buffer_part(self): @@ -421,6 +615,7 @@ To successfully integrate new parameters into the `config_manager.py` file, plea Utilize `st.session_state` to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML. Example: + ```python class ConfigManager: def _gen_buffer_config(self): diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index ff52f609e5..667aa10d74 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -1,10 +1,13 @@ from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.algorithm import ALGORITHM_TYPE, AlgorithmType from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn from trinity.algorithm.kl_fn import KL_FN, KLFn from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy __all__ = [ + "ALGORITHM_TYPE", + "AlgorithmType", "AdvantageFn", "ADVANTAGE_FN", "PolicyLossFn", diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 6f0a2d19a7..805dd8f213 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -26,7 +26,6 @@ class AlgorithmType(ABC, metaclass=ConstantMeta): use_critic: bool use_reference: bool use_advantage: bool - use_rollout: bool can_balance_batch: bool schema: type @@ -50,7 +49,6 @@ class SFTAlgorithm(AlgorithmType): use_critic: bool = False use_reference: bool = False use_advantage: bool = False - use_rollout: bool = False can_balance_batch: bool = True schema: type = SFTDataModel @@ -71,7 +69,6 @@ class PPOAlgorithm(AlgorithmType): use_critic: bool = True use_reference: bool = True use_advantage: bool = True - use_rollout: bool = True can_balance_batch: bool = True schema: type = ExperienceModel @@ -95,7 +92,6 @@ class GRPOAlgorithm(AlgorithmType): use_critic: bool = False use_reference: bool = True use_advantage: bool = True - use_rollout: bool = True can_balance_batch: bool = True schema: type = ExperienceModel @@ -119,7 +115,6 @@ class OPMDAlgorithm(AlgorithmType): use_critic: bool = False use_reference: bool = True use_advantage: bool = True - use_rollout: bool = True can_balance_batch: bool = True schema: type = ExperienceModel @@ -143,7 +138,6 @@ class DPOAlgorithm(AlgorithmType): use_critic: bool = False use_reference: bool = True use_advantage: bool = False - use_rollout: bool = False can_balance_batch: bool = False schema: type = DPODataModel diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index 3e054d58c6..49ef0ec9b3 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -2,7 +2,6 @@ import ray -from trinity.algorithm.algorithm import ALGORITHM_TYPE from trinity.buffer.buffer_writer import BufferWriter from trinity.buffer.db_wrapper import DBWrapper from trinity.common.config import BufferConfig, StorageConfig @@ -15,9 +14,6 @@ class SQLWriter(BufferWriter): def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.SQL # we only support write RFT algorithm buffer for now - # TODO: support other algorithms - algorithm = ALGORITHM_TYPE.get(meta.algorithm_type) - assert algorithm.use_rollout, "Only RFT buffer is supported for writing." self.wrap_in_ray = meta.wrap_in_ray self.db_wrapper = DBWrapper.get_wrapper(meta, config) diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py index d5ee37f36e..83cd393519 100644 --- a/trinity/utils/registry.py +++ b/trinity/utils/registry.py @@ -83,21 +83,21 @@ def register_module(self, module_name: str, module_cls: Type = None, force=False Default: False. Example: - ```python - WORKFLOWS = Registry("workflows") - - # register a module using decorator - @WORKFLOWS.register_module(name="workflow_name") - class MyWorkflow(Workflow): - pass - - # or register a module directly - WORKFLOWS.register_module( - name="workflow_name", - module_cls=MyWorkflow, - force=True, - ) - ``` + + .. code-block:: python + WORKFLOWS = Registry("workflows") + + # register a module using decorator + @WORKFLOWS.register_module(name="workflow_name") + class MyWorkflow(Workflow): + pass + + # or register a module directly + WORKFLOWS.register_module( + name="workflow_name", + module_cls=MyWorkflow, + force=True, + ) """ if not (module_name is None or isinstance(module_name, str)):