diff --git a/devtools/visualization/__init__.py b/devtools/visualization/__init__.py index df1d74c7fae..8e91d7ffdb2 100644 --- a/devtools/visualization/__init__.py +++ b/devtools/visualization/__init__.py @@ -9,4 +9,5 @@ SingletonModelExplorerServer, visualize, visualize_graph, + visualize_model_explorer, ) diff --git a/devtools/visualization/visualization_utils.py b/devtools/visualization/visualization_utils.py index d21d11082a3..b21a953f4d2 100644 --- a/devtools/visualization/visualization_utils.py +++ b/devtools/visualization/visualization_utils.py @@ -108,7 +108,7 @@ def visualize( **kwargs, ): """Wraps the visualize_from_config call from model_explorer. - For convenicence, figures out how to find the exported_program + For convenience, figures out how to find the exported_program from EdgeProgramManager and ExecutorchProgramManager for you. See https://github.com/google-ai-edge/model-explorer/wiki/4.-API-Guide#visualize-pytorch-models @@ -123,13 +123,22 @@ def visualize( ) if reuse_server: cur_config.set_reuse_server() - visualize_from_config( - cur_config, + visualize_model_explorer( + config=kwargs.pop("config", cur_config), no_open_in_browser=no_open_in_browser, **kwargs, ) +def visualize_model_explorer( + **kwargs, +): + """Wraps the visualize_from_config call from model_explorer.""" + visualize_from_config( + **kwargs, + ) + + def visualize_graph( graph_module: GraphModule, exported_program: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,