diff --git a/docs/sphinx_doc/source/tutorial/develop_workflow.md b/docs/sphinx_doc/source/tutorial/develop_workflow.md index 355e86ce0b..68a6ab77a2 100644 --- a/docs/sphinx_doc/source/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source/tutorial/develop_workflow.md @@ -499,7 +499,8 @@ During Workflow development, repeatedly launching the full training process for ```{mermaid} flowchart LR A[Start Inference Model] --> B[Debug Workflow] - B --> B + B --> C[Check Experiences] + C --> B ``` To start the inference model, use the following command: @@ -513,14 +514,22 @@ Here, `` is the path to a YAML configuration file, which shoul Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow: ```bash -trinity debug --config --module workflow --output-dir --plugin-dir --enable-profiling +trinity debug --config --module workflow --output-dir [--plugin-dir ] [--enable-profiling] [--disable-overwrite] ``` - ``: Path to the YAML configuration file, usually the same as used for starting the inference model. - ``: Directory to save the debug output. If not specified, the output will be saved to the `debug_output` in the current working directory. - `` (optional): Path to the plugin directory. If your workflow or reward function modules are not built into Trinity-RFT, you can specify this parameter to load custom modules. - `--enable-profiling` (optional): Enable performance profiling using [viztracer](https://github.com/gaogaotiantian/viztracer). +- `--disable-overwrite` (optional): Disable overwriting the output directory. If the directory is not empty, it will automatically change to a new directory with a timestamp suffix (e.g., `debug_output_20251203211200`) to avoid overwriting existing data. -During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection and the output experiences will be saved to the `/experiences.db` file. +During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return experiences will be written to the `experiences.db` file in the specified output directory. Additionally, the metrics will be printed in the terminal for easy inspection. + +```bash +trinity debug --config --module viewer --output-dir --port 8502 +``` + +This command launches the Experience Viewer at `http://localhost:8502` to visualize the experiences generated during debugging. You can inspect the generated experiences in a user-friendly interface. +Note that the viewer reads experiences from the `experiences.db` file in the specified output directory, so ensure that you have successfully run the workflow debug command beforehand and use the same output directory. When debugging is complete, you can terminate the inference model by pressing `Ctrl+C` in its terminal. diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md index 364c425b09..8e27d69e51 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md @@ -495,7 +495,8 @@ class MyWorkflow(Workflow): ```{mermaid} flowchart LR A[启动推理模型] --> B[调试 Workflow] - B --> B + B --> C[检查 Experience] + C --> B ``` 启动推理模型的命令如下: @@ -509,14 +510,21 @@ trinity debug --config --module inference_model 模型启动后会持续运行并等待调试指令,不会自动退出。此时,你可在另一个终端执行如下命令进行 Workflow 调试: ```bash -trinity debug --config --module workflow --output-dir --plugin-dir --enable-profiling +trinity debug --config --module workflow --output-dir [--plugin-dir ] [--enable-profiling] [--disable-overwrite] ``` - ``:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。 - ``:调试输出保存目录。如果未指定,调试输出将保存在当前工作目录下的 `debug_output` 目录中。 - ``(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。 - `--enable-profiling`(可选):启用性能分析,使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析。 +- `--disable-overwrite`(可选):禁用输出目录覆盖功能。如果指定的文件夹非空,程序将自动创建一个带有时间戳后缀的新目录(例如 `debug_output_20251203211200`)以避免覆盖现有数据。 -调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端以供观察和查看,同时产出的 Experience 会保存到 `/experiences.db` 数据库中。 +调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,工作流的返回 Experience 会被写入指定输出目录下的 `experiences.db` 文件中,而运行过程中记录的指标会打印在终端以便检查。 + +```bash +trinity debug --config --module viewer --output-dir --port 8502 +``` + +该命令会在 `http://localhost:8502` 启动 Experience Viewer,用于可视化调试过程中生成的 Experience。你可以在用户友好的界面中检查生成的 Experience。需注意,Viewer 会从指定输出目录下的 `experiences.db` 文件中读取 Experience,因此请确保你已成功运行过 Workflow 调试命令,且替换 `` 为实际的输出目录。 调试完成后,可在推理模型终端输入 `Ctrl+C` 以终止模型运行。 diff --git a/tests/cli/launcher_test.py b/tests/cli/launcher_test.py index c28852e621..17ef5f8f5c 100644 --- a/tests/cli/launcher_test.py +++ b/tests/cli/launcher_test.py @@ -275,6 +275,7 @@ def test_debug_mode(self, mock_load): config="dummy.yaml", module="workflow", enable_profiling=True, + disable_overwrite=False, output_dir=output_dir, output_file=output_file, plugin_dir="", @@ -297,6 +298,26 @@ def test_debug_mode(self, mock_load): config="dummy.yaml", module="workflow", enable_profiling=False, + disable_overwrite=False, + output_dir=output_dir, + output_file=output_file, + plugin_dir="", + ), + ): + launcher.main() + + dirs = os.listdir(self.config.checkpoint_job_dir) + target_output_dir = [d for d in dirs if d.startswith("debug_output_")] + self.assertEqual(len(target_output_dir), 0) + + with mock.patch( + "argparse.ArgumentParser.parse_args", + return_value=mock.Mock( + command="debug", + config="dummy.yaml", + module="workflow", + enable_profiling=False, + disable_overwrite=True, output_dir=output_dir, output_file=output_file, plugin_dir="", diff --git a/trinity/buffer/viewer.py b/trinity/buffer/viewer.py new file mode 100644 index 0000000000..32b647ddd3 --- /dev/null +++ b/trinity/buffer/viewer.py @@ -0,0 +1,371 @@ +import argparse +from typing import List + +import streamlit as st +import streamlit.components.v1 as components +from sqlalchemy.orm import sessionmaker +from transformers import AutoTokenizer + +from trinity.buffer.schema import init_engine +from trinity.common.config import StorageConfig +from trinity.common.experience import Experience +from trinity.utils.log import get_logger + + +class SQLExperienceViewer: + def __init__(self, config: StorageConfig) -> None: + self.logger = get_logger(f"sql_{config.name}", in_ray_actor=True) + if not config.path: + raise ValueError("`path` is required for SQL storage type.") + self.engine, self.table_model_cls = init_engine( + db_url=config.path, + table_name=config.name, + schema_type=config.schema_type, + ) + self.session = sessionmaker(bind=self.engine) + + def get_experiences(self, offset: int, limit: int = 10) -> List[Experience]: + self.logger.info(f"Viewing experiences from offset {offset} with limit {limit}.") + with self.session() as session: + query = session.query(self.table_model_cls).offset(offset).limit(limit) + results = query.all() + exps = [self.table_model_cls.to_experience(row) for row in results] + return exps + + def total_experiences(self) -> int: + with self.session() as session: + count = session.query(self.table_model_cls).count() + return count + + +st.set_page_config(page_title="Trinity-RFT Experience Visualizer", layout="wide") + + +def get_color_for_action_mask(action_mask_value: int) -> str: + """Return color based on action_mask value""" + if action_mask_value == 1: + return "#c8e6c9" + else: + return "#ffcdd2" + + +def render_experience(exp: Experience, exp_index: int, tokenizer): + """Render a single experience sequence in Streamlit.""" + token_ids = exp.tokens + logprobs = exp.logprobs + action_mask = exp.action_mask + + prompt_length = exp.prompt_length + + prompt_token_ids = token_ids[:prompt_length] # type: ignore [index] + response_token_ids = token_ids[prompt_length:] # type: ignore [index] + + # Decode tokens + prompt_text = ( + tokenizer.decode(prompt_token_ids) + if hasattr(tokenizer, "decode") + else "".join([str(tid) for tid in prompt_token_ids]) + ) + response_text = ( + tokenizer.decode(response_token_ids) + if hasattr(tokenizer, "decode") + else "".join([str(tid) for tid in response_token_ids]) + ) + + # Get each response token text + response_tokens = [] + for tid in response_token_ids: + if hasattr(tokenizer, "decode"): + token_text = tokenizer.decode([tid]) + else: + token_text = f"[{tid}]" + response_tokens.append(token_text) + + # HTML escape function + def html_escape(text): + return ( + text.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + # Build full HTML (with CSS) + html = f""" + + + + + + + +
+
Experience {exp_index + 1}
+ +
+
📝 Prompt:
+
{html_escape(prompt_text)}
+ +
💬 Response:
+
{html_escape(response_text)}
+
+ +
+ +
🔍 Response Tokens Detail:
+
+ """ + + # Add each response token + for i, (token_text, logprob, mask) in enumerate(zip(response_tokens, logprobs, action_mask)): # type: ignore [arg-type] + bg_color = get_color_for_action_mask(mask) + + # Handle special character display + token_display = token_text.replace(" ", "␣").replace("\n", "↵").replace("\t", "⇥") + token_display = html_escape(token_display) + + html += f""" +
+
{token_display}
+
{logprob:.3f}
+
+ """ + + html += """ +
+
+ + + """ + + # Use components.html instead of st.markdown + components.html(html, height=1200, scrolling=True) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Experience Visualizer") + parser.add_argument( + "--db-url", + type=str, + help="Path to the experience database.", + ) + parser.add_argument( + "--table", + type=str, + help="Name of the experience table.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Path to the tokenizer.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # Initialize SQLExperienceViewer + config = StorageConfig( + name=args.table, + path=args.db_url, + schema_type="experience", + storage_type="sql", + ) + viewer = SQLExperienceViewer(config) + + st.title("🎯 Trinity-RFT Experience Visualizer") + if "page" not in st.session_state: + st.session_state.page = 1 + + # Add instructions + with st.expander("ℹ️ Instructions"): + st.markdown( + """ + - **Green background**: action_mask = 1 + - **Red background**: action_mask = 0 + - **Top**: Token text (special characters: space=␣, newline=↵, tab=⇥) + - **Bottom**: Logprob value of the token + - Hover over token to zoom in + """ + ) + + # Get total sequence number + total_seq_num = viewer.total_experiences() + + # Sidebar configuration + st.sidebar.header("⚙️ Settings") + + # Pagination settings + experiences_per_page = st.sidebar.slider( + "Experiences per page", min_value=1, max_value=20, value=5 + ) + + # Calculate total pages + total_pages = (total_seq_num + experiences_per_page - 1) // experiences_per_page + + # Page selection (sidebar) + current_page = st.sidebar.number_input( + "Select page", + min_value=1, + max_value=max(1, total_pages), + step=1, + value=st.session_state.page, + ) + if current_page != st.session_state.page: + st.session_state.page = current_page + st.rerun() + + # Show statistics + st.sidebar.markdown("---") + st.sidebar.metric("Total experiences", total_seq_num) + st.sidebar.metric("Total pages", total_pages) + st.sidebar.metric("Current page", f"{st.session_state.page}/{total_pages}") + + # Calculate offset + offset = (st.session_state.page - 1) * experiences_per_page + + # Get experiences for current page + experiences = viewer.get_experiences(offset, experiences_per_page) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + + # Render experiences + if experiences: + for i, exp in enumerate(experiences): + render_experience(exp, offset + i, tokenizer) + else: + st.warning("No experience data found") + + # Pagination navigation + st.markdown("---") + col1, col2, col3 = st.columns([1, 2, 1]) + + with col1: + if st.session_state.page > 1: + if st.button("⬅️ Previous Page"): + st.session_state.page = st.session_state.page - 1 + st.rerun() + + with col2: + st.markdown( + f"
Page {st.session_state.page} / {total_pages}
", unsafe_allow_html=True + ) + with col3: + if st.session_state.page < total_pages: + if st.button("Next Page ➡️"): + st.session_state.page = st.session_state.page + 1 + st.rerun() + + +if __name__ == "__main__": + main() diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index e7572c2ca0..dd3181c490 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -239,7 +239,9 @@ def debug( config_path: str, module: str, output_dir: str = "debug_output", - enable_viztracer: bool = False, + disable_overwrite: bool = False, + enable_profiling: bool = False, + port: int = 8502, plugin_dir: str = None, ): """Debug a module.""" @@ -263,8 +265,34 @@ def debug( elif module == "workflow": from trinity.explorer.workflow_runner import DebugWorkflowRunner - runner = DebugWorkflowRunner(config, output_dir, enable_viztracer) + runner = DebugWorkflowRunner(config, output_dir, enable_profiling, disable_overwrite) asyncio.run(runner.debug()) + elif module == "viewer": + from streamlit.web import cli as stcli + + current_dir = Path(__file__).resolve().parent.parent + viewer_path = os.path.join(current_dir, "buffer", "viewer.py") + output_dir_abs = os.path.abspath(output_dir) + if output_dir_abs.endswith("/"): + output_dir_abs = output_dir_abs[:-1] + print(f"sqlite:///{output_dir_abs}/experiences.db") + sys.argv = [ + "streamlit", + "run", + viewer_path, + "--server.port", + str(port), + "--server.fileWatcherType", + "none", + "--", + "--db-url", + f"sqlite:///{output_dir_abs}/experiences.db", + "--table", + "debug_buffer", + "--tokenizer", + config.model.model_path, + ] + sys.exit(stcli.main()) else: raise ValueError( f"Only support 'inference_model' and 'workflow' for debugging, got {module}" @@ -301,8 +329,8 @@ def main() -> None: debug_parser.add_argument( "--module", type=str, - choices=["inference_model", "workflow"], - help="The module to start debugging, only support 'inference_model' and 'workflow' for now.", + choices=["inference_model", "workflow", "viewer"], + help="The module to start debugging, only support 'inference_model', 'workflow' and 'viewer' for now.", ) debug_parser.add_argument( "--plugin-dir", @@ -316,6 +344,9 @@ def main() -> None: default="debug_output", help="The output directory for debug files.", ) + debug_parser.add_argument( + "--disable-overwrite", action="store_true", help="Disable overwriting the output directory." + ) debug_parser.add_argument( "--enable-profiling", action="store_true", @@ -327,6 +358,12 @@ def main() -> None: default=None, help="[DEPRECATED] Please use --output-dir instead.", ) + debug_parser.add_argument( + "--port", + type=int, + default=8502, + help="The port for Experience Viewer.", + ) args = parser.parse_args() if args.command == "run": @@ -335,7 +372,15 @@ def main() -> None: elif args.command == "studio": studio(args.port) elif args.command == "debug": - debug(args.config, args.module, args.output_dir, args.enable_profiling, args.plugin_dir) + debug( + args.config, + args.module, + args.output_dir, + args.disable_overwrite, + args.enable_profiling, + args.port, + args.plugin_dir, + ) if __name__ == "__main__": diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index de64865fee..c5c8b01eb1 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -223,18 +223,20 @@ def __init__( config: Config, output_dir: str = "debug_output", enable_profiling: bool = False, + disable_overwrite: bool = False, ) -> None: model, auxiliary_models = get_debug_inference_model(config) super().__init__(config, model, auxiliary_models, 0) self.taskset = get_buffer_reader(config.buffer.explorer_input.tasksets[0]) self.output_dir = output_dir self.enable_profiling = enable_profiling - # if output dir is not empty, change to a new dir with datetime suffix - if os.path.isdir(self.output_dir) and os.listdir(self.output_dir): - suffix = time.strftime("%Y%m%d%H%M%S", time.localtime()) - new_output_dir = f"{self.output_dir}_{suffix}" - self.output_dir = new_output_dir - self.logger.info(f"Debug output directory: {self.output_dir}") + if disable_overwrite: + # if output dir is not empty, change to a new dir with datetime suffix + if os.path.isdir(self.output_dir) and os.listdir(self.output_dir): + suffix = time.strftime("%Y%m%d%H%M%S", time.localtime()) + new_output_dir = f"{self.output_dir}_{suffix}" + self.output_dir = new_output_dir + self.logger.info(f"Debug output directory: {self.output_dir}") os.makedirs(self.output_dir, exist_ok=True) self.output_profiling_file = os.path.join( self.output_dir, @@ -270,8 +272,6 @@ async def debug(self) -> None: await self.sqlite_writer.write_async(exps) if status.ok: print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}") - for exp in exps: - print(f"Generated experience:\n{exp}") else: self.logger.error(f"Task {task.task_id} failed with message: {status.message}") self.logger.info("Debugging completed.") diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 8ea9446061..4ddfadcf9d 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -212,8 +212,10 @@ def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int) def log(self, data: dict, step: int, commit: bool = False) -> None: """Log metrics.""" - mlflow.log_metrics(metrics=data, step=step) self.console_logger.info(f"Step {step}: {data}") + # Replace all '@' in keys with '_at_', as MLflow does not support '@' in metric names + data = {k.replace("@", "_at_"): v for k, v in data.items()} + mlflow.log_metrics(metrics=data, step=step) def close(self) -> None: mlflow.end_run()