Skip to content

Commit 1df5621

Browse files
authored
perf: (Enzyme) remove tangent conversion, use native gradient when possible (#730)
* perf: remove tangent conversion with Enzyme, use native gradient when possible * Drop tests * bump
1 parent 3417ff5 commit 1df5621

File tree

8 files changed

+136
-172
lines changed

8 files changed

+136
-172
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.41"
4+
version = "0.6.42"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ function DI.value_and_pushforward(
2020
) where {F,C}
2121
mode = forward_withprimal(backend)
2222
f_and_df = get_f_and_df(f, backend, mode)
23-
dx_sametype = convert(typeof(x), only(tx))
24-
x_and_dx = Duplicated(x, dx_sametype)
23+
dx = only(tx)
24+
x_and_dx = Duplicated(x, dx)
2525
annotated_contexts = translate(backend, mode, Val(1), contexts...)
2626
dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)
2727
return y, (dy,)
@@ -37,8 +37,7 @@ function DI.value_and_pushforward(
3737
) where {F,B,C}
3838
mode = forward_withprimal(backend)
3939
f_and_df = get_f_and_df(f, backend, mode, Val(B))
40-
tx_sametype = map(Fix1(convert, typeof(x)), tx)
41-
x_and_tx = BatchDuplicated(x, tx_sametype)
40+
x_and_tx = BatchDuplicated(x, tx)
4241
annotated_contexts = translate(backend, mode, Val(B), contexts...)
4342
ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)
4443
return y, values(ty)
@@ -54,8 +53,8 @@ function DI.pushforward(
5453
) where {F,C}
5554
mode = forward_noprimal(backend)
5655
f_and_df = get_f_and_df(f, backend, mode)
57-
dx_sametype = convert(typeof(x), only(tx))
58-
x_and_dx = Duplicated(x, dx_sametype)
56+
dx = only(tx)
57+
x_and_dx = Duplicated(x, dx)
5958
annotated_contexts = translate(backend, mode, Val(1), contexts...)
6059
dy = only(autodiff(mode, f_and_df, x_and_dx, annotated_contexts...))
6160
return (dy,)
@@ -71,8 +70,7 @@ function DI.pushforward(
7170
) where {F,B,C}
7271
mode = forward_noprimal(backend)
7372
f_and_df = get_f_and_df(f, backend, mode, Val(B))
74-
tx_sametype = map(Fix1(convert, typeof(x)), tx)
75-
x_and_tx = BatchDuplicated(x, tx_sametype)
73+
x_and_tx = BatchDuplicated(x, tx)
7674
annotated_contexts = translate(backend, mode, Val(B), contexts...)
7775
ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))
7876
return values(ty)

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ function DI.value_and_pushforward(
2222
) where {F,C}
2323
mode = forward_noprimal(backend)
2424
f!_and_df! = get_f_and_df(f!, backend, mode)
25-
dx_sametype = convert(typeof(x), only(tx))
26-
dy_sametype = make_zero(y)
27-
x_and_dx = Duplicated(x, dx_sametype)
28-
y_and_dy = Duplicated(y, dy_sametype)
25+
dx = only(tx)
26+
dy = make_zero(y)
27+
x_and_dx = Duplicated(x, dx)
28+
y_and_dy = Duplicated(y, dy)
2929
annotated_contexts = translate(backend, mode, Val(1), contexts...)
3030
autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...)
31-
return y, (dy_sametype,)
31+
return y, (dy,)
3232
end
3333

