diff --git a/docs/sphinx_doc/source/tutorial/develop_workflow.md b/docs/sphinx_doc/source/tutorial/develop_workflow.md index 2f580426ab..355e86ce0b 100644 --- a/docs/sphinx_doc/source/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source/tutorial/develop_workflow.md @@ -513,13 +513,14 @@ Here, `` is the path to a YAML configuration file, which shoul Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow: ```bash -trinity debug --config --module workflow --output-file --plugin-dir +trinity debug --config --module workflow --output-dir --plugin-dir --enable-profiling ``` - ``: Path to the YAML configuration file, usually the same as used for starting the inference model. -- ``: Path to save the performance profiling results. Debug Mode uses [viztracer](https://github.com/gaogaotiantian/viztracer) to profile the workflow execution and saves the results as an HTML file for easy viewing in a browser. +- ``: Directory to save the debug output. If not specified, the output will be saved to the `debug_output` in the current working directory. - `` (optional): Path to the plugin directory. If your workflow or reward function modules are not built into Trinity-RFT, you can specify this parameter to load custom modules. +- `--enable-profiling` (optional): Enable performance profiling using [viztracer](https://github.com/gaogaotiantian/viztracer). -During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection. +During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection and the output experiences will be saved to the `/experiences.db` file. When debugging is complete, you can terminate the inference model by pressing `Ctrl+C` in its terminal. diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md index 17ab4cc22a..364c425b09 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md @@ -509,13 +509,14 @@ trinity debug --config --module inference_model 模型启动后会持续运行并等待调试指令,不会自动退出。此时,你可在另一个终端执行如下命令进行 Workflow 调试: ```bash -trinity debug --config --module workflow --output-file --plugin-dir +trinity debug --config --module workflow --output-dir --plugin-dir --enable-profiling ``` -- `config_file_path`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。 -- `output_file_path`:性能分析结果输出路径。调试模式会使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析,并将结果保存为 HTML 文件,便于在浏览器中查看。 -- `plugin_dir`(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。 +- ``:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。 +- ``:调试输出保存目录。如果未指定,调试输出将保存在当前工作目录下的 `debug_output` 目录中。 +- ``(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。 +- `--enable-profiling`(可选):启用性能分析,使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析。 -调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端,方便查看运行结果。 +调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端以供观察和查看,同时产出的 Experience 会保存到 `/experiences.db` 数据库中。 调试完成后,可在推理模型终端输入 `Ctrl+C` 以终止模型运行。 diff --git a/tests/cli/launcher_test.py b/tests/cli/launcher_test.py index 99c3d77077..c28852e621 100644 --- a/tests/cli/launcher_test.py +++ b/tests/cli/launcher_test.py @@ -41,6 +41,7 @@ def setUp(self): def tearDown(self): sys.argv = self._orig_argv + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) @mock.patch("trinity.cli.launcher.serve") @mock.patch("trinity.cli.launcher.explore") @@ -254,31 +255,79 @@ def test_multi_stage_run( @mock.patch("trinity.cli.launcher.load_config") def test_debug_mode(self, mock_load): process = multiprocessing.Process(target=debug_inference_model_process) - process.start() - time.sleep(15) # wait for the model to be created - for _ in range(10): - try: - get_debug_inference_model(self.config) - break - except Exception: - time.sleep(3) - output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html") - self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")] - mock_load.return_value = self.config - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="debug", - config="dummy.yaml", - module="workflow", - output_file=output_file, - plugin_dir="", - ), - ): - launcher.main() - process.join(timeout=10) - process.terminate() - self.assertTrue(os.path.exists(output_file)) + try: + process.start() + time.sleep(15) # wait for the model to be created + for _ in range(10): + try: + get_debug_inference_model(self.config) + break + except Exception: + time.sleep(3) + output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html") + output_dir = os.path.join(self.config.checkpoint_job_dir, "debug_output") + self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")] + mock_load.return_value = self.config + with mock.patch( + "argparse.ArgumentParser.parse_args", + return_value=mock.Mock( + command="debug", + config="dummy.yaml", + module="workflow", + enable_profiling=True, + output_dir=output_dir, + output_file=output_file, + plugin_dir="", + ), + ): + launcher.main() + + self.assertFalse(os.path.exists(output_file)) + self.assertTrue(os.path.exists(output_dir)) + self.assertTrue(os.path.exists(os.path.join(output_dir, "profiling.html"))) + self.assertTrue(os.path.exists(os.path.join(output_dir, "experiences.db"))) + # add a dummy file to test overwrite behavior + with open(os.path.join(output_dir, "dummy.txt"), "w") as f: + f.write("not empty") + + with mock.patch( + "argparse.ArgumentParser.parse_args", + return_value=mock.Mock( + command="debug", + config="dummy.yaml", + module="workflow", + enable_profiling=False, + output_dir=output_dir, + output_file=output_file, + plugin_dir="", + ), + ): + launcher.main() + + self.assertFalse(os.path.exists(output_file)) + # test the original files are not overwritten + self.assertTrue(os.path.exists(output_dir)) + self.assertTrue(os.path.exists(os.path.join(output_dir, "dummy.txt"))) + dirs = os.listdir(self.config.checkpoint_job_dir) + target_output_dir = [d for d in dirs if d.startswith("debug_output_")] + self.assertEqual(len(target_output_dir), 1) + self.assertFalse( + os.path.exists( + os.path.join( + self.config.checkpoint_job_dir, target_output_dir[0], "profiling.html" + ) + ) + ) + self.assertTrue( + os.path.exists( + os.path.join( + self.config.checkpoint_job_dir, target_output_dir[0], "experiences.db" + ) + ) + ) + finally: + process.join(timeout=10) + process.terminate() def debug_inference_model_process(): diff --git a/tests/utils/plugins/dependencies.py b/tests/utils/plugins/dependencies.py new file mode 100644 index 0000000000..6f3b5f76b2 --- /dev/null +++ b/tests/utils/plugins/dependencies.py @@ -0,0 +1,7 @@ +"""A file contains some dependencies.""" + +DEPENDENCY_VALUE = 0 + + +def dependency_func(): + return "0" diff --git a/tests/utils/plugins/main.py b/tests/utils/plugins/main.py new file mode 100644 index 0000000000..fc07712658 --- /dev/null +++ b/tests/utils/plugins/main.py @@ -0,0 +1,17 @@ +from tests.utils.plugins.dependencies import DEPENDENCY_VALUE, dependency_func +from trinity.common.workflows.workflow import Workflow + + +class MainDummyWorkflow(Workflow): + def __init__(self, *, task, model, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + + @property + def repeatable(self): + return True + + def set_repeat_times(self, repeat_times, run_id_base): + pass + + def run(self) -> list: + return [DEPENDENCY_VALUE, dependency_func()] diff --git a/tests/utils/registry_test.py b/tests/utils/registry_test.py new file mode 100644 index 0000000000..4bdfff1e12 --- /dev/null +++ b/tests/utils/registry_test.py @@ -0,0 +1,29 @@ +import unittest + +import ray + + +class ImportUtils: + def run(self): + from trinity.common.workflows import WORKFLOWS, Workflow + + workflow_cls = WORKFLOWS.get("tests.utils.plugins.main.MainDummyWorkflow") + assert issubclass(workflow_cls, Workflow) + workflow = workflow_cls(task=None, model=None) + res = workflow.run() + assert res[0] == 0 + assert res[1] == "0" + + +class TestRegistry(unittest.TestCase): + def setUp(self): + ray.init(ignore_reinit_error=True) + + def tearDown(self): + ray.shutdown() + + def test_dynamic_import(self): + # test local import + ImportUtils().run() + # test remote import + ray.get(ray.remote(ImportUtils).remote().run.remote()) diff --git a/trinity/buffer/schema/formatter.py b/trinity/buffer/schema/formatter.py index 976d23eb76..8e3fe4e69a 100644 --- a/trinity/buffer/schema/formatter.py +++ b/trinity/buffer/schema/formatter.py @@ -9,7 +9,7 @@ from trinity.common.experience import Experience from trinity.common.models.utils import get_action_mask_method from trinity.common.rewards import REWARD_FUNCTIONS -from trinity.common.workflows import WORKFLOWS, Task +from trinity.common.workflows.workflow import WORKFLOWS, Task from trinity.utils.log import get_logger from trinity.utils.registry import Registry diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index 4b333af09c..4823fa1c7b 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -25,19 +25,19 @@ def write(self, data: list) -> None: async def write_async(self, data): if self.wrap_in_ray: - ray.get(self.db_wrapper.write.remote(data)) + await self.db_wrapper.write.remote(data) else: self.db_wrapper.write(data) async def acquire(self) -> int: if self.wrap_in_ray: - return ray.get(self.db_wrapper.acquire.remote()) + return await self.db_wrapper.acquire.remote() else: return 0 async def release(self) -> int: if self.wrap_in_ray: - return ray.get(self.db_wrapper.release.remote()) + return await self.db_wrapper.release.remote() else: self.db_wrapper.release() return 0 diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 468ab2df53..e7572c2ca0 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -238,7 +238,8 @@ def studio(port: int = 8501): def debug( config_path: str, module: str, - output_file: str = "debug_workflow_runner.html", + output_dir: str = "debug_output", + enable_viztracer: bool = False, plugin_dir: str = None, ): """Debug a module.""" @@ -247,6 +248,7 @@ def debug( load_plugins() config = load_config(config_path) config.check_and_update() + sys.path.insert(0, os.getcwd()) config.ray_namespace = DEBUG_NAMESPACE ray.init( namespace=config.ray_namespace, @@ -261,7 +263,7 @@ def debug( elif module == "workflow": from trinity.explorer.workflow_runner import DebugWorkflowRunner - runner = DebugWorkflowRunner(config, output_file) + runner = DebugWorkflowRunner(config, output_dir, enable_viztracer) asyncio.run(runner.debug()) else: raise ValueError( @@ -308,11 +310,22 @@ def main() -> None: default=None, help="Path to the directory containing plugin modules.", ) + debug_parser.add_argument( + "--output-dir", + type=str, + default="debug_output", + help="The output directory for debug files.", + ) + debug_parser.add_argument( + "--enable-profiling", + action="store_true", + help="Whether to use viztracer for workflow profiling.", + ) debug_parser.add_argument( "--output-file", type=str, - default="debug_workflow_runner.html", - help="The output file for viztracer.", + default=None, + help="[DEPRECATED] Please use --output-dir instead.", ) args = parser.parse_args() @@ -322,7 +335,7 @@ def main() -> None: elif args.command == "studio": studio(args.port) elif args.command == "debug": - debug(args.config, args.module, args.output_file, args.plugin_dir) + debug(args.config, args.module, args.output_dir, args.enable_profiling, args.plugin_dir) if __name__ == "__main__": diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index cb421c2fce..cba0aa96c8 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -359,7 +359,7 @@ async def sync_weight(self) -> None: async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None: for step in range(start_step, end_step + 1): - self.logger.info(f"Log metrics of step {step}") + self.logger.info(f"Waiting for step {step}") await self._finish_explore_step(step=step, model_version=model_version) await self._finish_eval_step(step=step) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 8917274dd5..de64865fee 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- """The Workflow Runner Module.""" import asyncio +import os import time import traceback from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Optional, Tuple -from trinity.buffer import get_buffer_reader -from trinity.common.config import Config +from trinity.buffer import get_buffer_reader, get_buffer_writer +from trinity.common.config import Config, ExperienceBufferConfig from trinity.common.experience import Experience from trinity.common.models import get_debug_inference_model from trinity.common.models.model import InferenceModel, ModelWrapper @@ -220,23 +221,53 @@ class DebugWorkflowRunner(WorkflowRunner): def __init__( self, config: Config, - output_file: str, + output_dir: str = "debug_output", + enable_profiling: bool = False, ) -> None: model, auxiliary_models = get_debug_inference_model(config) super().__init__(config, model, auxiliary_models, 0) self.taskset = get_buffer_reader(config.buffer.explorer_input.tasksets[0]) - self.output_file = output_file + self.output_dir = output_dir + self.enable_profiling = enable_profiling + # if output dir is not empty, change to a new dir with datetime suffix + if os.path.isdir(self.output_dir) and os.listdir(self.output_dir): + suffix = time.strftime("%Y%m%d%H%M%S", time.localtime()) + new_output_dir = f"{self.output_dir}_{suffix}" + self.output_dir = new_output_dir + self.logger.info(f"Debug output directory: {self.output_dir}") + os.makedirs(self.output_dir, exist_ok=True) + self.output_profiling_file = os.path.join( + self.output_dir, + "profiling.html", + ) + self.output_sqlite_file = "sqlite:///" + os.path.join( + self.output_dir, + "experiences.db", + ) + self.sqlite_writer = get_buffer_writer( + ExperienceBufferConfig( + name="debug_buffer", + schema_type="experience", + path=self.output_sqlite_file, + storage_type="sql", + batch_size=1, + ) + ) async def debug(self) -> None: """Run the debug workflow.""" - from viztracer import VizTracer - await self.prepare() tasks = await self.taskset.read_async(batch_size=1) task = tasks[0] - self.logger.info(f"Read task: {task.task_id}, repeat_times: {task.repeat_times}") - with VizTracer(output_file=self.output_file): - status, exps = await self.run_task(task, task.repeat_times, 0) + self.logger.info(f"Start debugging task:\n{task.raw_task}") + if not self.enable_profiling: + status, exps = await self.run_task(task, 1, 0) + else: + from viztracer import VizTracer + + with VizTracer(output_file=self.output_profiling_file): + status, exps = await self.run_task(task, 1, 0) + await self.sqlite_writer.write_async(exps) if status.ok: print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}") for exp in exps: diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py index e5f6806378..d7c0858c8f 100644 --- a/trinity/utils/registry.py +++ b/trinity/utils/registry.py @@ -1,8 +1,9 @@ +import traceback from typing import Any, Type +from trinity.utils.log import get_logger + -# TODO: support lazy load -# e.g. @MODULES.register_module("name", lazy=True) class Registry(object): """A class for registry.""" @@ -13,6 +14,7 @@ def __init__(self, name: str): """ self._name = name self._modules = {} + self.logger = get_logger() @property def name(self) -> str: @@ -45,7 +47,21 @@ def get(self, module_key) -> Any: Returns: `Any`: the module object """ - return self._modules.get(module_key, None) + module = self._modules.get(module_key, None) + if module is None: + # try to dynamic import + if isinstance(module_key, str) and "." in module_key: + module_path, class_name = module_key.rsplit(".", 1) + try: + module = self._dynamic_import(module_path, class_name) + except Exception: + self.logger.error( + f"Failed to dynamically import {class_name} from {module_path}:\n" + + traceback.format_exc() + ) + raise ImportError(f"Cannot dynamically import {class_name} from {module_path}") + self._register_module(module_name=module_key, module_cls=module) + return module def _register_module(self, module_name=None, module_cls=None, force=False): """ @@ -56,6 +72,10 @@ def _register_module(self, module_name=None, module_cls=None, force=False): module_name = module_cls.__name__ if module_name in self._modules and not force: + self.logger.warning( + f"{module_name} is already registered in {self._name}, " + f"if you want to override it, please set force=True." + ) raise KeyError(f"{module_name} is already registered in {self._name}") self._modules[module_name] = module_cls @@ -111,3 +131,20 @@ def _register(module_cls): return module_cls return _register + + def _dynamic_import(self, module_path: str, class_name: str) -> Type: + """ + Dynamically import a module class object from the specified module path. + + Args: + module_path (`str`): The module path. For example, "my_package.my_module". + class_name (`str`): The class name. For example, "MyWorkflow". + + Returns: + `Type`: The imported module class object. + """ + import importlib + + module = importlib.import_module(module_path) + module_cls = getattr(module, class_name) + return module_cls