Skip to content

Commit 79e302e

Browse files
committed
perf: run ablations for the paper
1 parent 23a8d1a commit 79e302e

File tree

5 files changed

+180
-9
lines changed

5 files changed

+180
-9
lines changed

perf/common.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
using BenchmarkTools: @belapsed
2+
using Reactant, Enzyme, PrettyTables, Statistics
3+
4+
function simple_mse_loss(model, x, ps, st)
5+
y, _ = Lux.apply(model, x, ps, st)
6+
return sum(abs2, y)
7+
end
8+
9+
function benchmark_nn_primal(
10+
model, x, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true
11+
)
12+
results = Vector{Tuple{String,String,Float64,Float64,Float64}}()
13+
14+
# Only XLA
15+
compiled_fwd_xla = @compile sync = true compile_options = Reactant.DefaultXLACompileOptions() simple_mse_loss(
16+
model, x, ps, st
17+
)
18+
bench = @benchmark $compiled_fwd_xla($model, $x, $ps, $st)
19+
push!(results, ("Primal", "Only XLA", median(bench).time, std(bench).time, 1.0))
20+
baseline = median(bench).time
21+
22+
# Default
23+
compiled_fwd = @compile sync = true simple_mse_loss(model, x, ps, st)
24+
bench = @benchmark $compiled_fwd($model, $x, $ps, $st)
25+
push!(
26+
results,
27+
(
28+
"Primal",
29+
"All",
30+
median(bench).time,
31+
std(bench).time,
32+
median(bench).time / baseline,
33+
),
34+
)
35+
36+
# Disable Scatter
37+
if disable_scatter_gather_bench
38+
compiled_fwd_no_scatter = @compile sync = true compile_options = CompileOptions(;
39+
disable_scatter_gather_optimization_passes=true
40+
) simple_mse_loss(model, x, ps, st)
41+
bench = @benchmark $compiled_fwd_no_scatter($model, $x, $ps, $st)
42+
43+
push!(
44+
results,
45+
(
46+
"Primal",
47+
"No Scatter/Gather Optimizations",
48+
median(bench).time,
49+
std(bench).time,
50+
median(bench).time / baseline,
51+
),
52+
)
53+
end
54+
55+
# Disable Pad
56+
if disable_pad_bench
57+
compiled_fwd_no_pad = @compile sync = true compile_options = CompileOptions(;
58+
disable_pad_optimization_passes=true
59+
) simple_mse_loss(model, x, ps, st)
60+
bench = @benchmark $compiled_fwd_no_pad($model, $x, $ps, $st)
61+
62+
push!(
63+
results,
64+
(
65+
"Primal",
66+
"No Pad Optimizations",
67+
median(bench).time,
68+
std(bench).time,
69+
median(bench).time / baseline,
70+
),
71+
)
72+
end
73+
74+
# Disable Scatter and Pad
75+
if disable_scatter_gather_bench && disable_pad_bench
76+
compiled_fwd_no_scatter_pad = @compile sync = true compile_options = CompileOptions(;
77+
disable_scatter_gather_optimization_passes=true,
78+
disable_pad_optimization_passes=true,
79+
) simple_mse_loss(model, x, ps, st)
80+
bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $ps, $st)
81+
82+
push!(
83+
results,
84+
(
85+
"Primal",
86+
"No Scatter/Gather and Pad Optimizations",
87+
median(bench).time,
88+
std(bench).time,
89+
median(bench).time / baseline,
90+
),
91+
)
92+
end
93+
94+
sort!(results, by=x -> x[3])
95+
return results
96+
end
97+
98+
function pretty_print_table(results)
99+
header = (
100+
["Mode", "Optimization Passes", "Median Time", "Std. Dev. Time", "Relative Timing"],
101+
["", "", "s", "s", "Time / XLA Time"],
102+
)
103+
104+
results = copy(results)
105+
results[:, 3] ./= 1e9
106+
results[:, 4] ./= 1e9
107+
108+
hl_r = Highlighter((data, i, j) -> j == 5 && data[i, j] > 1.0, crayon"bold red")
109+
hl_g = Highlighter((data, i, j) -> j == 5 && data[i, j] < 1.0, crayon"bold green")
110+
display(
111+
pretty_table(
112+
results;
113+
header,
114+
header_crayon=crayon"yellow bold",
115+
highlighters=(hl_r, hl_g),
116+
tf=tf_unicode_rounded,
117+
),
118+
)
119+
return nothing
120+
end

