Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/sphinx_doc/source/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,14 @@ Here, `<config_file_path>` is the path to a YAML configuration file, which shoul
Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow:

```bash
trinity debug --config <config_file_path> --module workflow --output-file <output_file_path> --plugin-dir <plugin_dir>
trinity debug --config <config_file_path> --module workflow --output-dir <output_dir> --plugin-dir <plugin_dir> --enable-profiling
```

- `<config_file_path>`: Path to the YAML configuration file, usually the same as used for starting the inference model.
- `<output_file_path>`: Path to save the performance profiling results. Debug Mode uses [viztracer](https://github.com/gaogaotiantian/viztracer) to profile the workflow execution and saves the results as an HTML file for easy viewing in a browser.
- `<output_dir>`: Directory to save the debug output. If not specified, the output will be saved to the `debug_output` in the current working directory.
- `<plugin_dir>` (optional): Path to the plugin directory. If your workflow or reward function modules are not built into Trinity-RFT, you can specify this parameter to load custom modules.
- `--enable-profiling` (optional): Enable performance profiling using [viztracer](https://github.com/gaogaotiantian/viztracer).

During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection.
During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection and the output experiences will be saved to the `<output_dir>/experiences.db` file.

When debugging is complete, you can terminate the inference model by pressing `Ctrl+C` in its terminal.
11 changes: 6 additions & 5 deletions docs/sphinx_doc/source_zh/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,14 @@ trinity debug --config <config_file_path> --module inference_model
模型启动后会持续运行并等待调试指令,不会自动退出。此时,你可在另一个终端执行如下命令进行 Workflow 调试:

```bash
trinity debug --config <config_file_path> --module workflow --output-file <output_file_path> --plugin-dir <plugin_dir>
trinity debug --config <config_file_path> --module workflow --output-dir <output_dir> --plugin-dir <plugin_dir> --enable-profiling
```

- `config_file_path`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。
- `output_file_path`:性能分析结果输出路径。调试模式会使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析,并将结果保存为 HTML 文件,便于在浏览器中查看。
- `plugin_dir`(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。
- `<config_file_path>`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。
- `<output_dir>`:调试输出保存目录。如果未指定,调试输出将保存在当前工作目录下的 `debug_output` 目录中。
- `<plugin_dir>`(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。
- `--enable-profiling`(可选):启用性能分析,使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析。

调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端,方便查看运行结果
调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端以供观察和查看,同时产出的 Experience 会保存到 `<output_dir>/experiences.db` 数据库中

调试完成后,可在推理模型终端输入 `Ctrl+C` 以终止模型运行。
99 changes: 74 additions & 25 deletions tests/cli/launcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def setUp(self):

def tearDown(self):
sys.argv = self._orig_argv
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)

@mock.patch("trinity.cli.launcher.serve")
@mock.patch("trinity.cli.launcher.explore")
Expand Down Expand Up @@ -254,31 +255,79 @@ def test_multi_stage_run(
@mock.patch("trinity.cli.launcher.load_config")
def test_debug_mode(self, mock_load):
process = multiprocessing.Process(target=debug_inference_model_process)
process.start()
time.sleep(15) # wait for the model to be created
for _ in range(10):
try:
get_debug_inference_model(self.config)
break
except Exception:
time.sleep(3)
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
mock_load.return_value = self.config
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="workflow",
output_file=output_file,
plugin_dir="",
),
):
launcher.main()
process.join(timeout=10)
process.terminate()
self.assertTrue(os.path.exists(output_file))
try:
process.start()
time.sleep(15) # wait for the model to be created
for _ in range(10):
try:
get_debug_inference_model(self.config)
break
except Exception:
time.sleep(3)
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
output_dir = os.path.join(self.config.checkpoint_job_dir, "debug_output")
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
mock_load.return_value = self.config
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="workflow",
enable_profiling=True,
output_dir=output_dir,
output_file=output_file,
plugin_dir="",
),
):
launcher.main()

