Skip to content

Commit 54c2d72

Browse files
authored
Enhance debug mode (#421)
1 parent b2fd301 commit 54c2d72

File tree

12 files changed

+240
-55
lines changed

12 files changed

+240
-55
lines changed

docs/sphinx_doc/source/tutorial/develop_workflow.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,13 +513,14 @@ Here, `<config_file_path>` is the path to a YAML configuration file, which shoul
513513
Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow:
514514

515515
```bash
516-
trinity debug --config <config_file_path> --module workflow --output-file <output_file_path> --plugin-dir <plugin_dir>
516+
trinity debug --config <config_file_path> --module workflow --output-dir <output_dir> --plugin-dir <plugin_dir> --enable-profiling
517517
```
518518

519519
- `<config_file_path>`: Path to the YAML configuration file, usually the same as used for starting the inference model.
520-
- `<output_file_path>`: Path to save the performance profiling results. Debug Mode uses [viztracer](https://github.com/gaogaotiantian/viztracer) to profile the workflow execution and saves the results as an HTML file for easy viewing in a browser.
520+
- `<output_dir>`: Directory to save the debug output. If not specified, the output will be saved to the `debug_output` in the current working directory.
521521
- `<plugin_dir>` (optional): Path to the plugin directory. If your workflow or reward function modules are not built into Trinity-RFT, you can specify this parameter to load custom modules.
522+
- `--enable-profiling` (optional): Enable performance profiling using [viztracer](https://github.com/gaogaotiantian/viztracer).
522523

523-
During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection.
524+
During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection and the output experiences will be saved to the `<output_dir>/experiences.db` file.
524525

525526
When debugging is complete, you can terminate the inference model by pressing `Ctrl+C` in its terminal.

docs/sphinx_doc/source_zh/tutorial/develop_workflow.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -509,13 +509,14 @@ trinity debug --config <config_file_path> --module inference_model
509509
模型启动后会持续运行并等待调试指令,不会自动退出。此时,你可在另一个终端执行如下命令进行 Workflow 调试:
510510

511511
```bash
512-
trinity debug --config <config_file_path> --module workflow --output-file <output_file_path> --plugin-dir <plugin_dir>
512+
trinity debug --config <config_file_path> --module workflow --output-dir <output_dir> --plugin-dir <plugin_dir> --enable-profiling
513513
```
514514

515-
- `config_file_path`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。
516-
- `output_file_path`:性能分析结果输出路径。调试模式会使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析,并将结果保存为 HTML 文件,便于在浏览器中查看。
517-
- `plugin_dir`(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。
515+
- `<config_file_path>`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。
516+
- `<output_dir>`:调试输出保存目录。如果未指定,调试输出将保存在当前工作目录下的 `debug_output` 目录中。
517+
- `<plugin_dir>`(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。
518+
- `--enable-profiling`(可选):启用性能分析,使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析。
518519

519-
调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端,方便查看运行结果
520+
调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端以供观察和查看,同时产出的 Experience 会保存到 `<output_dir>/experiences.db` 数据库中
520521

521522
调试完成后,可在推理模型终端输入 `Ctrl+C` 以终止模型运行。

tests/cli/launcher_test.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def setUp(self):
4141

4242
def tearDown(self):
4343
sys.argv = self._orig_argv
44+
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)
4445

4546
@mock.patch("trinity.cli.launcher.serve")
4647
@mock.patch("trinity.cli.launcher.explore")
@@ -254,31 +255,79 @@ def test_multi_stage_run(
254255
@mock.patch("trinity.cli.launcher.load_config")
255256
def test_debug_mode(self, mock_load):
256257
process = multiprocessing.Process(target=debug_inference_model_process)
257-
process.start()
258-
time.sleep(15) # wait for the model to be created
259-
for _ in range(10):
260-
try:
261-
get_debug_inference_model(self.config)
262-
break
263-
except Exception:
264-
time.sleep(3)
265-
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
266-
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
267-
mock_load.return_value = self.config
268-
with mock.patch(
269-
"argparse.ArgumentParser.parse_args",
270-
return_value=mock.Mock(
271-
command="debug",
272-
config="dummy.yaml",
273-
module="workflow",
274-
output_file=output_file,
275-
plugin_dir="",
276-
),
277-
):
278-
launcher.main()
279-
process.join(timeout=10)
280-
process.terminate()
281-
self.assertTrue(os.path.exists(output_file))
258+
try:
259+
process.start()
260+
time.sleep(15) # wait for the model to be created
261+
for _ in range(10):
262+
try:
263+
get_debug_inference_model(self.config)
264+
break
265+
except Exception:
266+
time.sleep(3)
267+
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
268+
output_dir = os.path.join(self.config.checkpoint_job_dir, "debug_output")
269+
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
270+
mock_load.return_value = self.config
271+
with mock.patch(
272+
"argparse.ArgumentParser.parse_args",
273+
return_value=mock.Mock(
274+
command="debug",
275+
config="dummy.yaml",
276+
module="workflow",
277+
enable_profiling=True,
278+
output_dir=output_dir,
279+
output_file=output_file,
280+
plugin_dir="",
281+
),
282+
):
283+
launcher.main()
284+
285+
self.assertFalse(os.path.exists(output_file))
286+
self.assertTrue(os.path.exists(output_dir))
287+
self.assertTrue(os.path.exists(os.path.join(output_dir, "profiling.html")))
288+
self.assertTrue(os.path.exists(os.path.join(output_dir, "experiences.db")))
289+
# add a dummy file to test overwrite behavior
290+
with open(os.path.join(output_dir, "dummy.txt"), "w") as f:
291+
f.write("not empty")
292+
293+
with mock.patch(
294+
"argparse.ArgumentParser.parse_args",
295+
return_value=mock.Mock(
296+
command="debug",
297+
config="dummy.yaml",
298+
module="workflow",
299+
enable_profiling=False,
300+
output_dir=output_dir,
301+
output_file=output_file,
302+
plugin_dir="",
303+
),
304+
):
305+
launcher.main()
306+
307+
self.assertFalse(os.path.exists(output_file))
308+
# test the original files are not overwritten
309+
self.assertTrue(os.path.exists(output_dir))
310+
self.assertTrue(os.path.exists(os.path.join(output_dir, "dummy.txt")))
311+
dirs = os.listdir(self.config.checkpoint_job_dir)
312+
target_output_dir = [d for d in dirs if d.startswith("debug_output_")]
313+
self.assertEqual(len(target_output_dir), 1)
314+
self.assertFalse(
315+
os.path.exists(
316+
os.path.join(
317+
self.config.checkpoint_job_dir, target_output_dir[0], "profiling.html"
318+
)
319+
)
320+
)
321+
self.assertTrue(
322+
os.path.exists(
323+
os.path.join(
324+
self.config.checkpoint_job_dir, target_output_dir[0], "experiences.db"
325+
)
326+
)
327+
)
328+
finally:
329+
process.join(timeout=10)
330+
process.terminate()
282331

283332

284333
def debug_inference_model_process():
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""A file contains some dependencies."""
2+
3+
DEPENDENCY_VALUE = 0
4+
5+
6+
def dependency_func():
7+
return "0"

tests/utils/plugins/main.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from tests.utils.plugins.dependencies import DEPENDENCY_VALUE, dependency_func
2+
from trinity.common.workflows.workflow import Workflow
3+
4+
5+
class MainDummyWorkflow(Workflow):
6+
def __init__(self, *, task, model, auxiliary_models=None):
7+
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
8+
9+
@property
10+
def repeatable(self):
11+
return True
12+
13+
def set_repeat_times(self, repeat_times, run_id_base):
14+
pass
15+
16+
def run(self) -> list:
17+
return [DEPENDENCY_VALUE, dependency_func()]

tests/utils/registry_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import unittest
2+
3+
import ray
4+
5+
6+
class ImportUtils:
7+
def run(self):
8+
from trinity.common.workflows import WORKFLOWS, Workflow
9+
10+
workflow_cls = WORKFLOWS.get("tests.utils.plugins.main.MainDummyWorkflow")
11+
assert issubclass(workflow_cls, Workflow)
12+
workflow = workflow_cls(task=None, model=None)
13+
res = workflow.run()
14+
assert res[0] == 0
15+
assert res[1] == "0"
16+
17+
18+
class TestRegistry(unittest.TestCase):
19+
def setUp(self):
20+
ray.init(ignore_reinit_error=True)
21+
22+
def tearDown(self):
23+
ray.shutdown()
24+
25+
def test_dynamic_import(self):
26+
# test local import
27+
ImportUtils().run()
28+
# test remote import
29+
ray.get(ray.remote(ImportUtils).remote().run.remote())

trinity/buffer/schema/formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from trinity.common.experience import Experience
1010
from trinity.common.models.utils import get_action_mask_method
1111
from trinity.common.rewards import REWARD_FUNCTIONS
12-
from trinity.common.workflows import WORKFLOWS, Task
12+
from trinity.common.workflows.workflow import WORKFLOWS, Task
1313
from trinity.utils.log import get_logger
1414
from trinity.utils.registry import Registry
1515

trinity/buffer/writer/sql_writer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,19 @@ def write(self, data: list) -> None:
2525

2626
async def write_async(self, data):
2727
if self.wrap_in_ray:
28-
ray.get(self.db_wrapper.write.remote(data))
28+
await self.db_wrapper.write.remote(data)
2929
else:
3030
self.db_wrapper.write(data)
3131

3232
async def acquire(self) -> int:
3333
if self.wrap_in_ray:
34-
return ray.get(self.db_wrapper.acquire.remote())
34+
return await self.db_wrapper.acquire.remote()
3535
else:
3636
return 0
3737

3838
async def release(self) -> int:
3939
if self.wrap_in_ray:
40-
return ray.get(self.db_wrapper.release.remote())
40+
return await self.db_wrapper.release.remote()
4141
else:
4242
self.db_wrapper.release()
4343
return 0

trinity/cli/launcher.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ def studio(port: int = 8501):
238238
def debug(
239239
config_path: str,
240240
module: str,
241-
output_file: str = "debug_workflow_runner.html",
241+
output_dir: str = "debug_output",
242+
enable_viztracer: bool = False,
242243
plugin_dir: str = None,
243244
):
244245
"""Debug a module."""
@@ -247,6 +248,7 @@ def debug(
247248
load_plugins()
248249
config = load_config(config_path)
249250
config.check_and_update()
251+
sys.path.insert(0, os.getcwd())
250252
config.ray_namespace = DEBUG_NAMESPACE
251253
ray.init(
252254
namespace=config.ray_namespace,
@@ -261,7 +263,7 @@ def debug(
261263
elif module == "workflow":
262264
from trinity.explorer.workflow_runner import DebugWorkflowRunner
263265

264-
runner = DebugWorkflowRunner(config, output_file)
266+
runner = DebugWorkflowRunner(config, output_dir, enable_viztracer)
265267
asyncio.run(runner.debug())
266268
else:
267269
raise ValueError(
@@ -308,11 +310,22 @@ def main() -> None:
308310
default=None,
309311
help="Path to the directory containing plugin modules.",
310312
)
313+
debug_parser.add_argument(
314+
"--output-dir",
315+
type=str,
316+
default="debug_output",
317+
help="The output directory for debug files.",
318+
)
319+
debug_parser.add_argument(
320+
"--enable-profiling",
321+
action="store_true",
322+
help="Whether to use viztracer for workflow profiling.",
323+
)
311324
debug_parser.add_argument(
312325
"--output-file",
313326
type=str,
314-
default="debug_workflow_runner.html",
315-
help="The output file for viztracer.",
327+
default=None,
328+
help="[DEPRECATED] Please use --output-dir instead.",
316329
)
317330

318331
args = parser.parse_args()
@@ -322,7 +335,7 @@ def main() -> None:
322335
elif args.command == "studio":
323336
studio(args.port)
324337
elif args.command == "debug":
325-
debug(args.config, args.module, args.output_file, args.plugin_dir)
338+
debug(args.config, args.module, args.output_dir, args.enable_profiling, args.plugin_dir)
326339

327340

328341
if __name__ == "__main__":

trinity/explorer/explorer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ async def sync_weight(self) -> None:
359359

360360
async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None:
361361
for step in range(start_step, end_step + 1):
362-
self.logger.info(f"Log metrics of step {step}")
362+
self.logger.info(f"Waiting for step {step}")
363363
await self._finish_explore_step(step=step, model_version=model_version)
364364
await self._finish_eval_step(step=step)
365365

0 commit comments

Comments
 (0)