Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/sphinx_doc/source/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,14 @@ Here, `<config_file_path>` is the path to a YAML configuration file, which shoul
Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow:

```bash
trinity debug --config <config_file_path> --module workflow --output-file <output_file_path> --plugin-dir <plugin_dir>
trinity debug --config <config_file_path> --module workflow --output-dir <output_dir> --plugin-dir <plugin_dir> --enable-profiling
```

- `<config_file_path>`: Path to the YAML configuration file, usually the same as used for starting the inference model.
- `<output_file_path>`: Path to save the performance profiling results. Debug Mode uses [viztracer](https://github.com/gaogaotiantian/viztracer) to profile the workflow execution and saves the results as an HTML file for easy viewing in a browser.
- `<output_dir>`: Directory to save the debug output. If not specified, the output will be saved to the `debug_output` in the current working directory.
- `<plugin_dir>` (optional): Path to the plugin directory. If your workflow or reward function modules are not built into Trinity-RFT, you can specify this parameter to load custom modules.
- `--enable-profiling` (optional): Enable performance profiling using [viztracer](https://github.com/gaogaotiantian/viztracer).

During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection.
During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection and the output experiences will be saved to the `output_dir/experiences.db` directory.

When debugging is complete, you can terminate the inference model by pressing `Ctrl+C` in its terminal.
99 changes: 74 additions & 25 deletions tests/cli/launcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def setUp(self):

def tearDown(self):
sys.argv = self._orig_argv
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)

@mock.patch("trinity.cli.launcher.serve")
@mock.patch("trinity.cli.launcher.explore")
Expand Down Expand Up @@ -254,31 +255,79 @@ def test_multi_stage_run(
@mock.patch("trinity.cli.launcher.load_config")
def test_debug_mode(self, mock_load):
process = multiprocessing.Process(target=debug_inference_model_process)
process.start()
time.sleep(15) # wait for the model to be created
for _ in range(10):
try:
get_debug_inference_model(self.config)
break
except Exception:
time.sleep(3)
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
mock_load.return_value = self.config
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="workflow",
output_file=output_file,
plugin_dir="",
),
):
launcher.main()
process.join(timeout=10)
process.terminate()
self.assertTrue(os.path.exists(output_file))
try:
process.start()
time.sleep(15) # wait for the model to be created
for _ in range(10):
try:
get_debug_inference_model(self.config)
break
except Exception:
time.sleep(3)
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
output_dir = os.path.join(self.config.checkpoint_job_dir, "debug_output")
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
mock_load.return_value = self.config
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="workflow",
enable_profiling=True,
output_dir=output_dir,
output_file=output_file,
plugin_dir="",
),
):
launcher.main()

self.assertFalse(os.path.exists(output_file))
self.assertTrue(os.path.exists(output_dir))
self.assertTrue(os.path.exists(os.path.join(output_dir, "profiling.html")))
self.assertTrue(os.path.exists(os.path.join(output_dir, "experiences.db")))
# add a dummy file to test overwrite behavior
with open(os.path.join(output_dir, "dummy.txt"), "w") as f:
f.write("not empty")

with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="workflow",
enable_profiling=False,
output_dir=output_dir,
output_file=output_file,
plugin_dir="",
),
):
launcher.main()

self.assertFalse(os.path.exists(output_file))
# test the original files are not overwritten
self.assertTrue(os.path.exists(output_dir))
self.assertTrue(os.path.exists(os.path.join(output_dir, "dummy.txt")))
dirs = os.listdir(self.config.checkpoint_job_dir)
target_output_dir = [d for d in dirs if d.startswith("debug_output_")]
self.assertEqual(len(target_output_dir), 1)
self.assertFalse(
os.path.exists(
os.path.join(
self.config.checkpoint_job_dir, target_output_dir[0], "profiling.html"
)
)
)
self.assertTrue(
os.path.exists(
os.path.join(
self.config.checkpoint_job_dir, target_output_dir[0], "experiences.db"
)
)
)
finally:
process.join(timeout=10)
process.terminate()


