Skip to content

Commit 45fad0e

Browse files
authored
fix: function shadows for higher-order Enzyme (#943)
1 parent 325f2b8 commit 45fad0e

File tree

6 files changed

+29
-29
lines changed

6 files changed

+29
-29
lines changed

DifferentiationInterface/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515

1616
### Fixed
1717

18+
- Function shadows for higher-order Enzyme ([#943](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/943))
1819
- Improve wrong-mode pushforward/pullback ([#932](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/932), [#931](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/931))
1920
- Clean up CI ([#926](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/926), [#924](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/924))
2021

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ function DI.prepare_pushforward_nokwarg(
1515
contexts::Vararg{DI.Context, C}
1616
) where {F, C, B}
1717
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
18-
df = function_shadow(f, backend, Val(B))
1918
mode = forward_withprimal(backend)
19+
df = function_shadow(f, backend, mode, Val(B))
2020
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
2121
return EnzymeOneArgPushforwardPrep(_sig, df, context_shadows)
2222
end
@@ -146,8 +146,8 @@ function DI.prepare_gradient_nokwarg(
146146
) where {F, C}
147147
_sig = DI.signature(f, backend, x, contexts...; strict)
148148
valB = to_val(DI.pick_batchsize(backend, x))
149-
df = function_shadow(f, backend, valB)
150149
mode = forward_withprimal(backend)
150+
df = function_shadow(f, backend, mode, valB)
151151
context_shadows = make_context_shadows(backend, mode, valB, contexts...)
152152
basis_shadows = create_shadows(valB, x)
153153
return EnzymeForwardGradientPrep(_sig, valB, df, context_shadows, basis_shadows)
@@ -236,7 +236,7 @@ function DI.prepare_jacobian_nokwarg(
236236
y = f(x, map(DI.unwrap, contexts)...)
237237
valB = to_val(DI.pick_batchsize(backend, x))
238238
mode = forward_withprimal(backend)
239-
df = function_shadow(f, backend, valB)
239+
df = function_shadow(f, backend, mode, valB)
240240
context_shadows = make_context_shadows(backend, mode, valB, contexts...)
241241
basis_shadows = create_shadows(valB, x)
242242
return EnzymeForwardOneArgJacobianPrep(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ function DI.prepare_pushforward_nokwarg(
1616
contexts::Vararg{DI.Context, C}
1717
) where {F, B, C}
1818
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
19-
df! = function_shadow(f!, backend, Val(B))
2019
mode = forward_noprimal(backend)
20+
df! = function_shadow(f!, backend, mode, Val(B))
2121
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
2222
return EnzymeTwoArgPushforwardPrep(_sig, df!, context_shadows)
2323
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ function DI.prepare_pullback_nokwarg(
6363
contexts::Vararg{DI.Context, C}
6464
) where {F, B, C}
6565
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
66-
df = function_shadow(f, backend, Val(B))
6766
mode = reverse_split_withprimal(backend)
67+
df = function_shadow(f, backend, mode, Val(B))
6868
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
6969
y = f(x, map(DI.unwrap, contexts)...)
7070
return EnzymeReverseOneArgPullbackPrep(_sig, df, context_shadows, y)
@@ -214,8 +214,8 @@ function DI.prepare_gradient_nokwarg(
214214
contexts::Vararg{DI.Context, C}
215215
) where {F, C}
216216
_sig = DI.signature(f, backend, x, contexts...; strict)
217-
df = function_shadow(f, backend, Val(1))
218217
mode = reverse_withprimal(backend)
218+
df = function_shadow(f, backend, mode, Val(1))
219219
context_shadows = make_context_shadows(backend, mode, Val(1), contexts...)
220220
return EnzymeGradientPrep(_sig, df, context_shadows)
221221
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ function DI.prepare_pullback_nokwarg(
1717
contexts::Vararg{DI.Context, C}
1818
) where {F, B, C}
1919
_sig = DI.signature(f!, y, backend, x, ty, contexts...; strict)
20-
df! = function_shadow(f!, backend, Val(B))
2120
mode = reverse_noprimal(backend)
21+
df! = function_shadow(f!, backend, mode, Val(B))
2222
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
2323
ty_copy = map(copy, ty)
2424
return EnzymeReverseTwoArgPullbackPrep(_sig, df!, context_shadows, ty_copy)

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,33 @@ end
2828
function get_f_and_df_prepared!(
2929
df, f::F, ::AutoEnzyme{M, <:AnyDuplicated}, ::Val{B}
3030
) where {F, M, B}
31-
#=
32-
It is not obvious why we don't need a `make_zero` here, in the case of mutable constant data in `f`.
33-
- In forward mode, `df` is never incremented if `f` is not mutated, so it remains equal to its initial value of `0`.
34-
- In reverse mode, `df` gets incremented but it does not influence the input cotangent `dx`.
35-
=#
36-
if B == 1
37-
return Duplicated(f, df)
31+
if isnothing(df)
32+
return Const(f)
3833
else
39-
return BatchDuplicated(f, df)
34+
if B == 1
35+
return Duplicated(f, df)
36+
else
37+
return BatchDuplicated(f, df)
38+
end
4039
end
4140
end
4241

4342
function function_shadow(
44-
::F, ::AutoEnzyme{M, <:Union{Const, Nothing}}, ::Val{B}
43+
::F, ::AutoEnzyme{M, <:Union{Const, Nothing}}, ::Mode, ::Val{B}
4544
) where {M, B, F}
4645
return nothing
4746
end
4847

49-
function function_shadow(f::F, ::AutoEnzyme{M, <:AnyDuplicated}, ::Val{B}) where {F, M, B}
50-
if B == 1
51-
return make_zero(f)
48+
function function_shadow(f::F, ::AutoEnzyme{M, <:AnyDuplicated}, mode::Mode, ::Val{B}) where {F, M, B}
49+
IA = guess_activity(F, mode)
50+
return if IA <: Const
51+
nothing
5252
else
53-
return ntuple(_ -> make_zero(f), Val(B))
53+
if B == 1
54+
return make_zero(f)
55+
else
56+
return ntuple(_ -> make_zero(f), Val(B))
57+
end
5458
end
5559
end
5660

@@ -87,13 +91,13 @@ function _shadow(
8791
end
8892

8993
function _shadow(
90-
backend::AutoEnzyme{M, <:Union{Const, Nothing}},
91-
::Mode,
94+
backend::AutoEnzyme,
95+
mode::Mode,
9296
::Val{B},
9397
c_wrapped::DI.FunctionContext,
94-
) where {M, B}
98+
) where {B}
9599
f = DI.unwrap(c_wrapped)
96-
return function_shadow(f, backend, Val(B))
100+
return function_shadow(f, backend, mode, Val(B))
97101
end
98102

99103
function make_context_shadows(
@@ -122,11 +126,6 @@ end
122126
function _translate_prepared!(
123127
dc, c_wrapped::Union{DI.ConstantOrCache, DI.FunctionContext}, ::Val{B}
124128
) where {B}
125-
#=
126-
It is not obvious why we don't need a `make_zero` here, in the case of mutable constant contexts.
127-
- In forward mode, `dc` is never incremented because `c` is not mutated, so it remains equal to its initial value of `0`.
128-
- In reverse mode, `dc` gets incremented but it does not influence the input cotangent `dx`.
129-
=#
130129
c = DI.unwrap(c_wrapped)
131130
if isnothing(dc)
132131
return Const(c)

0 commit comments

Comments
 (0)