|
| 1 | +# Copyright 2025 Cisco Systems, Inc. and its affiliates |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# |
| 15 | +# SPDX-License-Identifier: Apache-2.0 |
| 16 | + |
| 17 | +"""visualize_tool.py.""" |
| 18 | + |
| 19 | + |
1 | 20 | import argparse |
2 | 21 | import os |
3 | 22 |
|
|
8 | 27 | from infscale.configs.job import JobConfig |
9 | 28 |
|
10 | 29 |
|
11 | | -def load_job_config(path: str) -> JobConfig: |
| 30 | +def get_job_data(path: str) -> tuple[JobConfig, str]: |
| 31 | + """Load job data composed by JobConfig and yaml file name.""" |
12 | 32 | with open(path) as f: |
13 | 33 | data = yaml.safe_load(f) |
| 34 | + file_name = os.path.splitext(os.path.basename(path))[0] |
14 | 35 |
|
15 | | - return JobConfig(**data) |
| 36 | + return JobConfig(**data), file_name |
16 | 37 |
|
17 | 38 |
|
18 | 39 | def build_graph(job: JobConfig) -> tuple[nx.DiGraph, dict[str, int]]: |
@@ -42,7 +63,7 @@ def build_graph(job: JobConfig) -> tuple[nx.DiGraph, dict[str, int]]: |
42 | 63 | def draw_graph( |
43 | 64 | graph: nx.DiGraph, |
44 | 65 | worker_stage: dict[str, int], |
45 | | - job_id: str, |
| 66 | + file_name: str, |
46 | 67 | output_path: str = "", |
47 | 68 | ) -> None: |
48 | 69 | """Draw graph where worker_stage maps node -> stage (start).""" |
@@ -151,32 +172,37 @@ def draw_graph( |
151 | 172 | ax.set_axis_off() |
152 | 173 | plt.tight_layout() |
153 | 174 | if output_path: |
154 | | - os.makedirs(output_path, exist_ok=True) |
155 | | - output_file = os.path.join(output_path, f"{job_id}.png") |
| 175 | + # save PNG in the same folder as this script |
| 176 | + script_dir = os.path.dirname(os.path.abspath(__file__)) |
| 177 | + output_file = os.path.join(script_dir, f"{file_name}.png") |
156 | 178 | plt.savefig(output_file, dpi=300, bbox_inches="tight") |
157 | 179 | print(f"Graph saved at: {output_file}") |
158 | | - else: |
159 | | - print("Graph opened in a new window.") |
160 | | - plt.show() |
| 180 | + |
| 181 | + print("Graph opened in a new window.") |
| 182 | + plt.show() |
161 | 183 |
|
162 | 184 |
|
163 | 185 | def main(): |
164 | 186 | parser = argparse.ArgumentParser(description="Visualize JobConfig flow graph") |
165 | 187 | parser.add_argument("config_path", help="Path to job YAML config") |
166 | 188 | parser.add_argument( |
167 | | - "-o", "--output", help="Directory to save output image (optional)", default=None |
| 189 | + "-s", |
| 190 | + "--save", |
| 191 | + action="store_true", |
| 192 | + help="Save the graph as a PNG file in the same directory as the script instead of displaying it", |
168 | 193 | ) |
169 | 194 | args = parser.parse_args() |
170 | 195 |
|
171 | 196 | try: |
172 | | - config = load_job_config(args.config_path) |
| 197 | + config, file_name = get_job_data(args.config_path) |
173 | 198 | except FileNotFoundError as e: |
174 | 199 | print(f"Error while loading file: {e}") |
175 | 200 | return |
176 | 201 |
|
177 | 202 | graph, worker_stage = build_graph(config) |
| 203 | + output_path = os.path.dirname(os.path.abspath(__file__)) if args.save else None |
178 | 204 | try: |
179 | | - draw_graph(graph, worker_stage, config.job_id, args.output) |
| 205 | + draw_graph(graph, worker_stage, file_name, output_path) |
180 | 206 | except nx.exception.NetworkXError as e: |
181 | 207 | print(f"Error while drawing graph: {e}") |
182 | 208 | except KeyboardInterrupt: |
|
0 commit comments