Skip to content

Commit f862e11

Browse files
authored
Support loading user-written plugin modules automatically (#74)
1 parent e773560 commit f862e11

File tree

13 files changed

+198
-64
lines changed

13 files changed

+198
-64
lines changed

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ During initialization, `Workflow` receives the following parameters:
120120
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
121121
```
122122

123-
Heres an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
123+
Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
124124

125125
```python
126126
class ExampleWorkflow(Workflow):
@@ -197,6 +197,25 @@ class ExampleWorkflow(Workflow):
197197
pass
198198
```
199199

200+
For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`:
201+
202+
```python
203+
# existing import lines
204+
from .example_workflow import ExampleWorkflow
205+
206+
__all__ = [
207+
# existing __all__ lines
208+
"ExampleWorkflow",
209+
]
210+
```
211+
212+
For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in `trinity/plugins`. Trinity-RFT will automatically detect and load all custom modules in this folder.
213+
214+
```{tip}
215+
You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `<Trinity_RFT_ROOT_DIR>/trinity/plugins` as the default directory.
216+
```
217+
218+
200219
#### Avoid Re-initialization
201220

202221
For heavy workflows, re-initializing every time can incurs extra computational costs.

tests/utils/__init__.py

Whitespace-only changes.

tests/utils/plugin_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import unittest
2+
from pathlib import Path
3+
4+
import ray
5+
6+
from trinity.common.workflows import WORKFLOWS
7+
from trinity.utils.plugin_loader import load_plugins
8+
9+
10+
@ray.remote
11+
class PluginActor:
12+
def run(self):
13+
my_plugin_cls = WORKFLOWS.get("my_workflow")
14+
return my_plugin_cls(None, None).run()
15+
16+
17+
class TestPluginLoader(unittest.TestCase):
18+
def test_load_plugins(self):
19+
ray.init(ignore_reinit_error=True)
20+
my_plugin_cls = WORKFLOWS.get("my_workflow")
21+
self.assertIsNone(my_plugin_cls)
22+
load_plugins(Path(__file__).resolve().parent / "plugins")
23+
my_plugin_cls = WORKFLOWS.get("my_workflow")
24+
self.assertIsNotNone(my_plugin_cls)
25+
my_plugin = my_plugin_cls(None, None, None)
26+
self.assertTrue(my_plugin.__module__.startswith("trinity.plugins"))
27+
res = my_plugin.run()
28+
self.assertEqual(res[0], "Hello world")
29+
self.assertEqual(res[1], "Hi")
30+
remote_plugin = PluginActor.remote()
31+
remote_res = ray.get(remote_plugin.run.remote())
32+
self.assertEqual(remote_res[0], "Hello world")
33+
self.assertEqual(remote_res[1], "Hi")
34+
ray.shutdown(_exiting_interpreter=True)

tests/utils/plugins/__init__.py

Whitespace-only changes.

tests/utils/plugins/my_workflow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import List
2+
3+
from trinity.common.workflows import WORKFLOWS, Workflow
4+
5+
6+
@WORKFLOWS.register_module("my_workflow")
7+
class MyWorkflow(Workflow):
8+
def __init__(self, model, task, auxiliary_models=None):
9+
super().__init__(model, task, auxiliary_models)
10+
11+
def run(self) -> List:
12+
return ["Hello world", "Hi"]

trinity/buffer/buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig
4646
file_read_type = file_read_type.value
4747
else:
4848
file_read_type = "rollout"
49-
return FILE_READERS.get(file_read_type)(storage_config, buffer_config)
49+
return FILE_READERS.get(file_read_type)(storage_config, buffer_config) # type: ignore
5050
else:
5151
raise ValueError(f"{storage_config.storage_type} not supported.")
5252

trinity/buffer/reader/file_reader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
195195
self.reward_fn_key = meta.format.reward_fn_key
196196

197197
self.task_type = meta.task_type
198-
self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type)
199-
self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type)
198+
self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore
199+
self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore
200200
self.total_epochs = meta.total_epochs if self.task_type == TaskType.EXPLORE else 1
201201

202202
def __len__(self):
@@ -216,7 +216,7 @@ def read(self, strategy: Optional[ReadStrategy] = None):
216216
if self.reward_fn_key in sample
217217
else self.default_reward_fn_cls
218218
)
219-
assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required"
219+
assert workflow_class is not None, "`default_workflow_type` or `workflow_key` is required"
220220
task = Task(
221221
workflow=workflow_class,
222222
format_args=self.meta.format,

trinity/cli/launcher.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from trinity.explorer.explorer import Explorer
1313
from trinity.trainer.trainer import Trainer
1414
from trinity.utils.log import get_logger
15+
from trinity.utils.plugin_loader import load_plugins
1516

1617
logger = get_logger(__name__)
1718

@@ -157,7 +158,8 @@ def activate_data_module(data_workflow_url: str, config_path: str):
157158
return
158159

159160

160-
def run(config_path: str, dlc: bool = False):
161+
def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
162+
load_plugins(plugin_dir)
161163
config = load_config(config_path)
162164
config.check_and_update()
163165
pprint(config)
@@ -219,6 +221,12 @@ def main() -> None:
219221
# run command
220222
run_parser = subparsers.add_parser("run", help="Run RFT process.")
221223
run_parser.add_argument("--config", type=str, required=True, help="Path to the config file.")
224+
run_parser.add_argument(
225+
"--plugin-dir",
226+
type=str,
227+
default=None,
228+
help="Path to the directory containing plugin modules.",
229+
)
222230
run_parser.add_argument(
223231
"--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
224232
)
@@ -229,12 +237,10 @@ def main() -> None:
229237
"--port", type=int, default=8501, help="The port for Trinity-Studio."
230238
)
231239

232-
# TODO: add more commands like `monitor`, `label`
233-
234240
args = parser.parse_args()
235241
if args.command == "run":
236242
# TODO: support parse all args from command line
237-
run(args.config, args.dlc)
243+
run(args.config, args.dlc, args.plugin_dir)
238244
elif args.command == "studio":
239245
studio(args.port)
240246

trinity/common/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def create_inference_models(
6464
else:
6565
raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}")
6666

67-
main_bundles = [{"GPU": 1, "CPU": 1} for _ in range(engine_num * tensor_parallel_size)]
67+
main_bundles = [{"GPU": 1} for _ in range(engine_num * tensor_parallel_size)]
6868
auxiliary_bundles = [
69-
{"GPU": 1, "CPU": 1}
69+
{"GPU": 1}
7070
for _ in range(
7171
sum(
7272
[

trinity/plugins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Add your custom modules to this directory."""

0 commit comments

Comments
 (0)