1
1
using BenchmarkTools: @benchmark
2
2
using Reactant, Enzyme, PrettyTables, Statistics
3
3
4
- function simple_mse_loss (model, x, ps, st)
4
+ function simple_mse_loss (model, x, z, ps, st)
5
5
y, _ = Lux. apply (model, x, ps, st)
6
- return sum (abs2, y)
6
+ return MSELoss ()(y, z)
7
+ end
8
+
9
+ function simple_mse_loss_gradient (model, x, z, ps, st)
10
+ return Enzyme. gradient (
11
+ Reverse, simple_mse_loss, Const (model), Const (x), Const (z), ps, Const (st)
12
+ )
7
13
end
8
14
9
15
function benchmark_nn_primal (
10
- model, x, ps, st; disable_scatter_gather_bench= true , disable_pad_bench= true
16
+ model, x, z, ps, st; disable_scatter_gather_bench= true , disable_pad_bench= true
11
17
)
12
18
results = Vector {Tuple{String,String,Float64,Float64,Float64}} ()
13
19
14
20
# Only XLA
15
21
compiled_fwd_xla = @compile sync = true compile_options = Reactant. DefaultXLACompileOptions () simple_mse_loss (
16
- model, x, ps, st
22
+ model, x, z, ps, st
17
23
)
18
- bench = @benchmark $ compiled_fwd_xla ($ model, $ x, $ ps, $ st)
24
+ bench = @benchmark $ compiled_fwd_xla ($ model, $ x, $ z, $ ps, $ st) setup = (GC . gc ( true ) )
19
25
push! (results, (" Primal" , " Only XLA" , median (bench). time, std (bench). time, 1.0 ))
20
26
baseline = median (bench). time
21
27
22
28
# Default
23
- compiled_fwd = @compile sync = true simple_mse_loss (model, x, ps, st)
24
- bench = @benchmark $ compiled_fwd ($ model, $ x, $ ps, $ st)
29
+ compiled_fwd = @compile sync = true simple_mse_loss (model, x, z, ps, st)
30
+ bench = @benchmark $ compiled_fwd ($ model, $ x, $ z, $ ps, $ st) setup = (GC . gc ( true ) )
25
31
push! (
26
32
results,
27
33
(
@@ -37,8 +43,10 @@ function benchmark_nn_primal(
37
43
if disable_scatter_gather_bench
38
44
compiled_fwd_no_scatter = @compile sync = true compile_options = CompileOptions (;
39
45
disable_scatter_gather_optimization_passes= true
40
- ) simple_mse_loss (model, x, ps, st)
41
- bench = @benchmark $ compiled_fwd_no_scatter ($ model, $ x, $ ps, $ st)
46
+ ) simple_mse_loss (model, x, z, ps, st)
47
+ bench = @benchmark $ compiled_fwd_no_scatter ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
48
+ true
49
+ ))
42
50
43
51
push! (
44
52
results,
@@ -56,8 +64,10 @@ function benchmark_nn_primal(
56
64
if disable_pad_bench
57
65
compiled_fwd_no_pad = @compile sync = true compile_options = CompileOptions (;
58
66
disable_pad_optimization_passes= true
59
- ) simple_mse_loss (model, x, ps, st)
60
- bench = @benchmark $ compiled_fwd_no_pad ($ model, $ x, $ ps, $ st)
67
+ ) simple_mse_loss (model, x, z, ps, st)
68
+ bench = @benchmark $ compiled_fwd_no_pad ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
69
+ true
70
+ ))
61
71
62
72
push! (
63
73
results,
@@ -76,8 +86,10 @@ function benchmark_nn_primal(
76
86
compiled_fwd_no_scatter_pad = @compile sync = true compile_options = CompileOptions (;
77
87
disable_scatter_gather_optimization_passes= true ,
78
88
disable_pad_optimization_passes= true ,
79
- ) simple_mse_loss (model, x, ps, st)
80
- bench = @benchmark $ compiled_fwd_no_scatter_pad ($ model, $ x, $ ps, $ st)
89
+ ) simple_mse_loss (model, x, z, ps, st)
90
+ bench = @benchmark $ compiled_fwd_no_scatter_pad ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
91
+ true
92
+ ))
81
93
82
94
push! (
83
95
results,
@@ -95,6 +107,127 @@ function benchmark_nn_primal(
95
107
return results
96
108
end
97
109
110
+ function benchmark_nn_gradient (model, x, z, ps, st; kwargs... )
111
+ return vcat (
112
+ [
113
+ benchmark_nn_gradient_internal (model, x, z, ps, st, mode; kwargs... ) for
114
+ mode in [:all , :before_enzyme , :after_enzyme ]
115
+ ]. .. ,
116
+ )
117
+ end
118
+
119
+ function benchmark_nn_gradient_internal (
120
+ model, x, z, ps, st, mode; disable_scatter_gather_bench= true , disable_pad_bench= true
121
+ )
122
+ @info " Benchmarking gradient with mode: $(Meta. quot (mode)) "
123
+
124
+ results = Vector {Tuple{String,String,Float64,Float64,Float64}} ()
125
+
126
+ # Only XLA
127
+ compiled_grad_xla = @compile sync = true compile_options = Reactant. DefaultXLACompileOptions () simple_mse_loss_gradient (
128
+ model, x, z, ps, st
129
+ )
130
+ bench = @benchmark $ compiled_grad_xla ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (true ))
131
+ push! (
132
+ results, (" Gradient ($mode )" , " Only XLA" , median (bench). time, std (bench). time, 1.0 )
133
+ )
134
+ baseline = median (bench). time
135
+
136
+ display (results[end ])
137
+
138
+ # Default
139
+ compiled_grad = @compile sync = true optimize = mode simple_mse_loss_gradient (
140
+ model, x, z, ps, st
141
+ )
142
+ bench = @benchmark $ compiled_grad ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (true ))
143
+ push! (
144
+ results,
145
+ (
146
+ " Gradient ($mode )" ,
147
+ " All" ,
148
+ median (bench). time,
149
+ std (bench). time,
150
+ median (bench). time / baseline,
151
+ ),
152
+ )
153
+
154
+ display (results[end ])
155
+
156
+ # Disable Scatter
157
+ if disable_scatter_gather_bench
158
+ compiled_grad_no_scatter = @compile sync = true compile_options = CompileOptions (;
159
+ disable_scatter_gather_optimization_passes= true , optimization_passes= mode
160
+ ) simple_mse_loss_gradient (model, x, z, ps, st)
161
+ bench = @benchmark $ compiled_grad_no_scatter ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
162
+ true
163
+ ))
164
+
165
+ push! (
166
+ results,
167
+ (
168
+ " Gradient ($mode )" ,
169
+ " No Scatter/Gather Optimizations" ,
170
+ median (bench). time,
171
+ std (bench). time,
172
+ median (bench). time / baseline,
173
+ ),
174
+ )
175
+
176
+ display (results[end ])
177
+ end
178
+
179
+ # Disable Pad
180
+ if disable_pad_bench
181
+ compiled_grad_no_pad = @compile sync = true compile_options = CompileOptions (;
182
+ disable_pad_optimization_passes= true , optimization_passes= mode
183
+ ) simple_mse_loss_gradient (model, x, z, ps, st)
184
+ bench = @benchmark $ compiled_grad_no_pad ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
185
+ true
186
+ ))
187
+
188
+ push! (
189
+ results,
190
+ (
191
+ " Gradient ($mode )" ,
192
+ " No Pad Optimizations" ,
193
+ median (bench). time,
194
+ std (bench). time,
195
+ median (bench). time / baseline,
196
+ ),
197
+ )
198
+
199
+ display (results[end ])
200
+ end
201
+
202
+ # Disable Pad and Scatter
203
+ if disable_scatter_gather_bench && disable_pad_bench
204
+ compiled_grad_no_scatter_no_pad = @compile sync = true compile_options = CompileOptions (;
205
+ disable_scatter_gather_optimization_passes= true ,
206
+ disable_pad_optimization_passes= true ,
207
+ optimization_passes= mode,
208
+ ) simple_mse_loss_gradient (model, x, z, ps, st)
209
+ bench = @benchmark $ compiled_grad_no_scatter_no_pad ($ model, $ x, $ z, $ ps, $ st) setup = (GC. gc (
210
+ true
211
+ ))
212
+
213
+ push! (
214
+ results,
215
+ (
216
+ " Gradient ($mode )" ,
217
+ " No Scatter/Gather/Pad Optimizations" ,
218
+ median (bench). time,
219
+ std (bench). time,
220
+ median (bench). time / baseline,
221
+ ),
222
+ )
223
+
224
+ display (results[end ])
225
+ end
226
+
227
+ sort! (results; by= x -> x[3 ])
228
+ return results
229
+ end
230
+
98
231
function pretty_print_table (results)
99
232
header = (
100
233
[" Mode" , " Optimization Passes" , " Median Time" , " Std. Dev. Time" , " Relative Timing" ],
0 commit comments