Skip to content

Commit bf47700

Browse files
authored
perf: reduce allocations with Enzyme for in-place functions (#707)
1 parent 8ef803c commit bf47700

File tree

9 files changed

+153
-102
lines changed

9 files changed

+153
-102
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
actions: write
2626
contents: read
2727
strategy:
28-
fail-fast: false # TODO: toggle
28+
fail-fast: true # TODO: toggle
2929
matrix:
3030
version:
3131
- "1.10"

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.35"
4+
version = "0.6.36"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Moreover, each context type is supported by a specific subset of backends:
6767
| `AutoFiniteDifferences` |||
6868
| `AutoForwardDiff` |||
6969
| `AutoGTPSA` |||
70-
| `AutoMooncake` || |
70+
| `AutoMooncake` || |
7171
| `AutoPolyesterForwardDiff` |||
7272
| `AutoReverseDiff` |||
7373
| `AutoSymbolics` |||

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,22 @@ end
6767
function DI.value_and_pushforward!(
6868
f!::F,
6969
y,
70-
ty::NTuple,
71-
prep::DI.NoPushforwardPrep,
70+
ty::NTuple{B},
71+
::DI.NoPushforwardPrep,
7272
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
7373
x,
74-
tx::NTuple,
74+
tx::NTuple{B},
7575
contexts::Vararg{DI.Context,C},
76-
) where {F,C}
77-
y, new_ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
78-
foreach(copyto!, ty, new_ty)
76+
) where {F,B,C}
77+
mode = forward_noprimal(backend)
78+
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
79+
tx_sametype = map(Fix1(convert, typeof(x)), tx)
80+
ty_sametype = map(Fix1(convert, typeof(y)), ty)
81+
x_and_tx = BatchDuplicated(x, tx_sametype)
82+
y_and_ty = BatchDuplicated(y, ty_sametype)
83+
annotated_contexts = translate(backend, mode, Val(B), contexts...)
84+
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
85+
foreach(copyto_if_different_addresses!, ty, ty_sametype)
7986
return y, ty
8087
end
8188

@@ -89,7 +96,6 @@ function DI.pushforward!(
8996
tx::NTuple,
9097
contexts::Vararg{DI.Context,C},
9198
) where {F,C}
92-
new_ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...)
93-
foreach(copyto!, ty, new_ty)
99+
DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)
94100
return ty
95101
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 20 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,26 @@ end
4949

5050
## Pullback
5151

52+
struct EnzymeReverseOneArgPullbackPrep{Y} <: DI.PullbackPrep
53+
y_example::Y # useful to create return activity
54+
end
55+
5256
function DI.prepare_pullback(
5357
f::F,
5458
::AutoEnzyme{<:Union{ReverseMode,Nothing}},
5559
x,
5660
ty::NTuple,
5761
contexts::Vararg{DI.Context,C},
5862
) where {F,C}
59-
return DI.NoPullbackPrep()
63+
y = f(x, map(DI.unwrap, contexts)...)
64+
return EnzymeReverseOneArgPullbackPrep(y)
6065
end
6166

6267
### Out-of-place
6368

6469
function DI.value_and_pullback(
6570
f::F,
66-
::DI.NoPullbackPrep,
71+
prep::EnzymeReverseOneArgPullbackPrep,
6772
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
6873
x,
6974
ty::NTuple{1},
@@ -72,7 +77,7 @@ function DI.value_and_pullback(
7277
mode = reverse_split_withprimal(backend)
7378
f_and_df = force_annotation(get_f_and_df(f, backend, mode))
7479
IA = guess_activity(typeof(x), mode)
75-
RA = guess_activity(eltype(ty), mode)
80+
RA = guess_activity(typeof(prep.y_example), mode)
7681
dx = make_zero(x)
7782
annotated_contexts = translate(backend, mode, Val(1), contexts...)
7883
dinputs, result = seeded_autodiff_thunk(
@@ -88,7 +93,7 @@ end
8893

8994
function DI.value_and_pullback(
9095
f::F,
91-
::DI.NoPullbackPrep,
96+
prep::EnzymeReverseOneArgPullbackPrep,
9297
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
9398
x,
9499
ty::NTuple{B},
@@ -97,7 +102,7 @@ function DI.value_and_pullback(
97102
mode = reverse_split_withprimal(backend)
98103
f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B)))
99104
IA = batchify_activity(guess_activity(typeof(x), mode), Val(B))
100-
RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B))
105+
RA = batchify_activity(guess_activity(typeof(prep.y_example), mode), Val(B))
101106
tx = ntuple(_ -> make_zero(x), Val(B))
102107
annotated_contexts = translate(backend, mode, Val(B), contexts...)
103108
dinputs, result = batch_seeded_autodiff_thunk(
@@ -113,7 +118,7 @@ end
113118

114119
function DI.pullback(
115120
f::F,
116-
prep::DI.NoPullbackPrep,
121+
prep::EnzymeReverseOneArgPullbackPrep,
117122
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
118123
x,
119124
ty::NTuple,
@@ -127,51 +132,51 @@ end
127132
function DI.value_and_pullback!(
128133
f::F,
129134
tx::NTuple{1},
130-
::DI.NoPullbackPrep,
135+
prep::EnzymeReverseOneArgPullbackPrep,
131136
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
132137
x,
133138
ty::NTuple{1},
134139
contexts::Vararg{DI.Context,C},
135140
) where {F,C}
136141
mode = reverse_split_withprimal(backend)
137142
f_and_df = force_annotation(get_f_and_df(f, backend, mode))
138-
RA = guess_activity(eltype(ty), mode)
143+
RA = guess_activity(typeof(prep.y_example), mode)
139144
dx_righttype = convert(typeof(x), only(tx))
140145
make_zero!(dx_righttype)
141146
annotated_contexts = translate(backend, mode, Val(1), contexts...)
142147
_, result = seeded_autodiff_thunk(
143148
mode, only(ty), f_and_df, RA, Duplicated(x, dx_righttype), annotated_contexts...
144149
)
145-
only(tx) === dx_righttype || copyto!(only(tx), dx_righttype)
150+
copyto_if_different_addresses!(only(tx), dx_righttype)
146151
return result, tx
147152
end
148153

149154
function DI.value_and_pullback!(
150155
f::F,
151156
tx::NTuple{B},
152-
::DI.NoPullbackPrep,
157+
prep::EnzymeReverseOneArgPullbackPrep,
153158
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
154159
x,
155160
ty::NTuple{B},
156161
contexts::Vararg{DI.Context,C},
157162
) where {F,B,C}
158163
mode = reverse_split_withprimal(backend)
159164
f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B)))
160-
RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B))
165+
RA = batchify_activity(guess_activity(typeof(prep.y_example), mode), Val(B))
161166
tx_righttype = map(Fix1(convert, typeof(x)), tx)
162167
make_zero!(tx_righttype)
163168
annotated_contexts = translate(backend, mode, Val(B), contexts...)
164169
_, result = batch_seeded_autodiff_thunk(
165170
mode, ty, f_and_df, RA, BatchDuplicated(x, tx_righttype), annotated_contexts...
166171
)
167-
foreach(copyto!, tx, tx_righttype)
172+
foreach(copyto_if_different_addresses!, tx, tx_righttype)
168173
return result, tx
169174
end
170175

