Skip to content

Commit 693ef52

Browse files
min-gukmboehm7
authored andcommitted
[SYSTEMDS-3790] Extended optimizer for federated execution plans
Closes #2238.
1 parent 2e68ad3 commit 693ef52

14 files changed

+1842
-656
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
22+
import sys
23+
import re
24+
import networkx as nx
25+
import matplotlib.pyplot as plt
26+
27+
try:
28+
import pygraphviz
29+
from networkx.drawing.nx_agraph import graphviz_layout
30+
HAS_PYGRAPHVIZ = True
31+
except ImportError:
32+
HAS_PYGRAPHVIZ = False
33+
print("[WARNING] pygraphviz not found. Please install via 'pip install pygraphviz'.\n"
34+
"If not installed, we will use an alternative layout (spring_layout).")
35+
36+
37+
def parse_line(line: str):
38+
"""
39+
Parse a single line from the trace file to extract:
40+
- Node ID
41+
- Operation (hop name)
42+
- Kind (e.g., FOUT, LOUT, NREF)
43+
- Total cost
44+
- Weight
45+
- Refs (list of IDs that this node depends on)
46+
"""
47+
48+
# 1) Match a node ID in the form of "(R)" or "(<number>)"
49+
match_id = re.match(r'^\((R|\d+)\)', line)
50+
if not match_id:
51+
return None
52+
node_id = match_id.group(1)
53+
54+
# 2) The remaining string after the node ID
55+
after_id = line[match_id.end():].strip()
56+
57+
# Extract operation (hop name) before the first "["
58+
match_label = re.search(r'^(.*?)\s*\[', after_id)
59+
if match_label:
60+
operation = match_label.group(1).strip()
61+
else:
62+
operation = after_id.strip()
63+
64+
# 3) Extract the kind (content inside the first pair of brackets "[]")
65+
match_bracket = re.search(r'\[([^\]]+)\]', after_id)
66+
if match_bracket:
67+
kind = match_bracket.group(1).strip()
68+
else:
69+
kind = ""
70+
71+
# 4) Extract total and weight from the content inside curly braces "{}"
72+
total = ""
73+
weight = ""
74+
match_curly = re.search(r'\{([^}]+)\}', line)
75+
if match_curly:
76+
curly_content = match_curly.group(1)
77+
m_total = re.search(r'Total:\s*([\d\.]+)', curly_content)
78+
m_weight = re.search(r'Weight:\s*([\d\.]+)', curly_content)
79+
if m_total:
80+
total = m_total.group(1)
81+
if m_weight:
82+
weight = m_weight.group(1)
83+
84+
# 5) Extract reference nodes: look for the first parenthesis containing numbers after the hop name
85+
match_refs = re.search(r'\(\s*(\d+(?:,\d+)*)\s*\)', after_id)
86+
if match_refs:
87+
ref_str = match_refs.group(1)
88+
refs = [r.strip() for r in ref_str.split(',') if r.strip().isdigit()]
89+
else:
90+
refs = []
91+
92+
return {
93+
'node_id': node_id,
94+
'operation': operation,
95+
'kind': kind,
96+
'total': total,
97+
'weight': weight,
98+
'refs': refs
99+
}
100+
101+
102+
def build_dag_from_file(filename: str):
103+
"""
104+
Read a trace file line by line and build a directed acyclic graph (DAG) using NetworkX.
105+
"""
106+
G = nx.DiGraph()
107+
with open(filename, 'r', encoding='utf-8') as f:
108+
for line in f:
109+
line = line.strip()
110+
if not line:
111+
continue
112+
113+
info = parse_line(line)
114+
if not info:
115+
continue
116+
117+
node_id = info['node_id']
118+
operation = info['operation']
119+
kind = info['kind']
120+
total = info['total']
121+
weight = info['weight']
122+
refs = info['refs']
123+
124+
# Add node with attributes
125+
G.add_node(node_id, label=operation, kind=kind, total=total, weight=weight)
126+
127+
# Add edges from references to this node
128+
for r in refs:
129+
if r not in G:
130+
G.add_node(r, label=r, kind="", total="", weight="")
131+
G.add_edge(r, node_id)
132+
return G
133+
134+
135+
def main():
136+
"""
137+
Main function that:
138+
- Reads a filename from command-line arguments
139+
- Builds a DAG from the file
140+
- Draws and displays the DAG using matplotlib
141+
"""
142+
143+
# Get filename from command-line argument
144+
if len(sys.argv) < 2:
145+
print("[ERROR] No filename provided.\nUsage: python plot_federated_dag.py <filename>")
146+
sys.exit(1)
147+
filename = sys.argv[1]
148+
149+
print(f"[INFO] Running with filename '{filename}'")
150+
151+
# Build the DAG
152+
G = build_dag_from_file(filename)
153+
154+
# Print debug info: nodes and edges
155+
print("Nodes:", G.nodes(data=True))
156+
print("Edges:", list(G.edges()))
157+
158+
# Decide on layout
159+
if HAS_PYGRAPHVIZ:
160+
# graphviz_layout with rankdir=BT (bottom to top), etc.
161+
pos = graphviz_layout(G, prog='dot', args='-Grankdir=BT -Gnodesep=0.5 -Granksep=0.8')
162+
else:
163+
# Fallback layout if pygraphviz is not installed
164+
pos = nx.spring_layout(G, seed=42)
165+
166+
# Dynamically adjust figure size based on number of nodes
167+
node_count = len(G.nodes())
168+
fig_width = 10 + node_count / 10.0
169+
fig_height = 6 + node_count / 10.0
170+
plt.figure(figsize=(fig_width, fig_height), facecolor='white', dpi=300)
171+
ax = plt.gca()
172+
ax.set_facecolor('white')
173+
174+
# Generate labels for each node in the format:
175+
# node_id: operation_name
176+
# C<total> (W<weight>)
177+
labels = {
178+
n: f"{n}: {G.nodes[n].get('label', n)}\n C{G.nodes[n].get('total', '')} (W{G.nodes[n].get('weight', '')})"
179+
for n in G.nodes()
180+
}
181+
182+
# Function to determine color based on 'kind'
183+
def get_color(n):
184+
k = G.nodes[n].get('kind', '').lower()
185+
if k == 'fout':
186+
return 'tomato'
187+
elif k == 'lout':
188+
return 'dodgerblue'
189+
elif k == 'nref':
190+
return 'mediumpurple'
191+
else:
192+
return 'mediumseagreen'
193+
194+
# Determine node shapes based on operation name:
195+
# - '^' (triangle) if the label contains "twrite"
196+
# - 's' (square) if the label contains "tread"
197+
# - 'o' (circle) otherwise
198+
triangle_nodes = [n for n in G.nodes() if 'twrite' in G.nodes[n].get('label', '').lower()]
199+
square_nodes = [n for n in G.nodes() if 'tread' in G.nodes[n].get('label', '').lower()]
200+
other_nodes = [
201+
n for n in G.nodes()
202+
if 'twrite' not in G.nodes[n].get('label', '').lower() and
203+
'tread' not in G.nodes[n].get('label', '').lower()
204+
]
205+
206+
# Colors for each group
207+
triangle_colors = [get_color(n) for n in triangle_nodes]
208+
square_colors = [get_color(n) for n in square_nodes]
209+
other_colors = [get_color(n) for n in other_nodes]
210+
211+
# Draw nodes group-wise
212+
node_collection_triangle = nx.draw_networkx_nodes(
213+
G, pos, nodelist=triangle_nodes, node_size=800,
214+
node_color=triangle_colors, node_shape='^', ax=ax
215+
)
216+
node_collection_square = nx.draw_networkx_nodes(
217+
G, pos, nodelist=square_nodes, node_size=800,
218+
node_color=square_colors, node_shape='s', ax=ax
219+
)
220+
node_collection_other = nx.draw_networkx_nodes(
221+
G, pos, nodelist=other_nodes, node_size=800,
222+
node_color=other_colors, node_shape='o', ax=ax
223+
)
224+
225+
# Set z-order for nodes, edges, and labels
226+
node_collection_triangle.set_zorder(1)
227+
node_collection_square.set_zorder(1)
228+
node_collection_other.set_zorder(1)
229+
230+
edge_collection = nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle='->', ax=ax)
231+
if isinstance(edge_collection, list):
232+
for ec in edge_collection:
233+
ec.set_zorder(2)
234+
else:
235+
edge_collection.set_zorder(2)
236+
237+
label_dict = nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, ax=ax)
238+
for text in label_dict.values():
239+
text.set_zorder(3)
240+
241+
# Set the title
242+
plt.title("Program Level Federated Plan", fontsize=14, fontweight="bold")
243+
244+
# Provide a small legend on the top-right or top-left
245+
plt.text(1, 1,
246+
"[LABEL]\n hopID: hopName\n C(Total) (W(Weight))",
247+
fontsize=12, ha='right', va='top', transform=ax.transAxes)
248+
249+
# Example mini-legend for different 'kind' values
250+
plt.scatter(0.05, 0.95, color='dodgerblue', s=200, transform=ax.transAxes)
251+
plt.scatter(0.18, 0.95, color='tomato', s=200, transform=ax.transAxes)
252+
plt.scatter(0.31, 0.95, color='mediumpurple', s=200, transform=ax.transAxes)
253+
254+
plt.text(0.08, 0.95, "LOUT", fontsize=12, va='center', transform=ax.transAxes)
255+
plt.text(0.21, 0.95, "FOUT", fontsize=12, va='center', transform=ax.transAxes)
256+
plt.text(0.34, 0.95, "NREF", fontsize=12, va='center', transform=ax.transAxes)
257+
258+
plt.axis("off")
259+
260+
# Save the plot to a file with the same name as the input file, but with a .png extension
261+
output_filename = f"{filename.rsplit('.', 1)[0]}.png"
262+
plt.savefig(output_filename, format='png', dpi=300, bbox_inches='tight')
263+
264+
plt.show()
265+
266+
267+
if __name__ == '__main__':
268+
main()

0 commit comments

Comments
 (0)