Skip to content

Commit 0ea7f1d

Browse files
authored
fix: use Enzyme's native Jacobian in forward mode with constant contexts (#710)
* fix: use Enzyme's native Jacobian in forward mode with constant contexts * Add tests
1 parent bf47700 commit 0ea7f1d

File tree

5 files changed

+70
-27
lines changed

5 files changed

+70
-27
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.36"
4+
version = "0.6.37"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Moreover, each context type is supported by a specific subset of backends:
6161
| `AutoChainRules` |||
6262
| `AutoDiffractor` |||
6363
| `AutoEnzyme` (forward) |||
64-
| `AutoEnzyme` (reverse) || |
64+
| `AutoEnzyme` (reverse) || ❌ (soon) |
6565
| `AutoFastDifferentiation` |||
6666
| `AutoFiniteDiff` |||
6767
| `AutoFiniteDifferences` |||

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,11 @@ function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O}
119119
end
120120

121121
function DI.prepare_gradient(
122-
f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x
123-
) where {F}
122+
f::F,
123+
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
124+
x,
125+
contexts::Vararg{DI.Constant,C},
126+
) where {F,C}
124127
valB = to_val(DI.pick_batchsize(backend, x))
125128
shadows = create_shadows(valB, x)
126129
return EnzymeForwardGradientPrep(valB, shadows)
@@ -131,23 +134,31 @@ function DI.gradient(
131134
prep::EnzymeForwardGradientPrep{B},
132135
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
133136
x,
134-
) where {F,B}
137+
contexts::Vararg{DI.Constant,C},
138+
) where {F,B,C}
135139
mode = forward_noprimal(backend)
136140
f_and_df = get_f_and_df(f, backend, mode)
137-
derivs = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)
138-
return only(derivs)
141+
annotated_contexts = translate(backend, mode, Val(B), contexts...)
142+
derivs = gradient(
143+
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
144+
)
145+
return first(derivs)
139146
end
140147

141148
function DI.value_and_gradient(
142149
f::F,
143150
prep::EnzymeForwardGradientPrep{B},
144151
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
145152
x,
146-
) where {F,B}
153+
contexts::Vararg{DI.Constant,C},
154+
) where {F,B,C}
147155
mode = forward_withprimal(backend)
148156
f_and_df = get_f_and_df(f, backend, mode)
149-
(; derivs, val) = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)
150-
return val, only(derivs)
157+
annotated_contexts = translate(backend, mode, Val(B), contexts...)
158+
(; derivs, val) = gradient(
159+
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
160+
)
161+
return val, first(derivs)
151162
end
152163

153164
function DI.gradient!(
@@ -156,8 +167,9 @@ function DI.gradient!(
156167
prep::EnzymeForwardGradientPrep{B},
157168
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
158169
x,
159-
) where {F,B}
160-
return copyto!(grad, DI.gradient(f, prep, backend, x))
170+
contexts::Vararg{DI.Constant,C},
171+
) where {F,B,C}
172+
return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...))
161173
end
162174

163175
function DI.value_and_gradient!(
@@ -166,8 +178,9 @@ function DI.value_and_gradient!(
166178
prep::EnzymeForwardGradientPrep{B},
167179
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
168180
x,
169-
) where {F,B}
170-
y, new_grad = DI.value_and_gradient(f, prep, backend, x)
181+
contexts::Vararg{DI.Constant,C},
182+
) where {F,B,C}
183+
y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
171184
return y, copyto!(grad, new_grad)
172185
end
173186

@@ -185,9 +198,12 @@ function EnzymeForwardOneArgJacobianPrep(
185198
end
186199

187200
function DI.prepare_jacobian(
188-
f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x
189-
) where {F}
190-
y = f(x)
201+
f::F,
202+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
203+
x,
204+
contexts::Vararg{DI.Constant,C},
205+
) where {F,C}
206+
y = f(x, map(DI.unwrap, contexts)...)
191207
valB = to_val(DI.pick_batchsize(backend, x))
192208
shadows = create_shadows(valB, x)
193209
return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y))
@@ -198,11 +214,15 @@ function DI.jacobian(
198214
prep::EnzymeForwardOneArgJacobianPrep{B},
199215
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
200216
x,
201-
) where {F,B}
217+
contexts::Vararg{DI.Constant,C},
218+
) where {F,B,C}
202219
mode = forward_noprimal(backend)
203220
f_and_df = get_f_and_df(f, backend, mode)
204-
derivs = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)
205-
jac_tensor = only(derivs)
221+
annotated_contexts = translate(backend, mode, Val(B), contexts...)
222+
derivs = jacobian(
223+
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
224+
)
225+
jac_tensor = first(derivs)
206226
return maybe_reshape(jac_tensor, prep.output_length, length(x))
207227
end
208228

@@ -211,11 +231,15 @@ function DI.value_and_jacobian(
211231
prep::EnzymeForwardOneArgJacobianPrep{B},
212232
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
213233
x,
214-
) where {F,B}
234+
contexts::Vararg{DI.Constant,C},
235+
) where {F,B,C}
215236
mode = forward_withprimal(backend)
216237
f_and_df = get_f_and_df(f, backend, mode)
217-
(; derivs, val) = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)
218-
jac_tensor = only(derivs)
238+
annotated_contexts = translate(backend, mode, Val(B), contexts...)
239+
(; derivs, val) = jacobian(
240+
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
241+
)
242+
jac_tensor = first(derivs)
219243
return val, maybe_reshape(jac_tensor, prep.output_length, length(x))
220244
end
221245

@@ -225,8 +249,9 @@ function DI.jacobian!(
225249
prep::EnzymeForwardOneArgJacobianPrep,
226250
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
227251
x,
228-
) where {F}
229-
return copyto!(jac, DI.jacobian(f, prep, backend, x))
252+
contexts::Vararg{DI.Constant,C},
253+
) where {F,C}
254+
return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...))
230255
end
231256

232257
function DI.value_and_jacobian!(
@@ -235,7 +260,8 @@ function DI.value_and_jacobian!(
235260
prep::EnzymeForwardOneArgJacobianPrep,
236261
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
237262
x,
238-
) where {F}
239-
y, new_jac = DI.value_and_jacobian(f, prep, backend, x)
263+
contexts::Vararg{DI.Constant,C},
264+
) where {F,C}
265+
y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...)
240266
return y, copyto!(jac, new_jac)
241267
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ force_annotation(f::F) where {F} = Const(f)
5353
return Const(DI.unwrap(c))
5454
end
5555

56+
@inline function _translate(
57+
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
58+
) where {B}
59+
if B == 1
60+
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))
61+
else
62+
return BatchDuplicated(DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B)))
63+
end
64+
end
65+
5666
@inline function _translate(
5767
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
5868
) where {B}

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ end;
5353
logging=LOGGING,
5454
)
5555

56+
test_differentiation(
57+
backends[2],
58+
default_scenarios(; include_normal=false, include_cachified=true);
59+
excluded=SECOND_ORDER,
60+
logging=LOGGING,
61+
)
62+
5663
test_differentiation(
5764
duplicated_backends,
5865
default_scenarios(; include_normal=false, include_closurified=true);

0 commit comments

Comments
 (0)