self.assertFalse(os.path.exists(output_file))
self.assertTrue(os.path.exists(output_dir))
self.assertTrue(os.path.exists(os.path.join(output_dir, "profiling.html")))
self.assertTrue(os.path.exists(os.path.join(output_dir, "experiences.db")))
# add a dummy file to test overwrite behavior
with open(os.path.join(output_dir, "dummy.txt"), "w") as f:
f.write("not empty")

with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="workflow",
enable_profiling=False,
output_dir=output_dir,
output_file=output_file,
plugin_dir="",
),
):
launcher.main()

self.assertFalse(os.path.exists(output_file))
# test the original files are not overwritten
self.assertTrue(os.path.exists(output_dir))
self.assertTrue(os.path.exists(os.path.join(output_dir, "dummy.txt")))
dirs = os.listdir(self.config.checkpoint_job_dir)
target_output_dir = [d for d in dirs if d.startswith("debug_output_")]
self.assertEqual(len(target_output_dir), 1)
self.assertFalse(
os.path.exists(
os.path.join(
self.config.checkpoint_job_dir, target_output_dir[0], "profiling.html"
)
)
)
self.assertTrue(
os.path.exists(
os.path.join(
self.config.checkpoint_job_dir, target_output_dir[0], "experiences.db"
)
)
)
finally:
process.join(timeout=10)
process.terminate()


