Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ During initialization, `Workflow` receives the following parameters:
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
```

Heres an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.

```python
class ExampleWorkflow(Workflow):
Expand Down Expand Up @@ -197,6 +197,25 @@ class ExampleWorkflow(Workflow):
pass
```

For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`:

```python
# existing import lines
from .example_workflow import ExampleWorkflow

__all__ = [
# existing __all__ lines
"ExampleWorkflow",
]
```

For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in `trinity/plugins`. Trinity-RFT will automatically detect and load all custom modules in this folder.

```{tip}
You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `<Trinity_RFT_ROOT_DIR>/trinity/plugins` as the default directory.
```


#### Avoid Re-initialization

For heavy workflows, re-initializing every time can incurs extra computational costs.
Expand Down
Empty file added tests/utils/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions tests/utils/plugin_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest
from pathlib import Path

import ray

from trinity.common.workflows import WORKFLOWS
from trinity.utils.plugin_loader import load_plugins


@ray.remote
class PluginActor:
def run(self):
my_plugin_cls = WORKFLOWS.get("my_workflow")
return my_plugin_cls(None, None).run()


class TestPluginLoader(unittest.TestCase):
def test_load_plugins(self):
ray.init(ignore_reinit_error=True)
my_plugin_cls = WORKFLOWS.get("my_workflow")
self.assertIsNone(my_plugin_cls)
load_plugins(Path(__file__).resolve().parent / "plugins")
my_plugin_cls = WORKFLOWS.get("my_workflow")
self.assertIsNotNone(my_plugin_cls)
my_plugin = my_plugin_cls(None, None, None)
self.assertTrue(my_plugin.__module__.startswith("trinity.plugins"))
res = my_plugin.run()
self.assertEqual(res[0], "Hello world")
self.assertEqual(res[1], "Hi")
remote_plugin = PluginActor.remote()
remote_res = ray.get(remote_plugin.run.remote())
self.assertEqual(remote_res[0], "Hello world")
self.assertEqual(remote_res[1], "Hi")
ray.shutdown(_exiting_interpreter=True)
Empty file added tests/utils/plugins/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions tests/utils/plugins/my_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import List

from trinity.common.workflows import WORKFLOWS, Workflow


@WORKFLOWS.register_module("my_workflow")
class MyWorkflow(Workflow):
def __init__(self, model, task, auxiliary_models=None):
super().__init__(model, task, auxiliary_models)

def run(self) -> List:
return ["Hello world", "Hi"]
2 changes: 1 addition & 1 deletion trinity/buffer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig
file_read_type = file_read_type.value
else:
file_read_type = "rollout"
return FILE_READERS.get(file_read_type)(storage_config, buffer_config)
return FILE_READERS.get(file_read_type)(storage_config, buffer_config) # type: ignore
else:
raise ValueError(f"{storage_config.storage_type} not supported.")

Expand Down
6 changes: 3 additions & 3 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.reward_fn_key = meta.format.reward_fn_key

self.task_type = meta.task_type
self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type)
self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type)
self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore
self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore
self.total_epochs = meta.total_epochs if self.task_type == TaskType.EXPLORE else 1

def __len__(self):
Expand All @@ -216,7 +216,7 @@ def read(self, strategy: Optional[ReadStrategy] = None):
if self.reward_fn_key in sample
else self.default_reward_fn_cls
)
assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required"
assert workflow_class is not None, "`default_workflow_type` or `workflow_key` is required"
task = Task(
workflow=workflow_class,
format_args=self.meta.format,
Expand Down
14 changes: 10 additions & 4 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from trinity.explorer.explorer import Explorer
from trinity.trainer.trainer import Trainer
from trinity.utils.log import get_logger
from trinity.utils.plugin_loader import load_plugins

logger = get_logger(__name__)

Expand Down Expand Up @@ -157,7 +158,8 @@ def activate_data_module(data_workflow_url: str, config_path: str):
return


def run(config_path: str, dlc: bool = False):
def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
load_plugins(plugin_dir)
config = load_config(config_path)
config.check_and_update()
pprint(config)
Expand Down Expand Up @@ -219,6 +221,12 @@ def main() -> None:
# run command
run_parser = subparsers.add_parser("run", help="Run RFT process.")
run_parser.add_argument("--config", type=str, required=True, help="Path to the config file.")
run_parser.add_argument(
"--plugin-dir",
type=str,
default=None,
help="Path to the directory containing plugin modules.",
)
run_parser.add_argument(
"--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
)
Expand All @@ -229,12 +237,10 @@ def main() -> None:
"--port", type=int, default=8501, help="The port for Trinity-Studio."
)

