Skip to content

Commit a51b02b

Browse files
committed
Awesome graphs
1 parent e633e2b commit a51b02b

File tree

1 file changed

+111
-15
lines changed

1 file changed

+111
-15
lines changed

src/graph.py

Lines changed: 111 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import argparse
2-
from typing import Dict
2+
from typing import Dict, Literal
33

44
import matplotlib.pyplot as plt
55
import pandas
6+
from matplotlib.colors import TABLEAU_COLORS
67

78
parser = argparse.ArgumentParser(
89
description="Generate a graph with the execution times of the MPI/OMP MinHash implementations."
@@ -27,13 +28,17 @@
2728
help="Path to the directory where the PNG file will be saved (dot for current directory)",
2829
)
2930

31+
colors = list(TABLEAU_COLORS.values())
32+
3033

3134
def draw_graph(
3235
dists: Dict[str, pandas.DataFrame],
3336
x_label: str,
3437
y_label: str,
3538
title: str = None,
36-
save_path: str = None
39+
save_path: str = None,
40+
peakline: Literal['min', 'max', None] = None,
41+
legend_loc: str = 'best',
3742
):
3843
"""
3944
Draws a graph from the given data.
@@ -52,14 +57,56 @@ def draw_graph(
5257
# Add Style
5358
plt.style.use('classic')
5459

55-
# Plot the data
56-
for name, df in dists.items():
57-
if len(df) > 1:
60+
# Iterate over non-empty distributions
61+
_dists = filter(lambda t: len(t[1]) > 0, dists.items())
62+
63+
for i, (name, df) in enumerate(_dists):
64+
len_df = len(df)
65+
66+
# Get dist color
67+
color = colors[i % len(colors)]
68+
69+
if len_df > 1:
5870
# Proper distribution, plot it
59-
plt.plot(df[x_label], df[y_label], label=name, marker='o', markersize=3)
60-
elif len(df) == 1:
61-
# Single point, plot it as a line
62-
plt.axline((0, df[y_label].values[0]), slope=0, label=name)
71+
plt.plot(
72+
df[x_label],
73+
df[y_label],
74+
label=name,
75+
marker='o',
76+
markersize=3,
77+
color=color,
78+
)
79+
80+
# Check if we need to plot a peakline
81+
if peakline:
82+
peakline = str(peakline).lower()
83+
84+
# Get peak value
85+
if peakline == 'min':
86+
y = df[y_label].min()
87+
elif peakline == 'max':
88+
y = df[y_label].max()
89+
else:
90+
raise ValueError(f"Invalid peakline value: {peakline}")
91+
92+
# Plot the mix/max line
93+
if len_df == 1:
94+
# Plot with label
95+
plt.axline(
96+
(0, y),
97+
slope=0,
98+
label=name,
99+
color=color,
100+
linestyle='dashed',
101+
)
102+
else:
103+
# Plot without label
104+
plt.axline(
105+
(0, y),
106+
slope=0,
107+
color=color,
108+
linestyle='dashed',
109+
)
63110

64111
# Add labels
65112
plt.title(title)
@@ -70,7 +117,7 @@ def draw_graph(
70117
plt.gcf().set_size_inches(10, 5)
71118

72119
# Add legend
73-
plt.legend()
120+
plt.legend(loc=legend_loc)
74121

75122
# Add Grid
76123
plt.grid()
@@ -113,23 +160,72 @@ def get_time_dataframes(csv_path: str) -> Dict[str, pandas.DataFrame]:
113160
return dists
114161

115162

163+
def get_speedup_dataframes(time_dists: Dict[str, pandas.DataFrame]) -> Dict[str, pandas.DataFrame]:
164+
"""
165+
Computes the speedup for each library and returns the results as a dict {lib:DataFrame}.
166+
167+
Args:
168+
time_df: DataFrame containing the time data.
169+
170+
Returns:
171+
DataFrame containing the speedup for each library.
172+
"""
173+
174+
# Copy the DataFrame
175+
speedup_df = time_dists.copy()
176+
177+
# Get baseline time
178+
if "NONE" in time_dists:
179+
baseline_time = time_dists["NONE"]["time_elapsed"].item()
180+
else:
181+
baseline_time = None
182+
for df in time_dists.values():
183+
if len(df) <= 0:
184+
continue
185+
t = df["time_elapsed"].min()
186+
if baseline_time is None or t < baseline_time:
187+
baseline_time = t
188+
189+
# Compute the speedup
190+
for lib, df in speedup_df.items():
191+
df["speedup"] = df["time_elapsed"].map(lambda time: baseline_time / time)
192+
193+
# Return the DataFrame
194+
return speedup_df
195+
196+
116197
def main():
117198
# Parse the command-line arguments
118199
args = parser.parse_args()
119200

120201
# Read the CSV files
121-
data = get_time_dataframes(f"{args.in_csv_path}/time_{args.dataset}.csv")
202+
exec_dists = get_time_dataframes(f"{args.in_csv_path}/time_{args.dataset}.csv")
203+
speedup_dists = get_speedup_dataframes(exec_dists)
122204

123205
# Compute save path
124-
save_path = f"{args.out_png_path}/time_{args.dataset}.svg" if args.out_png_path else None
206+
path_exec = f"{args.out_png_path}/exec_{args.dataset}.svg" if args.out_png_path else None
207+
path_speedup = f"{args.out_png_path}/speedup_{args.dataset}.svg" if args.out_png_path else None
125208

126-
# Draw a graph
209+
# Draw execution graph
127210
draw_graph(
128-
dists=data,
211+
dists=exec_dists,
129212
x_label="n_processes",
130213
y_label="time_elapsed",
131214
title=f"Execution time ({args.dataset})",
132-
save_path=save_path
215+
save_path=path_exec,
216+
peakline="min",
217+
legend_loc='upper right',
218+
)
219+
220+
# Draw speedup graph
221+
draw_graph(
222+
dists=speedup_dists,
223+
x_label="n_processes",
224+
y_label="speedup",
225+
title=f"Speedup ({args.dataset})",
226+
save_path=path_speedup,
227+
peakline="max",
228+
legend_loc='lower right',
133229
)
134230

135231

0 commit comments

Comments
 (0)