Skip to content

Commit 2fc1fd9

Browse files
authored
perf: run ablations for the paper (#1408)
* perf: run ablations for the paper * Update perf/common.jl * fix: import * perf: FNOs * perf: grad * fix: remove forced passes at end * fix: correct usage of sync * fix: enable more passes * feat: plotting * fix: missing save * perf: add plots
1 parent 19425b6 commit 2fc1fd9

File tree

5 files changed

+457
-0
lines changed

5 files changed

+457
-0
lines changed

perf/common.jl

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
using BenchmarkTools: @benchmark
2+
using Reactant, Enzyme, PrettyTables, Statistics
3+
using CairoMakie, AlgebraOfGraphics, CSV, DataFrames, Dates
4+
const AoG = AlgebraOfGraphics
5+
6+
AoG.set_aog_theme!()
7+
8+
function simple_mse_loss(model, x, z, ps, st)
9+
y, _ = Lux.apply(model, x, ps, st)
10+
return MSELoss()(y, z)
11+
end
12+
13+
function simple_mse_loss_gradient(model, x, z, ps, st)
14+
return Enzyme.gradient(
15+
Enzyme.Reverse, simple_mse_loss, Const(model), Const(x), Const(z), ps, Const(st)
16+
)
17+
end
18+
19+
function benchmark_nn_primal(
20+
model, x, z, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true
21+
)
22+
results = Vector{Tuple{String,String,Float64,Float64,Float64}}()
23+
24+
# Only XLA
25+
compiled_fwd_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(;
26+
sync=true
27+
) simple_mse_loss(model, x, z, ps, st)
28+
bench = @benchmark $compiled_fwd_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true))
29+
push!(results, ("Primal", "Only XLA", mean(bench).time, std(bench).time, 1.0))
30+
baseline = mean(bench).time
31+
32+
# Default
33+
compiled_fwd = @compile compile_options = CompileOptions(;
34+
sync=true, no_nan=true, all_finite=true
35+
) simple_mse_loss(model, x, z, ps, st)
36+
bench = @benchmark $compiled_fwd($model, $x, $z, $ps, $st) setup = (GC.gc(true))
37+
push!(
38+
results,
39+
("Primal", "All", mean(bench).time, std(bench).time, mean(bench).time / baseline),
40+
)
41+
42+
# Disable Scatter
43+
if disable_scatter_gather_bench
44+
compiled_fwd_no_scatter = @compile compile_options = CompileOptions(;
45+
disable_scatter_gather_optimization_passes=true,
46+
sync=true,
47+
no_nan=true,
48+
all_finite=true,
49+
) simple_mse_loss(model, x, z, ps, st)
50+
bench = @benchmark $compiled_fwd_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc(
51+
true
52+
))
53+
54+
push!(
55+
results,
56+
(
57+
"Primal",
58+
"No Scatter/Gather Optimizations",
59+
mean(bench).time,
60+
std(bench).time,
61+
mean(bench).time / baseline,
62+
),
63+
)
64+
end
65+
66+
# Disable Pad
67+
if disable_pad_bench
68+
compiled_fwd_no_pad = @compile compile_options = CompileOptions(;
69+
disable_pad_optimization_passes=true, sync=true, no_nan=true, all_finite=true
70+
) simple_mse_loss(model, x, z, ps, st)
71+
bench = @benchmark $compiled_fwd_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
72+
true
73+
))
74+
75+
push!(
76+
results,
77+
(
78+
"Primal",
79+
"No Pad Optimizations",
80+
mean(bench).time,
81+
std(bench).time,
82+
mean(bench).time / baseline,
83+
),
84+
)
85+
end
86+
87+
# Disable Scatter and Pad
88+
if disable_scatter_gather_bench && disable_pad_bench
89+
compiled_fwd_no_scatter_pad = @compile compile_options = CompileOptions(;
90+
disable_scatter_gather_optimization_passes=true,
91+
disable_pad_optimization_passes=true,
92+
sync=true,
93+
no_nan=true,
94+
all_finite=true,
95+
) simple_mse_loss(model, x, z, ps, st)
96+
bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
97+
true
98+
))
99+
100+
push!(
101+
results,
102+
(
103+
"Primal",
104+
"No Scatter/Gather/Pad Optimizations",
105+
mean(bench).time,
106+
std(bench).time,
107+
mean(bench).time / baseline,
108+
),
109+
)
110+
end
111+
112+
sort!(results; by=x -> x[3])
113+
return results
114+
end
115+
116+
function benchmark_nn_gradient(model, x, z, ps, st; kwargs...)
117+
return vcat(
118+
[
119+
benchmark_nn_gradient_internal(model, x, z, ps, st, mode; kwargs...) for
120+
mode in [:all, :before_enzyme, :after_enzyme]
121+
]...,
122+
)
123+
end
124+
125+
function benchmark_nn_gradient_internal(
126+
model, x, z, ps, st, mode; disable_scatter_gather_bench=true, disable_pad_bench=true
127+
)
128+
@info "Benchmarking gradient with mode: $(Meta.quot(mode))"
129+
130+
results = Vector{Tuple{String,String,Float64,Float64,Float64}}()
131+
132+
# Only XLA
133+
compiled_grad_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(;
134+
sync=true
135+
) simple_mse_loss_gradient(model, x, z, ps, st)
136+
bench = @benchmark $compiled_grad_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true))
137+
push!(results, ("Gradient ($mode)", "Only XLA", mean(bench).time, std(bench).time, 1.0))
138+
baseline = mean(bench).time
139+
140+
# Default
141+
compiled_grad = @compile compile_options = CompileOptions(;
142+
sync=true, no_nan=true, all_finite=true, optimization_passes=mode
143+
) simple_mse_loss_gradient(model, x, z, ps, st)
144+
bench = @benchmark $compiled_grad($model, $x, $z, $ps, $st) setup = (GC.gc(true))
145+
push!(
146+
results,
147+
(
148+
"Gradient ($mode)",
149+
"All",
150+
mean(bench).time,
151+
std(bench).time,
152+
mean(bench).time / baseline,
153+
),
154+
)
155+
156+
# Disable Scatter
157+
if disable_scatter_gather_bench
158+
compiled_grad_no_scatter = @compile compile_options = CompileOptions(;
159+
disable_scatter_gather_optimization_passes=true,
160+
optimization_passes=mode,
161+
sync=true,
162+
no_nan=true,
163+
all_finite=true,
164+
) simple_mse_loss_gradient(model, x, z, ps, st)
165+
bench = @benchmark $compiled_grad_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc(
166+
true
167+
))
168+
169+
push!(
170+
results,
171+
(
172+
"Gradient ($mode)",
173+
"No Scatter/Gather Optimizations",
174+
mean(bench).time,
175+
std(bench).time,
176+
mean(bench).time / baseline,
177+
),
178+
)
179+
end
180+
181+
# Disable Pad
182+
if disable_pad_bench
183+
compiled_grad_no_pad = @compile compile_options = CompileOptions(;
184+
disable_pad_optimization_passes=true,
185+
optimization_passes=mode,
186+
sync=true,
187+
no_nan=true,
188+
all_finite=true,
189+
) simple_mse_loss_gradient(model, x, z, ps, st)
190+
bench = @benchmark $compiled_grad_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
191+
true
192+
))
193+
194+
push!(
195+
results,
196+
(
197+
"Gradient ($mode)",
198+
"No Pad Optimizations",
199+
mean(bench).time,
200+
std(bench).time,
201+
mean(bench).time / baseline,
202+
),
203+
)
204+
end
205+
206+
# Disable Pad and Scatter
207+
if disable_scatter_gather_bench && disable_pad_bench
208+
compiled_grad_no_scatter_no_pad = @compile compile_options = CompileOptions(;
209+
disable_scatter_gather_optimization_passes=true,
210+
disable_pad_optimization_passes=true,
211+
optimization_passes=mode,
212+
sync=true,
213+
no_nan=true,
214+
all_finite=true,
215+
) simple_mse_loss_gradient(model, x, z, ps, st)
216+
bench = @benchmark $compiled_grad_no_scatter_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
217+
true
218+
))
219+
220+
push!(
221+
results,
222+
(
223+
"Gradient ($mode)",
224+
"No Scatter/Gather/Pad Optimizations",
225+
mean(bench).time,
226+
std(bench).time,
227+
mean(bench).time / baseline,
228+
),
229+
)
230+
end
231+
232+
sort!(results; by=x -> x[3])
233+
return results
234+
end
235+
236+
function pretty_print_table(results)
237+
header = (
238+
["Mode", "Optimization Passes", "Mean Time", "Std. Dev. Time", "Relative Timing"],
239+
["", "", "s", "s", "Time / XLA Time"],
240+
)
241+
242+
results = copy(results)
243+
results[:, 3] ./= 1e9
244+
results[:, 4] ./= 1e9
245+
246+
hl_r = Highlighter((data, i, j) -> j == 5 && data[i, j] > 1.0, crayon"bold red")
247+
hl_g = Highlighter((data, i, j) -> j == 5 && data[i, j] < 1.0, crayon"bold green")
248+
display(
249+
pretty_table(
250+
results;
251+
header,
252+
header_crayon=crayon"yellow bold",
253+
highlighters=(hl_r, hl_g),
254+
tf=tf_unicode_rounded,
255+
),
256+
)
257+
return nothing
258+
end
259+
260+
function save_benchmark_results(
261+
results::Matrix,
262+
tag;
263+
savedir=tempname(; cleanup=false),
264+
device_tag=lowercase(
265+
replace(Reactant.XLA.device_kind(Reactant.devices()[1]), " " => "_")
266+
),
267+
plot_title="",
268+
)
269+
IN_VSCODE = isdefined(Main, :VSCodeServer)
270+
271+
short_forms = Dict(
272+
"All" => "All",
273+
"Only XLA" => "Only XLA",
274+
"No Pad Optimizations" => "- Pad Opt",
275+
"No Scatter/Gather Optimizations" => "- S.G. Opt",
276+
"No Scatter/Gather/Pad Optimizations" => "- S.G. + Pad Opt",
277+
"No Scatter/Gather and Pad Optimizations" => "- S.G. + Pad Opt",
278+
)
279+
280+
mkpath(savedir)
281+
file_name_base = "$(tag)_$(device_tag)_$(Dates.format(now(), "yyyy_mm_dd_HH_MM_SS"))"
282+
283+
df = DataFrame(
284+
results,
285+
["Mode", "Optimization Passes", "Mean Time", "Std. Dev. Time", "Relative Timing"],
286+
)
287+
288+
csv_results_file_name = joinpath(savedir, "$(file_name_base).csv")
289+
CSV.write(csv_results_file_name, df)
290+
291+
@info "Saving timings to $(csv_results_file_name)"
292+
293+
df[!, "μ - σ"] = df[!, "Mean Time"] .- df[!, "Std. Dev. Time"]
294+
df[!, "μ + σ"] = df[!, "Mean Time"] .+ df[!, "Std. Dev. Time"]
295+
296+
fig = draw(
297+
(
298+
data(df) *
299+
mapping(
300+
"Mode",
301+
"Mean Time";
302+
dodge="Optimization Passes" => "",
303+
color="Optimization Passes" => x -> short_forms[x],
304+
) *
305+
visual(BarPlot; strokewidth=2)
306+
) + (
307+
data(df) *
308+
mapping("Mode", "μ - σ", "μ + σ"; dodge_x="Optimization Passes" => "") *
309+
visual(Rangebars; linewidth=2, whiskerwidth=10)
310+
),
311+
scales(; Color=(; palette=:tab10));
312+
figure=(; size=(1000, 500), title=plot_title, titlealign=:center),
313+
legend=(; position=:bottom),
314+
)
315+
316+
IN_VSCODE && display(fig)
317+
318+
plots_file_name = joinpath(savedir, "$(file_name_base).pdf")
319+
save(plots_file_name, fig)
320+
321+
@info "Saving plots to $(plots_file_name)"
322+
323+
return nothing
324+
end

perf/neuraloperators/Project.toml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
[deps]
2+
AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
3+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
4+
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
5+
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
6+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
7+
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
8+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
9+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
10+
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
11+
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
12+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
14+
15+
[sources]
16+
Reactant = {path = "../../"}
17+
18+
[compat]
19+
BenchmarkTools = "1.6"
20+
CSV = "0.10.15"
21+
Lux = "1.13.4"
22+
NeuralOperators = "0.6"
23+
PrettyTables = "2.4.0"
24+
Random = "1.11"
25+
julia = "1.11"
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)