Skip to content

Commit a93b09f

Browse files
committed
perf: grad
1 parent c51a517 commit a93b09f

File tree

4 files changed

+202
-19
lines changed

4 files changed

+202
-19
lines changed

perf/common.jl

Lines changed: 146 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
11
using BenchmarkTools: @benchmark
22
using Reactant, Enzyme, PrettyTables, Statistics
33

4-
function simple_mse_loss(model, x, ps, st)
4+
function simple_mse_loss(model, x, z, ps, st)
55
y, _ = Lux.apply(model, x, ps, st)
6-
return sum(abs2, y)
6+
return MSELoss()(y, z)
7+
end
8+
9+
function simple_mse_loss_gradient(model, x, z, ps, st)
10+
return Enzyme.gradient(
11+
Reverse, simple_mse_loss, Const(model), Const(x), Const(z), ps, Const(st)
12+
)
713
end
814

915
function benchmark_nn_primal(
10-
model, x, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true
16+
model, x, z, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true
1117
)
1218
results = Vector{Tuple{String,String,Float64,Float64,Float64}}()
1319

1420
# Only XLA
1521
compiled_fwd_xla = @compile sync = true compile_options = Reactant.DefaultXLACompileOptions() simple_mse_loss(
16-
model, x, ps, st
22+
model, x, z, ps, st
1723
)
18-
bench = @benchmark $compiled_fwd_xla($model, $x, $ps, $st)
24+
bench = @benchmark $compiled_fwd_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true))
1925
push!(results, ("Primal", "Only XLA", median(bench).time, std(bench).time, 1.0))
2026
baseline = median(bench).time
2127

2228
# Default
23-
compiled_fwd = @compile sync = true simple_mse_loss(model, x, ps, st)
24-
bench = @benchmark $compiled_fwd($model, $x, $ps, $st)
29+
compiled_fwd = @compile sync = true simple_mse_loss(model, x, z, ps, st)
30+
bench = @benchmark $compiled_fwd($model, $x, $z, $ps, $st) setup = (GC.gc(true))
2531
push!(
2632
results,
2733
(
@@ -37,8 +43,10 @@ function benchmark_nn_primal(
3743
if disable_scatter_gather_bench
3844
compiled_fwd_no_scatter = @compile sync = true compile_options = CompileOptions(;
3945
disable_scatter_gather_optimization_passes=true
40-
) simple_mse_loss(model, x, ps, st)
41-
bench = @benchmark $compiled_fwd_no_scatter($model, $x, $ps, $st)
46+
) simple_mse_loss(model, x, z, ps, st)
47+
bench = @benchmark $compiled_fwd_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc(
48+
true
49+
))
4250

4351
push!(
4452
results,
@@ -56,8 +64,10 @@ function benchmark_nn_primal(
5664
if disable_pad_bench
5765
compiled_fwd_no_pad = @compile sync = true compile_options = CompileOptions(;
5866
disable_pad_optimization_passes=true
59-
) simple_mse_loss(model, x, ps, st)
60-
bench = @benchmark $compiled_fwd_no_pad($model, $x, $ps, $st)
67+
) simple_mse_loss(model, x, z, ps, st)
68+
bench = @benchmark $compiled_fwd_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
69+
true
70+
))
6171

6272
push!(
6373
results,
@@ -76,8 +86,10 @@ function benchmark_nn_primal(
7686
compiled_fwd_no_scatter_pad = @compile sync = true compile_options = CompileOptions(;
7787
disable_scatter_gather_optimization_passes=true,
7888
disable_pad_optimization_passes=true,
79-
) simple_mse_loss(model, x, ps, st)
80-
bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $ps, $st)
89+
) simple_mse_loss(model, x, z, ps, st)
90+
bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
91+
true
92+
))
8193