def debug_inference_model_process():
Expand Down
7 changes: 7 additions & 0 deletions tests/utils/plugins/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""A file contains some dependencies."""

DEPENDENCY_VALUE = 0


def dependency_func():
return "0"
17 changes: 17 additions & 0 deletions tests/utils/plugins/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from tests.utils.plugins.dependencies import DEPENDENCY_VALUE, dependency_func
from trinity.common.workflows.workflow import Workflow


class MainDummyWorkflow(Workflow):
def __init__(self, *, task, model, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)

@property
def repeatable(self):
return True

def set_repeat_times(self, repeat_times, run_id_base):
pass

def run(self) -> list:
return [DEPENDENCY_VALUE, dependency_func()]
13 changes: 13 additions & 0 deletions tests/utils/registry_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import unittest

from trinity.common.workflows import WORKFLOWS, Workflow


class TestRegistry(unittest.TestCase):
def test_dynamic_import(self):
workflow_cls = WORKFLOWS.get("tests.utils.plugins.main.MainDummyWorkflow")
self.assertTrue(issubclass(workflow_cls, Workflow))
workflow = workflow_cls(task=None, model=None)
res = workflow.run()
self.assertEqual(res[0], 0)
self.assertEqual(res[1], "0")
2 changes: 1 addition & 1 deletion trinity/buffer/schema/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from trinity.common.experience import Experience
from trinity.common.models.utils import get_action_mask_method
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
from trinity.common.workflows.workflow import WORKFLOWS, Task
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

Expand Down
6 changes: 3 additions & 3 deletions trinity/buffer/writer/sql_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ def write(self, data: list) -> None:

async def write_async(self, data):
if self.wrap_in_ray:
ray.get(self.db_wrapper.write.remote(data))
await self.db_wrapper.write.remote(data)
else:
self.db_wrapper.write(data)

async def acquire(self) -> int:
if self.wrap_in_ray:
return ray.get(self.db_wrapper.acquire.remote())
return await self.db_wrapper.acquire.remote()
else:
return 0

async def release(self) -> int:
if self.wrap_in_ray:
return ray.get(self.db_wrapper.release.remote())
return await self.db_wrapper.release.remote()
else:
self.db_wrapper.release()
return 0
22 changes: 17 additions & 5 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def studio(port: int = 8501):
def debug(
config_path: str,
module: str,
output_file: str = "debug_workflow_runner.html",
output_dir: str = "debug_output",
enable_viztracer: bool = False,
plugin_dir: str = None,
):
"""Debug a module."""
Expand All @@ -261,7 +262,7 @@ def debug(
elif module == "workflow":
from trinity.explorer.workflow_runner import DebugWorkflowRunner

runner = DebugWorkflowRunner(config, output_file)
runner = DebugWorkflowRunner(config, output_dir, enable_viztracer)
asyncio.run(runner.debug())
else:
raise ValueError(
Expand Down Expand Up @@ -308,11 +309,22 @@ def main() -> None:
default=None,
help="Path to the directory containing plugin modules.",
)
debug_parser.add_argument(
"--output-dir",
type=str,
default="debug_output",
help="The output directory for debug files.",
)
debug_parser.add_argument(
"--enable-profiling",
action="store_true",
help="Whether to use viztracer for workflow profiling.",
)
debug_parser.add_argument(
"--output-file",
type=str,
default="debug_workflow_runner.html",
help="The output file for viztracer.",
default=None,
help="[DEPRECATED] Please use --output-dir instead.",
)

args = parser.parse_args()
Expand All @@ -322,7 +334,7 @@ def main() -> None:
elif args.command == "studio":
studio(args.port)
elif args.command == "debug":
debug(args.config, args.module, args.output_file, args.plugin_dir)
debug(args.config, args.module, args.output_dir, args.enable_profiling, args.plugin_dir)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ async def sync_weight(self) -> None:

async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None:
for step in range(start_step, end_step + 1):
self.logger.info(f"Log metrics of step {step}")
self.logger.info(f"Waiting for step {step}")
await self._finish_explore_step(step=step, model_version=model_version)
await self._finish_eval_step(step=step)

Expand Down
49 changes: 40 additions & 9 deletions trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# -*- coding: utf-8 -*-
"""The Workflow Runner Module."""
import asyncio
import os
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

from trinity.buffer import get_buffer_reader
from trinity.common.config import Config
from trinity.buffer import get_buffer_reader, get_buffer_writer
from trinity.common.config import Config, ExperienceBufferConfig
from trinity.common.experience import Experience
from trinity.common.models import get_debug_inference_model
from trinity.common.models.model import InferenceModel, ModelWrapper
Expand Down Expand Up @@ -220,23 +221,53 @@ class DebugWorkflowRunner(WorkflowRunner):
def __init__(
self,
config: Config,
output_file: str,
output_dir: str = "debug_output",
enable_profiling: bool = False,
) -> None:
model, auxiliary_models = get_debug_inference_model(config)
super().__init__(config, model, auxiliary_models, 0)
self.taskset = get_buffer_reader(config.buffer.explorer_input.tasksets[0])
self.output_file = output_file
self.output_dir = output_dir
self.enable_profiling = enable_profiling
# if output dir is not empty, change to a new dir with datetime suffix
if os.path.isdir(self.output_dir) and os.listdir(self.output_dir):
suffix = time.strftime("%Y%m%d%H%M%S", time.localtime())
new_output_dir = f"{self.output_dir}_{suffix}"
self.output_dir = new_output_dir
self.logger.info(f"Debug output directory: {self.output_dir}")
os.makedirs(self.output_dir, exist_ok=True)
self.output_profiling_file = os.path.join(
self.output_dir,
"profiling.html",
)
self.output_sqlite_file = "sqlite:///" + os.path.join(
self.output_dir,
"experiences.db",
)
self.sqlite_writer = get_buffer_writer(
ExperienceBufferConfig(
name="debug_buffer",
schema_type="experience",
path=self.output_sqlite_file,
storage_type="sql",
batch_size=1,
)
)

async def debug(self) -> None:
"""Run the debug workflow."""
from viztracer import VizTracer

await self.prepare()
tasks = await self.taskset.read_async(batch_size=1)
task = tasks[0]
self.logger.info(f"Read task: {task.task_id}, repeat_times: {task.repeat_times}")
with VizTracer(output_file=self.output_file):
status, exps = await self.run_task(task, task.repeat_times, 0)
self.logger.info(f"Start debugging task:\n{task.raw_task}")
if not self.enable_profiling:
status, exps = await self.run_task(task, 1, 0)
else:
from viztracer import VizTracer

with VizTracer(output_file=self.output_profiling_file):
status, exps = await self.run_task(task, 1, 0)
await self.sqlite_writer.write_async(exps)
if status.ok:
print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}")
for exp in exps:
Expand Down
Loading