Skip to content

Commit 879dc48

Browse files
committed
program level fed planer
1 parent fd9479d commit 879dc48

File tree

8 files changed

+909
-394
lines changed

8 files changed

+909
-394
lines changed

graph.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import sys
2+
import re
3+
import networkx as nx
4+
import matplotlib.pyplot as plt
5+
6+
try:
7+
import pygraphviz
8+
from networkx.drawing.nx_agraph import graphviz_layout
9+
HAS_PYGRAPHVIZ = True
10+
except ImportError:
11+
HAS_PYGRAPHVIZ = False
12+
print("[WARNING] pygraphviz not found. Please install via 'pip install pygraphviz'.\n"
13+
"If not installed, we will use an alternative layout (spring_layout).")
14+
15+
16+
def parse_line(line: str):
17+
"""
18+
Parse a single line from the trace file to extract:
19+
- Node ID
20+
- Operation (hop name)
21+
- Kind (e.g., FOUT, LOUT, NREF)
22+
- Total cost
23+
- Weight
24+
- Refs (list of IDs that this node depends on)
25+
"""
26+
27+
# 1) Match a node ID in the form of "(R)" or "(<number>)"
28+
match_id = re.match(r'^\((R|\d+)\)', line)
29+
if not match_id:
30+
return None
31+
node_id = match_id.group(1)
32+
33+
# 2) The remaining string after the node ID
34+
after_id = line[match_id.end():].strip()
35+
36+
# Extract operation (hop name) before the first "["
37+
match_label = re.search(r'^(.*?)\s*\[', after_id)
38+
if match_label:
39+
operation = match_label.group(1).strip()
40+
else:
41+
operation = after_id.strip()
42+
43+
# 3) Extract the kind (content inside the first pair of brackets "[]")
44+
match_bracket = re.search(r'\[([^\]]+)\]', after_id)
45+
if match_bracket:
46+
kind = match_bracket.group(1).strip()
47+
else:
48+
kind = ""
49+
50+
# 4) Extract total and weight from the content inside curly braces "{}"
51+
total = ""
52+
weight = ""
53+
match_curly = re.search(r'\{([^}]+)\}', line)
54+
if match_curly:
55+
curly_content = match_curly.group(1)
56+
m_total = re.search(r'Total:\s*([\d\.]+)', curly_content)
57+
m_weight = re.search(r'Weight:\s*([\d\.]+)', curly_content)
58+
if m_total:
59+
total = m_total.group(1)
60+
if m_weight:
61+
weight = m_weight.group(1)
62+
63+
# 5) Extract reference nodes: look for the first parenthesis containing numbers after the hop name
64+
match_refs = re.search(r'\(\s*(\d+(?:,\d+)*)\s*\)', after_id)
65+
if match_refs:
66+
ref_str = match_refs.group(1)
67+
refs = [r.strip() for r in ref_str.split(',') if r.strip().isdigit()]
68+
else:
69+
refs = []
70+
71+
return {
72+
'node_id': node_id,
73+
'operation': operation,
74+
'kind': kind,
75+
'total': total,
76+
'weight': weight,
77+
'refs': refs
78+
}
79+
80+
81+
def build_dag_from_file(filename: str):
82+
"""
83+
Read a trace file line by line and build a directed acyclic graph (DAG) using NetworkX.
84+
"""
85+
G = nx.DiGraph()
86+
with open(filename, 'r', encoding='utf-8') as f:
87+
for line in f:
88+
line = line.strip()
89+
if not line:
90+
continue
91+
92+
info = parse_line(line)
93+
if not info:
94+
continue
95+
96+
node_id = info['node_id']
97+
operation = info['operation']
98+
kind = info['kind']
99+
total = info['total']
100+
weight = info['weight']
101+
refs = info['refs']
102+
103+
# Add node with attributes
104+
G.add_node(node_id, label=operation, kind=kind, total=total, weight=weight)
105+
106+
# Add edges from references to this node
107+
for r in refs:
108+
if r not in G:
109+
G.add_node(r, label=r, kind="", total="", weight="")
110+
G.add_edge(r, node_id)
111+
return G
112+
113+
114+
def main():
115+
"""
116+
Main function that:
117+
- Reads a filename from command-line arguments
118+
- Builds a DAG from the file
119+
- Draws and displays the DAG using matplotlib
120+
"""
121+
122+
# Get filename from command-line argument
123+
if len(sys.argv) < 2:
124+
print("[ERROR] No filename provided.\nUsage: python plot_federated_dag.py <filename>")
125+
sys.exit(1)
126+
filename = sys.argv[1]
127+
128+
print(f"[INFO] Running with filename '{filename}'")
129+
130+
# Build the DAG
131+
G = build_dag_from_file(filename)
132+
133+
# Print debug info: nodes and edges
134+
print("Nodes:", G.nodes(data=True))
135+
print("Edges:", list(G.edges()))
136+
137+
# Decide on layout
138+
if HAS_PYGRAPHVIZ:
139+
# graphviz_layout with rankdir=BT (bottom to top), etc.
140+
pos = graphviz_layout(G, prog='dot', args='-Grankdir=BT -Gnodesep=0.5 -Granksep=0.8')
141+
else:
142+
# Fallback layout if pygraphviz is not installed
143+
pos = nx.spring_layout(G, seed=42)
144+
145+
# Dynamically adjust figure size based on number of nodes
146+
node_count = len(G.nodes())
147+
fig_width = 10 + node_count / 10.0
148+
fig_height = 6 + node_count / 10.0
149+
plt.figure(figsize=(fig_width, fig_height), facecolor='white', dpi=300)
150+
ax = plt.gca()
151+
ax.set_facecolor('white')
152+
153+
# Generate labels for each node in the format:
154+
# node_id: operation_name
155+
# C<total> (W<weight>)
156+
labels = {
157+
n: f"{n}: {G.nodes[n].get('label', n)}\n C{G.nodes[n].get('total', '')} (W{G.nodes[n].get('weight', '')})"
158+
for n in G.nodes()
159+
}
160+
161+
# Function to determine color based on 'kind'
162+
def get_color(n):
163+
k = G.nodes[n].get('kind', '').lower()
164+
if k == 'fout':
165+
return 'tomato'
166+
elif k == 'lout':
167+
return 'dodgerblue'
168+
elif k == 'nref':
169+
return 'mediumpurple'
170+
else:
171+
return 'mediumseagreen'
172+
173+
# Determine node shapes based on operation name:
174+
# - '^' (triangle) if the label contains "twrite"
175+
# - 's' (square) if the label contains "tread"
176+
# - 'o' (circle) otherwise
177+
triangle_nodes = [n for n in G.nodes() if 'twrite' in G.nodes[n].get('label', '').lower()]
178+
square_nodes = [n for n in G.nodes() if 'tread' in G.nodes[n].get('label', '').lower()]
179+
other_nodes = [
180+
n for n in G.nodes()
181+
if 'twrite' not in G.nodes[n].get('label', '').lower() and
182+
'tread' not in G.nodes[n].get('label', '').lower()
183+
]
184+
185+
# Colors for each group
186+
triangle_colors = [get_color(n) for n in triangle_nodes]
187+
square_colors = [get_color(n) for n in square_nodes]
188+
other_colors = [get_color(n) for n in other_nodes]
189+
190+
# Draw nodes group-wise
191+
node_collection_triangle = nx.draw_networkx_nodes(
192+
G, pos, nodelist=triangle_nodes, node_size=800,
193+
node_color=triangle_colors, node_shape='^', ax=ax
194+
)
195+
node_collection_square = nx.draw_networkx_nodes(
196+
G, pos, nodelist=square_nodes, node_size=800,
197+
node_color=square_colors, node_shape='s', ax=ax
198+
)
199+
node_collection_other = nx.draw_networkx_nodes(
200+
G, pos, nodelist=other_nodes, node_size=800,
201+
node_color=other_colors, node_shape='o', ax=ax
202+
)
203+
204+
# Set z-order for nodes, edges, and labels
205+
node_collection_triangle.set_zorder(1)
206+
node_collection_square.set_zorder(1)
207+
node_collection_other.set_zorder(1)
208+
209+
edge_collection = nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle='->', ax=ax)
210+
if isinstance(edge_collection, list):
211+
for ec in edge_collection:
212+
ec.set_zorder(2)
213+
else:
214+
edge_collection.set_zorder(2)
215+
216+
label_dict = nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, ax=ax)
217+
for text in label_dict.values():
218+
text.set_zorder(3)
219+
220+
# Set the title
221+
plt.title("Program Level Federated Plan", fontsize=14, fontweight="bold")
222+
223+
# Provide a small legend on the top-right or top-left
224+
plt.text(1, 1,
225+
"[LABEL]\n hopID: hopName\n C(Total) (W(Weight))",
226+
fontsize=12, ha='right', va='top', transform=ax.transAxes)
227+
228+
# Example mini-legend for different 'kind' values
229+
plt.scatter(0.05, 0.95, color='dodgerblue', s=200, transform=ax.transAxes)
230+
plt.scatter(0.18, 0.95, color='tomato', s=200, transform=ax.transAxes)
231+
plt.scatter(0.31, 0.95, color='mediumpurple', s=200, transform=ax.transAxes)
232+
233+
plt.text(0.08, 0.95, "LOUT", fontsize=12, va='center', transform=ax.transAxes)
234+
plt.text(0.21, 0.95, "FOUT", fontsize=12, va='center', transform=ax.transAxes)
235+
plt.text(0.34, 0.95, "NREF", fontsize=12, va='center', transform=ax.transAxes)
236+
237+
plt.axis("off")
238+
239+
# Save the plot to a file with the same name as the input file, but with a .png extension
240+
output_filename = f"{filename.rsplit('.', 1)[0]}.png"
241+
plt.savefig(output_filename, format='png', dpi=300, bbox_inches='tight')
242+
243+
plt.show()
244+
245+
246+
if __name__ == '__main__':
247+
main()

0 commit comments

Comments
 (0)