perf/neuraloperators/Project.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[deps]
2+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
4+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
6+
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
7+
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
10+
11+
[sources]
12+
Reactant = {path = "../../"}
13+
14+
[compat]
15+
BenchmarkTools = "1.6"
16+
CSV = "0.10.15"
17+
Lux = "1.13.4"
18+
NeuralOperators = "0.6"
19+
PrettyTables = "2.4.0"
20+
Random = "1.11"
21+
julia = "1.11"

perf/neuraloperators/main.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using NeuralOperators, Lux, Random
2+
3+
include("../common.jl")
4+
5+
const xdev = reactant_device()
6+
7+
function run_deeponet_benchmarks()
8+
@info "Running DeepONet benchmarks"
9+
10+
model = DeepONet(;
11+
branch=(64, ntuple(Returns(256), 5)..., 16),
12+
trunk=(1, ntuple(Returns(256), 5)..., 16),
13+
branch_activation=gelu,
14+
trunk_activation=gelu,
15+
)
16+
ps, st = xdev(Lux.setup(Random.default_rng(), model))
17+
u = xdev(rand(Float32, 64, 1024))
18+
y = xdev(rand(Float32, 1, 128))
19+
20+
primal_timings = Reactant.with_config(;
21+
dot_general_precision=PrecisionConfig.HIGH,
22+
convolution_precision=PrecisionConfig.HIGH,
23+
) do
24+
benchmark_nn_primal(
25+
model,
26+
(u, y),
27+
ps,
28+
st;
29+
disable_scatter_gather_bench=true,
30+
disable_pad_bench=true,
31+
)
32+
end
33+
34+
pretty_print_table(permutedims(hcat([[t...] for t in primal_timings]...), (2, 1)))
35+
36+
return nothing
37+
end
38+
39+
run_deeponet_benchmarks()

src/CompileOptions.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,6 @@ Fine-grained control over the compilation options for the Reactant compiler.
138138
- `assert_nonallocating`: If `true`, we make sure that no new buffers are
139139
returned by the function. Any buffer returned must be donated from the inputs. Defaults
140140
to `false`.
141-
- `sync`: Reactant computations are asynchronous by default. If `true`, the computation
142-
will be executed synchronously, blocking till the computation is complete. This is
143-
recommended when benchmarking.
144141
145142
# Extended Help
146143
@@ -178,7 +175,6 @@ struct CompileOptions
178175
# julia codegen options
179176
assert_nonallocating::Bool
180177
donated_args::Symbol
181-
sync::Bool
182178
## private options for ablation studies
183179
disable_scatter_gather_optimization_passes::Bool
184180
disable_pad_optimization_passes::Bool
@@ -201,7 +197,6 @@ function CompileOptions(;
201197
optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true,
202198
assert_nonallocating::Bool=false,
203199
donated_args::Symbol=:auto,
204-
sync::Bool=false,
205200
disable_scatter_gather_optimization_passes::Bool=false,
206201
disable_pad_optimization_passes::Bool=false,
207202
)
@@ -248,7 +243,6 @@ function CompileOptions(;
248243
optimize_communications,
249244
assert_nonallocating,
250245
donated_args,
251-
sync,
252246
disable_scatter_gather_optimization_passes,
253247
disable_pad_optimization_passes,
254248
)
@@ -288,7 +282,6 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt
288282
compile_options.optimize_communications,
289283
compile_options.assert_nonallocating,
290284
compile_options.donated_args,
291-
compile_options.sync,
292285
compile_options.disable_scatter_gather_optimization_passes,
293286
compile_options.disable_pad_optimization_passes,
294287
)

src/Compiler.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,6 @@ function __get_compile_options_and_kwargs(;
12601260
optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true,
12611261
assert_nonallocating::Bool=false,
12621262
donated_args::Symbol=:auto,
1263-
sync::Bool=false,
12641263
kwargs...,
12651264
)
12661265
return (
@@ -1282,7 +1281,6 @@ function __get_compile_options_and_kwargs(;
12821281
optimize_communications,
12831282
assert_nonallocating,
12841283
donated_args,
1285-
sync,
12861284
),
12871285
kwargs,
12881286
)

0 commit comments

Comments
 (0)