8294
push!(
8395
results,
@@ -95,6 +107,127 @@ function benchmark_nn_primal(
95107
return results
96108
end
97109

110+
function benchmark_nn_gradient(model, x, z, ps, st; kwargs...)
111+
return vcat(
112+
[
113+
benchmark_nn_gradient_internal(model, x, z, ps, st, mode; kwargs...) for
114+
mode in [:all, :before_enzyme, :after_enzyme]
115+
]...,
116+
)
117+
end
118+
119+
function benchmark_nn_gradient_internal(
120+
model, x, z, ps, st, mode; disable_scatter_gather_bench=true, disable_pad_bench=true
121+
)
122+
@info "Benchmarking gradient with mode: $(Meta.quot(mode))"
123+
124+
results = Vector{Tuple{String,String,Float64,Float64,Float64}}()
125+
126+
# Only XLA
127+
compiled_grad_xla = @compile sync = true compile_options = Reactant.DefaultXLACompileOptions() simple_mse_loss_gradient(
128+
model, x, z, ps, st
129+
)
130+
bench = @benchmark $compiled_grad_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true))
131+
push!(
132+
results, ("Gradient ($mode)", "Only XLA", median(bench).time, std(bench).time, 1.0)
133+
)
134+
baseline = median(bench).time
135+
136+
display(results[end])
137+
138+
# Default
139+
compiled_grad = @compile sync = true optimize = mode simple_mse_loss_gradient(
140+
model, x, z, ps, st
141+
)
142+
bench = @benchmark $compiled_grad($model, $x, $z, $ps, $st) setup = (GC.gc(true))
143+
push!(
144+
results,
145+
(
146+
"Gradient ($mode)",
147+
"All",
148+
median(bench).time,
149+
std(bench).time,
150+
median(bench).time / baseline,
151+
),
152+
)
153+
154+
display(results[end])
155+
156+
# Disable Scatter
157+
if disable_scatter_gather_bench
158+
compiled_grad_no_scatter = @compile sync = true compile_options = CompileOptions(;
159+
disable_scatter_gather_optimization_passes=true, optimization_passes=mode
160+
) simple_mse_loss_gradient(model, x, z, ps, st)
161+
bench = @benchmark $compiled_grad_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc(
162+
true
163+
))
164+
165+
push!(
166+
results,
167+
(
168+
"Gradient ($mode)",
169+
"No Scatter/Gather Optimizations",
170+
median(bench).time,
171+
std(bench).time,
172+
median(bench).time / baseline,
173+
),
174+
)
175+
176+
display(results[end])
177+
end
178+
179+
# Disable Pad
180+
if disable_pad_bench
181+
compiled_grad_no_pad = @compile sync = true compile_options = CompileOptions(;
182+
disable_pad_optimization_passes=true, optimization_passes=mode
183+
) simple_mse_loss_gradient(model, x, z, ps, st)
184+
bench = @benchmark $compiled_grad_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
185+
true
186+
))
187+
188+
push!(
189+
results,
190+
(
191+
"Gradient ($mode)",
192+
"No Pad Optimizations",
193+
median(bench).time,
194+
std(bench).time,
195+
median(bench).time / baseline,
196+
),
197+
)
198+
199+
display(results[end])
200+
end
201+
202+
# Disable Pad and Scatter
203+
if disable_scatter_gather_bench && disable_pad_bench
204+
compiled_grad_no_scatter_no_pad = @compile sync = true compile_options = CompileOptions(;
205+
disable_scatter_gather_optimization_passes=true,
206+
disable_pad_optimization_passes=true,
207+
optimization_passes=mode,
208+
) simple_mse_loss_gradient(model, x, z, ps, st)
209+
bench = @benchmark $compiled_grad_no_scatter_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
210+
true
211+
))
212+
213+
push!(
214+
results,
215+
(
216+
"Gradient ($mode)",
217+
"No Scatter/Gather/Pad Optimizations",
218+
median(bench).time,
219+
std(bench).time,
220+
median(bench).time / baseline,
221+
),
222+
)
223+
224+
display(results[end])
225+
end
226+
227+
sort!(results; by=x -> x[3])
228+
return results
229+
end
230+
98231
function pretty_print_table(results)
99232
header = (
100233
["Mode", "Optimization Passes", "Median Time", "Std. Dev. Time", "Relative Timing"],

perf/neuraloperators/main.jl

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ function run_deeponet_benchmarks()
88
@info "Running DeepONet benchmarks"
99

1010
model = DeepONet(;
11-
branch=(64, ntuple(Returns(256), 5)..., 16),
12-
trunk=(1, ntuple(Returns(256), 5)..., 16),
11+
branch=(64, ntuple(Returns(256), 4)..., 16),
12+
trunk=(1, ntuple(Returns(256), 4)..., 16),
1313
branch_activation=gelu,
1414
trunk_activation=gelu,
1515
)
1616
ps, st = xdev(Lux.setup(Random.default_rng(), model))
1717
u = xdev(rand(Float32, 64, 1024))
1818
y = xdev(rand(Float32, 1, 128))
19+
z = xdev(rand(Float32, 128, 1024))
1920

2021
primal_timings = Reactant.with_config(;
2122
dot_general_precision=PrecisionConfig.HIGH,
@@ -24,14 +25,31 @@ function run_deeponet_benchmarks()
2425
benchmark_nn_primal(
2526
model,
2627
(u, y),
28+
z,
2729
ps,
2830
st;
2931
disable_scatter_gather_bench=true,
3032
disable_pad_bench=true,
3133
)
3234
end
3335

34-
pretty_print_table(permutedims(hcat([[t...] for t in primal_timings]...), (2, 1)))
36+
gradient_timings = Reactant.with_config(;
37+
dot_general_precision=PrecisionConfig.HIGH,
38+
convolution_precision=PrecisionConfig.HIGH,
39+
) do
40+
benchmark_nn_gradient(
41+
model,
42+
(u, y),
43+
z,
44+
ps,
45+
st;
46+
disable_scatter_gather_bench=true,
47+
disable_pad_bench=true,
48+
)
49+
end
50+
51+
timings = vcat(primal_timings, gradient_timings)
52+
pretty_print_table(permutedims(hcat([[t...] for t in timings]...), (2, 1)))
3553

3654
return nothing
3755
end
@@ -42,20 +60,43 @@ function run_fno_benchmarks()
4260
model = FourierNeuralOperator((16, 16), 3, 8, 64)
4361
ps, st = xdev(Lux.setup(Random.default_rng(), model))
4462
x = xdev(rand(Float32, 64, 64, 1, 256))
63+
z = xdev(rand(Float32, 64, 64, 8, 256))
4564

4665
primal_timings = Reactant.with_config(;
4766
dot_general_precision=PrecisionConfig.HIGH,
4867
convolution_precision=PrecisionConfig.HIGH,
4968
) do
5069
benchmark_nn_primal(
51-
model, x, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true
70+
model,
71+
x,
72+
z,
73+
ps,
74+
st;
75+
disable_scatter_gather_bench=true,
76+
disable_pad_bench=true,
5277
)
5378
end
5479

55-
pretty_print_table(permutedims(hcat([[t...] for t in primal_timings]...), (2, 1)))
80+
gradient_timings = Reactant.with_config(;
81+
dot_general_precision=PrecisionConfig.HIGH,
82+
convolution_precision=PrecisionConfig.HIGH,
83+
) do
84+
benchmark_nn_gradient(
85+
model,
86+
x,
87+
z,
88+
ps,
89+
st;
90+
disable_scatter_gather_bench=true,
91+
disable_pad_bench=true,
92+
)
93+
end
94+
95+
timings = vcat(primal_timings, gradient_timings)
96+
pretty_print_table(permutedims(hcat([[t...] for t in timings]...), (2, 1)))
5697

5798
return nothing
5899
end
59100

60-
run_fno_benchmarks()
61101
run_deeponet_benchmarks()
102+
run_fno_benchmarks()

src/CompileOptions.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ Fine-grained control over the compilation options for the Reactant compiler.
138138
- `assert_nonallocating`: If `true`, we make sure that no new buffers are
139139
returned by the function. Any buffer returned must be donated from the inputs. Defaults
140140
to `false`.
141+
- `sync`: Reactant computations are asynchronous by default. If `true`, the computation
142+
will be executed synchronously, blocking till the computation is complete. This is
143+
recommended when benchmarking.
141144
142145
# Extended Help
143146
@@ -175,6 +178,7 @@ struct CompileOptions
175178
# julia codegen options
176179
assert_nonallocating::Bool
177180
donated_args::Symbol
181+
sync::Bool
178182
## private options for ablation studies
179183
disable_scatter_gather_optimization_passes::Bool
180184
disable_pad_optimization_passes::Bool
@@ -197,6 +201,7 @@ function CompileOptions(;
197201
optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true,
198202
assert_nonallocating::Bool=false,
199203
donated_args::Symbol=:auto,
204+
sync::Bool=false,
200205
disable_scatter_gather_optimization_passes::Bool=false,
201206
disable_pad_optimization_passes::Bool=false,
202207
)
@@ -243,6 +248,7 @@ function CompileOptions(;
243248
optimize_communications,
244249
assert_nonallocating,
245250
donated_args,
251+
sync,
246252
disable_scatter_gather_optimization_passes,
247253
disable_pad_optimization_passes,
248254
)
@@ -282,6 +288,7 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt
282288
compile_options.optimize_communications,
283289
compile_options.assert_nonallocating,
284290
compile_options.donated_args,
291+
compile_options.sync,
285292
compile_options.disable_scatter_gather_optimization_passes,
286293
compile_options.disable_pad_optimization_passes,
287294
)

src/Compiler.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,7 @@ function __get_compile_options_and_kwargs(;
12601260
optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true,
12611261
assert_nonallocating::Bool=false,
12621262
donated_args::Symbol=:auto,
1263+
sync::Bool=false,
12631264
kwargs...,
12641265
)
12651266
return (
@@ -1281,6 +1282,7 @@ function __get_compile_options_and_kwargs(;
12811282
optimize_communications,
12821283
assert_nonallocating,
12831284
donated_args,
1285+
sync,
12841286
),
12851287
kwargs,
12861288
)

0 commit comments

Comments
 (0)