|
| 1 | +# ------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 4 | +# or more contributor license agreements. See the NOTICE file |
| 5 | +# distributed with this work for additional information |
| 6 | +# regarding copyright ownership. The ASF licenses this file |
| 7 | +# to you under the Apache License, Version 2.0 (the |
| 8 | +# "License"); you may not use this file except in compliance |
| 9 | +# with the License. You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, |
| 14 | +# software distributed under the License is distributed on an |
| 15 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 16 | +# KIND, either express or implied. See the License for the |
| 17 | +# specific language governing permissions and limitations |
| 18 | +# under the License. |
| 19 | +# |
| 20 | +# ------------------------------------------------------------- |
| 21 | + |
| 22 | +import sys |
| 23 | +import re |
| 24 | +import networkx as nx |
| 25 | +import matplotlib.pyplot as plt |
| 26 | + |
| 27 | +try: |
| 28 | + import pygraphviz |
| 29 | + from networkx.drawing.nx_agraph import graphviz_layout |
| 30 | + HAS_PYGRAPHVIZ = True |
| 31 | +except ImportError: |
| 32 | + HAS_PYGRAPHVIZ = False |
| 33 | + print("[WARNING] pygraphviz not found. Please install via 'pip install pygraphviz'.\n" |
| 34 | + "If not installed, we will use an alternative layout (spring_layout).") |
| 35 | + |
| 36 | + |
| 37 | +def parse_line(line: str): |
| 38 | + """ |
| 39 | + Parse a single line from the trace file to extract: |
| 40 | + - Node ID |
| 41 | + - Operation (hop name) |
| 42 | + - Kind (e.g., FOUT, LOUT, NREF) |
| 43 | + - Total cost |
| 44 | + - Weight |
| 45 | + - Refs (list of IDs that this node depends on) |
| 46 | + """ |
| 47 | + |
| 48 | + # 1) Match a node ID in the form of "(R)" or "(<number>)" |
| 49 | + match_id = re.match(r'^\((R|\d+)\)', line) |
| 50 | + if not match_id: |
| 51 | + return None |
| 52 | + node_id = match_id.group(1) |
| 53 | + |
| 54 | + # 2) The remaining string after the node ID |
| 55 | + after_id = line[match_id.end():].strip() |
| 56 | + |
| 57 | + # Extract operation (hop name) before the first "[" |
| 58 | + match_label = re.search(r'^(.*?)\s*\[', after_id) |
| 59 | + if match_label: |
| 60 | + operation = match_label.group(1).strip() |
| 61 | + else: |
| 62 | + operation = after_id.strip() |
| 63 | + |
| 64 | + # 3) Extract the kind (content inside the first pair of brackets "[]") |
| 65 | + match_bracket = re.search(r'\[([^\]]+)\]', after_id) |
| 66 | + if match_bracket: |
| 67 | + kind = match_bracket.group(1).strip() |
| 68 | + else: |
| 69 | + kind = "" |
| 70 | + |
| 71 | + # 4) Extract total and weight from the content inside curly braces "{}" |
| 72 | + total = "" |
| 73 | + weight = "" |
| 74 | + match_curly = re.search(r'\{([^}]+)\}', line) |
| 75 | + if match_curly: |
| 76 | + curly_content = match_curly.group(1) |
| 77 | + m_total = re.search(r'Total:\s*([\d\.]+)', curly_content) |
| 78 | + m_weight = re.search(r'Weight:\s*([\d\.]+)', curly_content) |
| 79 | + if m_total: |
| 80 | + total = m_total.group(1) |
| 81 | + if m_weight: |
| 82 | + weight = m_weight.group(1) |
| 83 | + |
| 84 | + # 5) Extract reference nodes: look for the first parenthesis containing numbers after the hop name |
| 85 | + match_refs = re.search(r'\(\s*(\d+(?:,\d+)*)\s*\)', after_id) |
| 86 | + if match_refs: |
| 87 | + ref_str = match_refs.group(1) |
| 88 | + refs = [r.strip() for r in ref_str.split(',') if r.strip().isdigit()] |
| 89 | + else: |
| 90 | + refs = [] |
| 91 | + |
| 92 | + return { |
| 93 | + 'node_id': node_id, |
| 94 | + 'operation': operation, |
| 95 | + 'kind': kind, |
| 96 | + 'total': total, |
| 97 | + 'weight': weight, |
| 98 | + 'refs': refs |
| 99 | + } |
| 100 | + |
| 101 | + |
| 102 | +def build_dag_from_file(filename: str): |
| 103 | + """ |
| 104 | + Read a trace file line by line and build a directed acyclic graph (DAG) using NetworkX. |
| 105 | + """ |
| 106 | + G = nx.DiGraph() |
| 107 | + with open(filename, 'r', encoding='utf-8') as f: |
| 108 | + for line in f: |
| 109 | + line = line.strip() |
| 110 | + if not line: |
| 111 | + continue |
| 112 | + |
| 113 | + info = parse_line(line) |
| 114 | + if not info: |
| 115 | + continue |
| 116 | + |
| 117 | + node_id = info['node_id'] |
| 118 | + operation = info['operation'] |
| 119 | + kind = info['kind'] |
| 120 | + total = info['total'] |
| 121 | + weight = info['weight'] |
| 122 | + refs = info['refs'] |
| 123 | + |
| 124 | + # Add node with attributes |
| 125 | + G.add_node(node_id, label=operation, kind=kind, total=total, weight=weight) |
| 126 | + |
| 127 | + # Add edges from references to this node |
| 128 | + for r in refs: |
| 129 | + if r not in G: |
| 130 | + G.add_node(r, label=r, kind="", total="", weight="") |
| 131 | + G.add_edge(r, node_id) |
| 132 | + return G |
| 133 | + |
| 134 | + |
| 135 | +def main(): |
| 136 | + """ |
| 137 | + Main function that: |
| 138 | + - Reads a filename from command-line arguments |
| 139 | + - Builds a DAG from the file |
| 140 | + - Draws and displays the DAG using matplotlib |
| 141 | + """ |
| 142 | + |
| 143 | + # Get filename from command-line argument |
| 144 | + if len(sys.argv) < 2: |
| 145 | + print("[ERROR] No filename provided.\nUsage: python plot_federated_dag.py <filename>") |
| 146 | + sys.exit(1) |
| 147 | + filename = sys.argv[1] |
| 148 | + |
| 149 | + print(f"[INFO] Running with filename '{filename}'") |
| 150 | + |
| 151 | + # Build the DAG |
| 152 | + G = build_dag_from_file(filename) |
| 153 | + |
| 154 | + # Print debug info: nodes and edges |
| 155 | + print("Nodes:", G.nodes(data=True)) |
| 156 | + print("Edges:", list(G.edges())) |
| 157 | + |
| 158 | + # Decide on layout |
| 159 | + if HAS_PYGRAPHVIZ: |
| 160 | + # graphviz_layout with rankdir=BT (bottom to top), etc. |
| 161 | + pos = graphviz_layout(G, prog='dot', args='-Grankdir=BT -Gnodesep=0.5 -Granksep=0.8') |
| 162 | + else: |
| 163 | + # Fallback layout if pygraphviz is not installed |
| 164 | + pos = nx.spring_layout(G, seed=42) |
| 165 | + |
| 166 | + # Dynamically adjust figure size based on number of nodes |
| 167 | + node_count = len(G.nodes()) |
| 168 | + fig_width = 10 + node_count / 10.0 |
| 169 | + fig_height = 6 + node_count / 10.0 |
| 170 | + plt.figure(figsize=(fig_width, fig_height), facecolor='white', dpi=300) |
| 171 | + ax = plt.gca() |
| 172 | + ax.set_facecolor('white') |
| 173 | + |
| 174 | + # Generate labels for each node in the format: |
| 175 | + # node_id: operation_name |
| 176 | + # C<total> (W<weight>) |
| 177 | + labels = { |
| 178 | + n: f"{n}: {G.nodes[n].get('label', n)}\n C{G.nodes[n].get('total', '')} (W{G.nodes[n].get('weight', '')})" |
| 179 | + for n in G.nodes() |
| 180 | + } |
| 181 | + |
| 182 | + # Function to determine color based on 'kind' |
| 183 | + def get_color(n): |
| 184 | + k = G.nodes[n].get('kind', '').lower() |
| 185 | + if k == 'fout': |
| 186 | + return 'tomato' |
| 187 | + elif k == 'lout': |
| 188 | + return 'dodgerblue' |
| 189 | + elif k == 'nref': |
| 190 | + return 'mediumpurple' |
| 191 | + else: |
| 192 | + return 'mediumseagreen' |
| 193 | + |
| 194 | + # Determine node shapes based on operation name: |
| 195 | + # - '^' (triangle) if the label contains "twrite" |
| 196 | + # - 's' (square) if the label contains "tread" |
| 197 | + # - 'o' (circle) otherwise |
| 198 | + triangle_nodes = [n for n in G.nodes() if 'twrite' in G.nodes[n].get('label', '').lower()] |
| 199 | + square_nodes = [n for n in G.nodes() if 'tread' in G.nodes[n].get('label', '').lower()] |
| 200 | + other_nodes = [ |
| 201 | + n for n in G.nodes() |
| 202 | + if 'twrite' not in G.nodes[n].get('label', '').lower() and |
| 203 | + 'tread' not in G.nodes[n].get('label', '').lower() |
| 204 | + ] |
| 205 | + |
| 206 | + # Colors for each group |
| 207 | + triangle_colors = [get_color(n) for n in triangle_nodes] |
| 208 | + square_colors = [get_color(n) for n in square_nodes] |
| 209 | + other_colors = [get_color(n) for n in other_nodes] |
| 210 | + |
| 211 | + # Draw nodes group-wise |
| 212 | + node_collection_triangle = nx.draw_networkx_nodes( |
| 213 | + G, pos, nodelist=triangle_nodes, node_size=800, |
| 214 | + node_color=triangle_colors, node_shape='^', ax=ax |
| 215 | + ) |
| 216 | + node_collection_square = nx.draw_networkx_nodes( |
| 217 | + G, pos, nodelist=square_nodes, node_size=800, |
| 218 | + node_color=square_colors, node_shape='s', ax=ax |
| 219 | + ) |
| 220 | + node_collection_other = nx.draw_networkx_nodes( |
| 221 | + G, pos, nodelist=other_nodes, node_size=800, |
| 222 | + node_color=other_colors, node_shape='o', ax=ax |
| 223 | + ) |
| 224 | + |
| 225 | + # Set z-order for nodes, edges, and labels |
| 226 | + node_collection_triangle.set_zorder(1) |
| 227 | + node_collection_square.set_zorder(1) |
| 228 | + node_collection_other.set_zorder(1) |
| 229 | + |
| 230 | + edge_collection = nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle='->', ax=ax) |
| 231 | + if isinstance(edge_collection, list): |
| 232 | + for ec in edge_collection: |
| 233 | + ec.set_zorder(2) |
| 234 | + else: |
| 235 | + edge_collection.set_zorder(2) |
| 236 | + |
| 237 | + label_dict = nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, ax=ax) |
| 238 | + for text in label_dict.values(): |
| 239 | + text.set_zorder(3) |
| 240 | + |
| 241 | + # Set the title |
| 242 | + plt.title("Program Level Federated Plan", fontsize=14, fontweight="bold") |
| 243 | + |
| 244 | + # Provide a small legend on the top-right or top-left |
| 245 | + plt.text(1, 1, |
| 246 | + "[LABEL]\n hopID: hopName\n C(Total) (W(Weight))", |
| 247 | + fontsize=12, ha='right', va='top', transform=ax.transAxes) |
| 248 | + |
| 249 | + # Example mini-legend for different 'kind' values |
| 250 | + plt.scatter(0.05, 0.95, color='dodgerblue', s=200, transform=ax.transAxes) |
| 251 | + plt.scatter(0.18, 0.95, color='tomato', s=200, transform=ax.transAxes) |
| 252 | + plt.scatter(0.31, 0.95, color='mediumpurple', s=200, transform=ax.transAxes) |
| 253 | + |
| 254 | + plt.text(0.08, 0.95, "LOUT", fontsize=12, va='center', transform=ax.transAxes) |
| 255 | + plt.text(0.21, 0.95, "FOUT", fontsize=12, va='center', transform=ax.transAxes) |
| 256 | + plt.text(0.34, 0.95, "NREF", fontsize=12, va='center', transform=ax.transAxes) |
| 257 | + |
| 258 | + plt.axis("off") |
| 259 | + |
| 260 | + # Save the plot to a file with the same name as the input file, but with a .png extension |
| 261 | + output_filename = f"{filename.rsplit('.', 1)[0]}.png" |
| 262 | + plt.savefig(output_filename, format='png', dpi=300, bbox_inches='tight') |
| 263 | + |
| 264 | + plt.show() |
| 265 | + |
| 266 | + |
| 267 | +if __name__ == '__main__': |
| 268 | + main() |
0 commit comments