171176
function DI.pullback!(
172177
f::F,
173178
tx::NTuple,
174-
prep::DI.NoPullbackPrep,
179+
prep::EnzymeReverseOneArgPullbackPrep,
175180
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
176181
x,
177182
ty::NTuple,
@@ -265,7 +270,7 @@ function DI.gradient!(
265270
make_zero!(grad_righttype)
266271
annotated_contexts = translate(backend, mode, Val(1), contexts...)
267272
autodiff(mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...)
268-
grad === grad_righttype || copyto!(grad, grad_righttype)
273+
copyto_if_different_addresses!(grad, grad_righttype)
269274
return grad
270275
end
271276

@@ -295,70 +300,6 @@ function DI.value_and_gradient!(
295300
_, y = autodiff(
296301
mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...
297302
)
298-
grad === grad_righttype || copyto!(grad, grad_righttype)
303+
copyto_if_different_addresses!(grad, grad_righttype)
299304
return y, grad
300305
end
301-
302-
## Jacobian
303-
304-
# TODO: does not support static arrays
305-
306-
#=
307-
struct EnzymeReverseOneArgJacobianPrep{Sy,B} <:DI.JacobianPrep end
308-
309-
function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B}
310-
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
311-
end
312-
313-
function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
314-
y = f(x)
315-
Sy = size(y)
316-
valB = to_val(DI.pick_batchsize(backend, y))
317-
return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB)
318-
end
319-
320-
function DI.jacobian(
321-
f::F,
322-
::EnzymeReverseOneArgJacobianPrep{Sy,B},
323-
backend::AutoEnzyme{<:ReverseMode,Nothing},
324-
x,
325-
) where {F,Sy,B}
326-
derivs = jacobian(reverse_noprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B))
327-
jac_tensor = only(derivs)
328-
return maybe_reshape(jac_tensor, prod(Sy), length(x))
329-
end
330-
331-
function DI.value_and_jacobian(
332-
f::F,
333-
::EnzymeReverseOneArgJacobianPrep{Sy,B},
334-
backend::AutoEnzyme{<:ReverseMode,Nothing},
335-
x,
336-
) where {F,Sy,B}
337-
(; derivs, val) = jacobian(
338-
reverse_withprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B)
339-
)
340-
jac_tensor = only(derivs)
341-
return val, maybe_reshape(jac_tensor, prod(Sy), length(x))
342-
end
343-
344-
function DI.jacobian!(
345-
f::F,
346-
jac,
347-
prep::EnzymeReverseOneArgJacobianPrep,
348-
backend::AutoEnzyme{<:ReverseMode,Nothing},
349-
x,
350-
) where {F}
351-
return copyto!(jac, DI.jacobian(f, prep, backend, x))
352-
end
353-
354-
function DI.value_and_jacobian!(
355-
f::F,
356-
jac,
357-
prep::EnzymeReverseOneArgJacobianPrep,
358-
backend::AutoEnzyme{<:ReverseMode,Nothing},
359-
x,
360-
) where {F}
361-
y, new_jac = DI.value_and_jacobian(f, prep, backend, x)
362-
return y, copyto!(jac, new_jac)
363-
end
364-
=#

0 commit comments

Comments
 (0)