3434
function DI.value_and_pushforward(
@@ -42,13 +42,12 @@ function DI.value_and_pushforward(
4242
) where {F,B,C}
4343
mode = forward_noprimal(backend)
4444
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
45-
tx_sametype = map(Fix1(convert, typeof(x)), tx)
46-
ty_sametype = ntuple(_ -> make_zero(y), Val(B))
47-
x_and_tx = BatchDuplicated(x, tx_sametype)
48-
y_and_ty = BatchDuplicated(y, ty_sametype)
45+
ty = ntuple(_ -> make_zero(y), Val(B))
46+
x_and_tx = BatchDuplicated(x, tx)
47+
y_and_ty = BatchDuplicated(y, ty)
4948
annotated_contexts = translate(backend, mode, Val(B), contexts...)
5049
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
51-
return y, ty_sametype
50+
return y, ty
5251
end
5352

5453
function DI.pushforward(
@@ -76,13 +75,10 @@ function DI.value_and_pushforward!(
7675
) where {F,B,C}
7776
mode = forward_noprimal(backend)
7877
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
79-
tx_sametype = map(Fix1(convert, typeof(x)), tx)
80-
ty_sametype = map(Fix1(convert, typeof(y)), ty)
81-
x_and_tx = BatchDuplicated(x, tx_sametype)
82-
y_and_ty = BatchDuplicated(y, ty_sametype)
78+
x_and_tx = BatchDuplicated(x, tx)
79+
y_and_ty = BatchDuplicated(y, ty)
8380
annotated_contexts = translate(backend, mode, Val(B), contexts...)
8481
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
85-
foreach(copyto_if_different_addresses!, ty, ty_sametype)
8682
return y, ty
8783
end
8884

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 78 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ function seeded_autodiff_thunk(
88
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
99
tape, result, shadow_result = forward(f, args...)
1010
if RA <: Active
11-
dresult_righttype = convert(typeof(result), dresult)
12-
dinputs = only(reverse(f, args..., dresult_righttype, tape))
11+
dinputs = only(reverse(f, args..., dresult, tape))
1312
else
1413
shadow_result .+= dresult # TODO: generalize beyond arrays
1514
dinputs = only(reverse(f, args..., tape))
@@ -32,8 +31,7 @@ function batch_seeded_autodiff_thunk(
3231
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
3332
tape, result, shadow_results = forward(f, args...)
3433
if RA <: Active
35-
dresults_righttype = map(Fix1(convert, typeof(result)), dresults)
36-
dinputs = only(reverse(f, args..., dresults_righttype, tape))
34+
dinputs = only(reverse(f, args..., dresults, tape))
3735
else
3836
foreach(shadow_results, dresults) do d0, d
3937
d0 .+= d # use recursive_add here?
@@ -141,13 +139,12 @@ function DI.value_and_pullback!(
141139
mode = reverse_split_withprimal(backend)
142140
f_and_df = force_annotation(get_f_and_df(f, backend, mode))
143141
RA = guess_activity(typeof(prep.y_example), mode)
144-
dx_righttype = convert(typeof(x), only(tx))
145-
make_zero!(dx_righttype)
142+
dx = only(tx)
143+
make_zero!(dx)
146144
annotated_contexts = translate(backend, mode, Val(1), contexts...)
147145
_, result = seeded_autodiff_thunk(
148-
mode, only(ty), f_and_df, RA, Duplicated(x, dx_righttype), annotated_contexts...
146+
mode, only(ty), f_and_df, RA, Duplicated(x, dx), annotated_contexts...
149147
)
150-
copyto_if_different_addresses!(only(tx), dx_righttype)
151148
return result, tx
152149
end
153150

@@ -163,13 +160,11 @@ function DI.value_and_pullback!(
163160
mode = reverse_split_withprimal(backend)
164161
f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B)))
165162
RA = batchify_activity(guess_activity(typeof(prep.y_example), mode), Val(B))
166-
tx_righttype = map(Fix1(convert, typeof(x)), tx)
167-
make_zero!(tx_righttype)
163+
make_zero!(tx)
168164
annotated_contexts = translate(backend, mode, Val(B), contexts...)
169165
_, result = batch_seeded_autodiff_thunk(
170-
mode, ty, f_and_df, RA, BatchDuplicated(x, tx_righttype), annotated_contexts...
166+
mode, ty, f_and_df, RA, BatchDuplicated(x, tx), annotated_contexts...
171167
)
172-
foreach(copyto_if_different_addresses!, tx, tx_righttype)
173168
return result, tx
174169
end
175170

@@ -187,10 +182,73 @@ end
187182

188183
## Gradient
189184

190-
### Without preparation
185+
function DI.prepare_gradient(
186+
f::F, ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}
187+
) where {F,C}
188+
return DI.NoGradientPrep()
189+
end
190+
191+
### Enzyme gradient API (only constants)
192+
193+
function DI.gradient(
194+
f::F,
195+
::DI.NoGradientPrep,
196+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
197+
x,
198+
contexts::Vararg{DI.Constant,C},
199+
) where {F,C}
200+
mode = reverse_noprimal(backend)
201+
f_and_df = get_f_and_df(f, backend, mode)
202+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
203+
grads = gradient(mode, f_and_df, x, annotated_contexts...)
204+
return first(grads)
205+
end
206+
207+
function DI.value_and_gradient(
208+
f::F,
209+
::DI.NoGradientPrep,
210+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
211+
x,
212+
contexts::Vararg{DI.Constant,C},
213+
) where {F,C}
214+
mode = reverse_withprimal(backend)
215+
f_and_df = get_f_and_df(f, backend, mode)
216+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
217+
grads, result = gradient(mode, f_and_df, x, annotated_contexts...)
218+
return result, first(grads)
219+
end
220+
221+
function DI.gradient!(
222+
f::F,
223+
grad,
224+
::DI.NoGradientPrep,
225+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
226+
x,
227+
) where {F}
228+
mode = reverse_noprimal(backend)
229+
f_and_df = get_f_and_df(f, backend, mode)
230+
gradient!(mode, grad, f_and_df, x)
231+
return grad
232+
end
233+
234+
function DI.value_and_gradient!(
235+
f::F,
236+
grad,
237+
::DI.NoGradientPrep,
238+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
239+
x,
240+
) where {F}
241+
mode = reverse_withprimal(backend)
242+
f_and_df = get_f_and_df(f, backend, mode)
243+
_, result = gradient!(mode, grad, f_and_df, x)
244+
return result, grad
245+
end
246+
247+
### Generic
191248

192249
function DI.gradient(
193250
f::F,
251+
::DI.NoGradientPrep,
194252
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
195253
x,
196254
contexts::Vararg{DI.Context,C},
@@ -213,6 +271,7 @@ end
213271

214272
function DI.value_and_gradient(
215273
f::F,
274+
::DI.NoGradientPrep,
216275
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
217276
x,
218277
contexts::Vararg{DI.Context,C},
@@ -233,73 +292,34 @@ function DI.value_and_gradient(
233292
end
234293
end
235294

236-
### With preparation
237-
238-
struct EnzymeGradientPrep{G} <: DI.GradientPrep
239-
grad_righttype::G
240-
end
241-
242-
function DI.prepare_gradient(
243-
f::F, ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}
244-
) where {F,C}
245-
grad_righttype = make_zero(x)
246-
return EnzymeGradientPrep(grad_righttype)
247-
end
248-
249-
function DI.gradient(
250-
f::F,
251-
::EnzymeGradientPrep,
252-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
253-
x,
254-
contexts::Vararg{DI.Context,C},
255-
) where {F,C}
256-
return DI.gradient(f, backend, x, contexts...)
257-
end
258-
259295
function DI.gradient!(
260296
f::F,
261297
grad,
262-
prep::EnzymeGradientPrep,
298+
::DI.NoGradientPrep,
263299
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
264300
x,
265301
contexts::Vararg{DI.Context,C},
266302
) where {F,C}
267303
mode = reverse_noprimal(backend)
268304
f_and_df = get_f_and_df(f, backend, mode)
269-
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
270-
make_zero!(grad_righttype)
271305
annotated_contexts = translate(backend, mode, Val(1), contexts...)
272-
autodiff(mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...)
273-
copyto_if_different_addresses!(grad, grad_righttype)
306+
make_zero!(grad)
307+
autodiff(mode, f_and_df, Active, Duplicated(x, grad), annotated_contexts...)
274308
return grad
275309
end
276310

277-
function DI.value_and_gradient(
278-
f::F,
279-
::EnzymeGradientPrep,
280-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
281-
x,
282-
contexts::Vararg{DI.Context,C},
283-
) where {F,C}
284-
return DI.value_and_gradient(f, backend, x, contexts...)
285-
end
286-
287311
function DI.value_and_gradient!(
288312
f::F,
289313
grad,
290-
prep::EnzymeGradientPrep,
314+
::DI.NoGradientPrep,
291315
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
292316
x,
293317
contexts::Vararg{DI.Context,C},
294318
) where {F,C}
295319
mode = reverse_withprimal(backend)
296320
f_and_df = get_f_and_df(f, backend, mode)
297-
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
298-
make_zero!(grad_righttype)
299321
annotated_contexts = translate(backend, mode, Val(1), contexts...)
300-
_, y = autodiff(
301-
mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...
302-
)
303-
copyto_if_different_addresses!(grad, grad_righttype)
322+
make_zero!(grad)
323+
_, y = autodiff(mode, f_and_df, Active, Duplicated(x, grad), annotated_contexts...)
304324
return y, grad
305325
end

0 commit comments

Comments
 (0)