def debug_inference_model_process():
Expand Down
7 changes: 7 additions & 0 deletions tests/utils/plugins/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""A file contains some dependencies."""

DEPENDENCY_VALUE = 0


def dependency_func():
return "0"
17 changes: 17 additions & 0 deletions tests/utils/plugins/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from tests.utils.plugins.dependencies import DEPENDENCY_VALUE, dependency_func
from trinity.common.workflows.workflow import Workflow


class MainDummyWorkflow(Workflow):
def __init__(self, *, task, model, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)

@property
def repeatable(self):
return True

def set_repeat_times(self, repeat_times, run_id_base):
pass

def run(self) -> list:
return [DEPENDENCY_VALUE, dependency_func()]
14 changes: 14 additions & 0 deletions tests/utils/registry_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import unittest

from trinity.common.config import Config
from trinity.common.workflows import WORKFLOWS, Workflow


class TestRegistry(unittest.TestCase):
def test_dynamic_import(self):
workflow_cls = WORKFLOWS.get("tests.utils.plugins.main.MainDummyWorkflow")
self.assertTrue(issubclass(workflow_cls, Workflow))
workflow = workflow_cls(task=None, model=None)
res = workflow.run()
self.assertEqual(res[0], 0)
self.assertEqual(res[1], "0")
2 changes: 1 addition & 1 deletion trinity/buffer/schema/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from trinity.common.experience import Experience
from trinity.common.models.utils import get_action_mask_method
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
from trinity.common.workflows.workflow import WORKFLOWS, Task
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

Expand Down
6 changes: 3 additions & 3 deletions trinity/buffer/writer/sql_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ def write(self, data: list) -> None:

async def write_async(self, data):
if self.wrap_in_ray:
ray.get(self.db_wrapper.write.remote(data))
return await self.db_wrapper.write.remote(data)
else:
self.db_wrapper.write(data)

async def acquire(self) -> int:
if self.wrap_in_ray:
return ray.get(self.db_wrapper.acquire.remote())
return await self.db_wrapper.acquire.remote()
else:
return 0

async def release(self) -> int:
if self.wrap_in_ray:
return ray.get(self.db_wrapper.release.remote())
return await self.db_wrapper.release.remote()
else:
self.db_wrapper.release()
return 0
22 changes: 17 additions & 5 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def studio(port: int = 8501):
def debug(
config_path: str,
module: str,
output_file: str = "debug_workflow_runner.html",
output_dir: str = "debug_output",
enable_viztracer: bool = False,
plugin_dir: str = None,
):
"""Debug a module."""
Expand All @@ -261,7 +262,7 @@ def debug(
elif module == "workflow":
from trinity.explorer.workflow_runner import DebugWorkflowRunner

runner = DebugWorkflowRunner(config, output_file)
runner = DebugWorkflowRunner(config, output_dir, enable_viztracer)
asyncio.run(runner.debug())
else:
raise ValueError(
Expand Down Expand Up @@ -308,11 +309,22 @@ def main() -> None:
default=None,
help="Path to the directory containing plugin modules.",
)
debug_parser.add_argument(
"--output-dir",
type=str,
default="debug_output",
help="The output directory for debug files.",
)
debug_parser.add_argument(
"--enable-profiling",
action="store_true",
help="Whether to use viztracer for workflow profiling.",
)
debug_parser.add_argument(
"--output-file",
type=str,
default="debug_workflow_runner.html",
help="The output file for viztracer.",
default=None,
help="[DEPRECATED] Please use --output-dir instead.",
)

args = parser.parse_args()
Expand All @@ -322,7 +334,7 @@ def main() -> None:
elif args.command == "studio":
studio(args.port)
elif args.command == "debug":
debug(args.config, args.module, args.output_file, args.plugin_dir)
debug(args.config, args.module, args.output_dir, args.enable_profiling, args.plugin_dir)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ async def sync_weight(self) -> None:

async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None:
for step in range(start_step, end_step + 1):
self.logger.info(f"Log metrics of step {step}")
self.logger.info(f"Waiting for step {step}")
await self._finish_explore_step(step=step, model_version=model_version)
await self._finish_eval_step(step=step)

Expand Down
49 changes: 40 additions & 9 deletions trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# -*- coding: utf-8 -*-
"""The Workflow Runner Module."""
import asyncio
import os
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

from trinity.buffer import get_buffer_reader
from trinity.common.config import Config
from trinity.buffer import get_buffer_reader, get_buffer_writer
from trinity.common.config import Config, ExperienceBufferConfig
from trinity.common.experience import Experience
from trinity.common.models import get_debug_inference_model
from trinity.common.models.model import InferenceModel, ModelWrapper
Expand Down Expand Up @@ -220,23 +221,53 @@ class DebugWorkflowRunner(WorkflowRunner):
def __init__(
self,
config: Config,
output_file: str,
output_dir: str = "debug_output",
enable_profiling: bool = False,
) -> None:
model, auxiliary_models = get_debug_inference_model(config)
super().__init__(config, model, auxiliary_models, 0)
self.taskset = get_buffer_reader(config.buffer.explorer_input.tasksets[0])
self.output_file = output_file
self.output_dir = output_dir
self.enable_profiling = enable_profiling
# if output dir is not empty, change to a new dir with datetime suffix
if os.path.isdir(self.output_dir) and os.listdir(self.output_dir):
suffix = time.strftime("%Y%m%d%H%M%S", time.localtime())
new_output_dir = f"{self.output_dir}_{suffix}"
self.output_dir = new_output_dir
self.logger.info(f"Debug output directory: {self.output_dir}")
os.makedirs(self.output_dir, exist_ok=True)
self.output_profiling_file = os.path.join(
self.output_dir,
"profiling.html",
)
self.output_sqlite_file = "sqlite:///" + os.path.join(
self.output_dir,
"experiences.db",
)
self.sqlite_writer = get_buffer_writer(
ExperienceBufferConfig(
name="debug_buffer",
schema_type="experience",
path=self.output_sqlite_file,
storage_type="sql",
batch_size=1,
)
)

async def debug(self) -> None:
"""Run the debug workflow."""
from viztracer import VizTracer

await self.prepare()
tasks = await self.taskset.read_async(batch_size=1)
task = tasks[0]
self.logger.info(f"Read task: {task.task_id}, repeat_times: {task.repeat_times}")
with VizTracer(output_file=self.output_file):
status, exps = await self.run_task(task, task.repeat_times, 0)
self.logger.info(f"Start debugging task:\n{task.raw_task}")
if not self.enable_profiling:
status, exps = await self.run_task(task, 1, 0)
else:
from viztracer import VizTracer

with VizTracer(output_file=self.output_profiling_file):
status, exps = await self.run_task(task, 1, 0)
await self.sqlite_writer.write_async(exps)
if status.ok:
print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}")
for exp in exps:
Expand Down
Loading
Loading