# TODO: add more commands like `monitor`, `label`

args = parser.parse_args()
if args.command == "run":
# TODO: support parse all args from command line
run(args.config, args.dlc)
run(args.config, args.dlc, args.plugin_dir)
elif args.command == "studio":
studio(args.port)

Expand Down
4 changes: 2 additions & 2 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def create_inference_models(
else:
raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}")

main_bundles = [{"GPU": 1, "CPU": 1} for _ in range(engine_num * tensor_parallel_size)]
main_bundles = [{"GPU": 1} for _ in range(engine_num * tensor_parallel_size)]
auxiliary_bundles = [
{"GPU": 1, "CPU": 1}
{"GPU": 1}
for _ in range(
sum(
[
Expand Down
1 change: 1 addition & 0 deletions trinity/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Add your custom modules to this directory."""
4 changes: 4 additions & 0 deletions trinity/utils/dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,11 @@ def setup_ray_cluster(namespace: str):
).remote()
while True:
if ray.get(cluster_status.running.remote()):
ret = subprocess.run("ray status", shell=True, capture_output=True)
print(ret.stdout.decode())
time.sleep(5)
else:
logger.info("Ray cluster is not running, exiting.")
break
sys.exit(0)

Expand All @@ -118,3 +121,4 @@ def stop_ray_cluster():
get_if_exists=True,
).remote()
ray.get(cluster_status.finish.remote())
logger.info("Stopping ray cluster...")
65 changes: 65 additions & 0 deletions trinity/utils/plugin_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Load modules from custom directory"""

import importlib
import os
import shutil
import sys
from pathlib import Path

from trinity.utils.log import get_logger

logger = get_logger(__name__)


def load_plugins(plugin_dir: str) -> None:
"""
Load plugin modules from a directory.
"""
if plugin_dir is None:
plugin_dir = Path(__file__).parent.parent / "plugins"
if not os.path.exists(plugin_dir):
logger.error(f"--plugin-dir [{plugin_dir}] does not exist.")
return None
if not os.path.isdir(plugin_dir):
logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.")
return None

logger.info(f"Loading plugin modules from [{plugin_dir}]...")
for file in Path(plugin_dir).glob("*.py"):
if file.name.startswith("__"):
continue
logger.info(f"Loading plugin modules from [{file}]...")
# load modules from file
load_from_file(os.path.join(plugin_dir, file))


def load_from_file(file_path: str):
"""
Load modules from a Python file

Args:
file_path (`str`): The python file path.

Returns:
`Any`: The loaded module.
"""
module_name = os.path.splitext(os.path.basename(file_path))[0]

full_module_name = f"trinity.plugins.{module_name}"

spec = importlib.util.spec_from_file_location(full_module_name, file_path)
if spec is None:
raise ImportError(f"Cannot load module from {file_path}")

module = importlib.util.module_from_spec(spec)

module.__package__ = "trinity.plugins"

spec.loader.exec_module(module)

if full_module_name in sys.modules:
raise ImportError(f"Module {module_name} already exists.")
sys.modules[full_module_name] = module
shutil.copy2(file_path, Path(__file__).parent.parent / "plugins")
logger.info(f"Load {file_path} as {full_module_name}")
return module
Loading