diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index da5b447cdc..1b8e3fc56b 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -120,7 +120,7 @@ During initialization, `Workflow` receives the following parameters: You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. ``` -Here’s an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization. +Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization. ```python class ExampleWorkflow(Workflow): @@ -197,6 +197,25 @@ class ExampleWorkflow(Workflow): pass ``` +For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`: + +```python +# existing import lines +from .example_workflow import ExampleWorkflow + +__all__ = [ + # existing __all__ lines + "ExampleWorkflow", +] +``` + +For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in `trinity/plugins`. Trinity-RFT will automatically detect and load all custom modules in this folder. + +```{tip} +You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `/trinity/plugins` as the default directory. +``` + + #### Avoid Re-initialization For heavy workflows, re-initializing every time can incurs extra computational costs. diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/utils/plugin_test.py b/tests/utils/plugin_test.py new file mode 100644 index 0000000000..01aa2f3967 --- /dev/null +++ b/tests/utils/plugin_test.py @@ -0,0 +1,34 @@ +import unittest +from pathlib import Path + +import ray + +from trinity.common.workflows import WORKFLOWS +from trinity.utils.plugin_loader import load_plugins + + +@ray.remote +class PluginActor: + def run(self): + my_plugin_cls = WORKFLOWS.get("my_workflow") + return my_plugin_cls(None, None).run() + + +class TestPluginLoader(unittest.TestCase): + def test_load_plugins(self): + ray.init(ignore_reinit_error=True) + my_plugin_cls = WORKFLOWS.get("my_workflow") + self.assertIsNone(my_plugin_cls) + load_plugins(Path(__file__).resolve().parent / "plugins") + my_plugin_cls = WORKFLOWS.get("my_workflow") + self.assertIsNotNone(my_plugin_cls) + my_plugin = my_plugin_cls(None, None, None) + self.assertTrue(my_plugin.__module__.startswith("trinity.plugins")) + res = my_plugin.run() + self.assertEqual(res[0], "Hello world") + self.assertEqual(res[1], "Hi") + remote_plugin = PluginActor.remote() + remote_res = ray.get(remote_plugin.run.remote()) + self.assertEqual(remote_res[0], "Hello world") + self.assertEqual(remote_res[1], "Hi") + ray.shutdown(_exiting_interpreter=True) diff --git a/tests/utils/plugins/__init__.py b/tests/utils/plugins/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/utils/plugins/my_workflow.py b/tests/utils/plugins/my_workflow.py new file mode 100644 index 0000000000..b999590a01 --- /dev/null +++ b/tests/utils/plugins/my_workflow.py @@ -0,0 +1,12 @@ +from typing import List + +from trinity.common.workflows import WORKFLOWS, Workflow + + +@WORKFLOWS.register_module("my_workflow") +class MyWorkflow(Workflow): + def __init__(self, model, task, auxiliary_models=None): + super().__init__(model, task, auxiliary_models) + + def run(self) -> List: + return ["Hello world", "Hi"] diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index 09ff663c47..32a5fb85a8 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -46,7 +46,7 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig file_read_type = file_read_type.value else: file_read_type = "rollout" - return FILE_READERS.get(file_read_type)(storage_config, buffer_config) + return FILE_READERS.get(file_read_type)(storage_config, buffer_config) # type: ignore else: raise ValueError(f"{storage_config.storage_type} not supported.") diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index cb69b5e017..69472a3547 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -195,8 +195,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.reward_fn_key = meta.format.reward_fn_key self.task_type = meta.task_type - self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) - self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) + self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore + self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore self.total_epochs = meta.total_epochs if self.task_type == TaskType.EXPLORE else 1 def __len__(self): @@ -216,7 +216,7 @@ def read(self, strategy: Optional[ReadStrategy] = None): if self.reward_fn_key in sample else self.default_reward_fn_cls ) - assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required" + assert workflow_class is not None, "`default_workflow_type` or `workflow_key` is required" task = Task( workflow=workflow_class, format_args=self.meta.format, diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index b0fb4b856f..d3156ebd6f 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -12,6 +12,7 @@ from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger +from trinity.utils.plugin_loader import load_plugins logger = get_logger(__name__) @@ -157,7 +158,8 @@ def activate_data_module(data_workflow_url: str, config_path: str): return -def run(config_path: str, dlc: bool = False): +def run(config_path: str, dlc: bool = False, plugin_dir: str = None): + load_plugins(plugin_dir) config = load_config(config_path) config.check_and_update() pprint(config) @@ -219,6 +221,12 @@ def main() -> None: # run command run_parser = subparsers.add_parser("run", help="Run RFT process.") run_parser.add_argument("--config", type=str, required=True, help="Path to the config file.") + run_parser.add_argument( + "--plugin-dir", + type=str, + default=None, + help="Path to the directory containing plugin modules.", + ) run_parser.add_argument( "--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC." ) @@ -229,12 +237,10 @@ def main() -> None: "--port", type=int, default=8501, help="The port for Trinity-Studio." ) - # TODO: add more commands like `monitor`, `label` - args = parser.parse_args() if args.command == "run": # TODO: support parse all args from command line - run(args.config, args.dlc) + run(args.config, args.dlc, args.plugin_dir) elif args.command == "studio": studio(args.port) diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 25cb927799..8b80a71bd6 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -64,9 +64,9 @@ def create_inference_models( else: raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}") - main_bundles = [{"GPU": 1, "CPU": 1} for _ in range(engine_num * tensor_parallel_size)] + main_bundles = [{"GPU": 1} for _ in range(engine_num * tensor_parallel_size)] auxiliary_bundles = [ - {"GPU": 1, "CPU": 1} + {"GPU": 1} for _ in range( sum( [ diff --git a/trinity/plugins/__init__.py b/trinity/plugins/__init__.py new file mode 100644 index 0000000000..1b8629c9ca --- /dev/null +++ b/trinity/plugins/__init__.py @@ -0,0 +1 @@ +"""Add your custom modules to this directory.""" diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py index 3edcc9539f..b250d856d6 100644 --- a/trinity/utils/dlc_utils.py +++ b/trinity/utils/dlc_utils.py @@ -103,8 +103,11 @@ def setup_ray_cluster(namespace: str): ).remote() while True: if ray.get(cluster_status.running.remote()): + ret = subprocess.run("ray status", shell=True, capture_output=True) + print(ret.stdout.decode()) time.sleep(5) else: + logger.info("Ray cluster is not running, exiting.") break sys.exit(0) @@ -118,3 +121,4 @@ def stop_ray_cluster(): get_if_exists=True, ).remote() ray.get(cluster_status.finish.remote()) + logger.info("Stopping ray cluster...") diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py new file mode 100644 index 0000000000..a5a779ae83 --- /dev/null +++ b/trinity/utils/plugin_loader.py @@ -0,0 +1,65 @@ +"""Load modules from custom directory""" + +import importlib +import os +import shutil +import sys +from pathlib import Path + +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +def load_plugins(plugin_dir: str) -> None: + """ + Load plugin modules from a directory. + """ + if plugin_dir is None: + plugin_dir = Path(__file__).parent.parent / "plugins" + if not os.path.exists(plugin_dir): + logger.error(f"--plugin-dir [{plugin_dir}] does not exist.") + return None + if not os.path.isdir(plugin_dir): + logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.") + return None + + logger.info(f"Loading plugin modules from [{plugin_dir}]...") + for file in Path(plugin_dir).glob("*.py"): + if file.name.startswith("__"): + continue + logger.info(f"Loading plugin modules from [{file}]...") + # load modules from file + load_from_file(os.path.join(plugin_dir, file)) + + +def load_from_file(file_path: str): + """ + Load modules from a Python file + + Args: + file_path (`str`): The python file path. + + Returns: + `Any`: The loaded module. + """ + module_name = os.path.splitext(os.path.basename(file_path))[0] + + full_module_name = f"trinity.plugins.{module_name}" + + spec = importlib.util.spec_from_file_location(full_module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load module from {file_path}") + + module = importlib.util.module_from_spec(spec) + + module.__package__ = "trinity.plugins" + + spec.loader.exec_module(module) + + if full_module_name in sys.modules: + raise ImportError(f"Module {module_name} already exists.") + sys.modules[full_module_name] = module + shutil.copy2(file_path, Path(__file__).parent.parent / "plugins") + logger.info(f"Load {file_path} as {full_module_name}") + return module diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py index 70fb2930c9..3ad4e844fe 100644 --- a/trinity/utils/registry.py +++ b/trinity/utils/registry.py @@ -1,21 +1,4 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -------------------------------------------------------- -# Most of the code here has been modified from: -# https://github.com/modelscope/modelscope/blob/master/modelscope/utils/registry.py -# -------------------------------------------------------- +from typing import Any, Type from trinity.utils.log import get_logger @@ -23,59 +6,57 @@ class Registry(object): - """This class is used to register some modules to registry by a repo - name.""" + """A class for registry.""" def __init__(self, name: str): """ - Initialization method. - - :param name: a registry repo name + Args: + name (`str`): The name of the registry. """ self._name = name self._modules = {} @property - def name(self): + def name(self) -> str: """ Get name of current registry. - :return: name of current registry. + Returns: + `str`: The name of current registry. """ return self._name @property - def modules(self): + def modules(self) -> dict: """ Get all modules in current registry. - :return: a dict storing modules in current registry. + Returns: + `dict`: A dict storing modules in current registry. """ return self._modules - def list(self): + def list(self) -> None: """Logging the list of module in current registry.""" for m in self._modules.keys(): logger.info(f"{self._name}\t{m}") - def get(self, module_key): + def get(self, module_key: str) -> Any: """ Get module named module_key from in current registry. If not found, return None. - :param module_key: specified module name - :return: module named module_key + Args: + module_key (`str`): specified module name + + Returns: + `Any`: the module object """ return self._modules.get(module_key, None) def _register_module(self, module_name=None, module_cls=None, force=False): """ Register module to registry. - - :param module_name: module name - :param module_cls: module class object - :param force: Whether to override an existing class with the - same name. Default: False. """ if module_name is None: @@ -87,25 +68,35 @@ def _register_module(self, module_name=None, module_cls=None, force=False): self._modules[module_name] = module_cls module_cls._name = module_name - def register_module(self, module_name: str = None, module_cls: type = None, force=False): + def register_module(self, module_name: str, module_cls: Type = None, force=False, lazy=False): """ - Register module class object to registry with the specified modulename. + Register module class object to registry with the specified module name. - :param module_name: module name - :param module_cls: module class object - :param force: Whether to override an existing class with - the same name. Default: False. + Args: + module_name (`str`): The module name. + module_cls (`Type`): module class object + force (`bool`): Whether to override an existing class with + the same name. Default: False. + lazy (`bool`): Whether to register the module class object lazily. + Default: False. Example: - >>> registry = Registry() - >>> @registry.register_module() - >>> class TextFormatter: - >>> pass - - >>> class TextFormatter2: - >>> pass - >>> registry.register_module( module_name='text_formatter2', - module_cls=TextFormatter2) + ```python + WORKFLOWS = Registry("workflows") + + # register a module using decorator + @WORKFLOWS.register_module(name="workflow_name") + class MyWorkflow(Workflow): + pass + + # or register a module directly + WORKFLOWS.register_module( + name="workflow_name", + module_cls=MyWorkflow, + force=True, + ) + ``` + """ if not (module_name is None or isinstance(module_name, str)): raise TypeError(f"module_name must be either of None, str," f"got {type(module_name)}") @@ -118,8 +109,10 @@ def _register(module_cls): """ Register module class object to registry. - :param module_cls: module class object - :return: module class object. + Args: + module_cls (`Type`): module class object + Returns: + `Type`: Decorated module class object. """ self._register_module(module_name=module_name, module_cls=module_cls, force=force) return module_cls