Skip to content

Commit 1a359ae

Browse files
authored
perf: HVP with in-place gradient + inner preparation (#743)
* perf: compute in-place HVP from in-place gradient * Shuffled gradient without prep * Inplace true by default * Fix FromPrimitive tests * Fix Zygote * Fix * Avoid prep * Inner HVP preparation * No fail fast * Appease JET * ForwardDiff test dep * Fix PolyesterForwardDiff * Improve coverage * Codecov * Fail fast toggle
1 parent 6ae9532 commit 1a359ae

File tree

30 files changed

+815
-251
lines changed

30 files changed

+815
-251
lines changed

DifferentiationInterface/Project.toml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
3939
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
4040
DifferentiationInterfaceGTPSAExt = "GTPSA"
4141
DifferentiationInterfaceMooncakeExt = "Mooncake"
42-
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
42+
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
4343
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
4444
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
4545
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
@@ -109,4 +109,21 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
109109
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
110110

111111
[targets]
112-
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"]
112+
test = [
113+
"ADTypes",
114+
"Aqua",
115+
"ComponentArrays",
116+
"DataFrames",
117+
"ExplicitImports",
118+
"JET",
119+
"JLArrays",
120+
"JuliaFormatter",
121+
"Pkg",
122+
"Random",
123+
"SparseArrays",
124+
"SparseConnectivityTracer",
125+
"SparseMatrixColorings",
126+
"StableRNGs",
127+
"StaticArrays",
128+
"Test",
129+
]

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
66
end
77

88
function DI.prepare_pullback(
9-
f,
10-
::AutoReverseChainRules,
11-
x,
12-
ty::NTuple,
13-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
9+
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}
1410
) where {C}
1511
return DI.NoPullbackPrep()
1612
end
@@ -21,7 +17,7 @@ function DI.prepare_pullback_same_point(
2117
backend::AutoReverseChainRules,
2218
x,
2319
ty::NTuple,
24-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
20+
contexts::Vararg{DI.GeneralizedConstant,C},
2521
) where {C}
2622
rc = ruleconfig(backend)
2723
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
@@ -34,7 +30,7 @@ function DI.value_and_pullback(
3430
backend::AutoReverseChainRules,
3531
x,
3632
ty::NTuple,
37-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
33+
contexts::Vararg{DI.GeneralizedConstant,C},
3834
) where {C}
3935
rc = ruleconfig(backend)
4036
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
@@ -50,7 +46,7 @@ function DI.value_and_pullback(
5046
::AutoReverseChainRules,
5147
x,
5248
ty::NTuple,
53-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
49+
contexts::Vararg{DI.GeneralizedConstant,C},
5450
) where {C}
5551
(; y, pb) = prep
5652
tx = map(ty) do dy
@@ -65,7 +61,7 @@ function DI.pullback(
6561
::AutoReverseChainRules,
6662
x,
6763
ty::NTuple,
68-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
64+
contexts::Vararg{DI.GeneralizedConstant,C},
6965
) where {C}
7066
(; pb) = prep
7167
tx = map(ty) do dy

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ force_annotation(f::F) where {F<:Annotation} = f
4848
force_annotation(f::F) where {F} = Const(f)
4949

5050
@inline function _translate(
51-
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext}
51+
::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedConstant
5252
) where {B}
5353
return Const(DI.unwrap(c))
5454
end
5555

5656
@inline function _translate(
57-
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
57+
::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache
5858
) where {B}
5959
if B == 1
6060
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using FiniteDiff:
2121
using LinearAlgebra: dot, mul!
2222

2323
DI.check_available(::AutoFiniteDiff) = true
24+
DI.inner_preparation_behavior(::AutoFiniteDiff) = DI.PrepareInnerSimple()
2425

2526
# see https://github.com/SciML/ADTypes.jl/issues/33
2627

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using LinearAlgebra: dot
77

88
DI.check_available(::AutoFiniteDifferences) = true
99
DI.inplace_support(::AutoFiniteDifferences) = DI.InPlaceNotSupported()
10+
DI.inner_preparation_behavior(::AutoFiniteDifferences) = DI.PrepareInnerSimple()
1011

1112
## Pushforward
1213

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ using ForwardDiff:
2828
value
2929

3030
DI.check_available(::AutoForwardDiff) = true
31+
DI.inner_preparation_behavior(::AutoForwardDiff) = DI.PrepareInnerOverload()
3132

3233
include("utils.jl")
3334
include("onearg.jl")

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@
22
DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp)
33
DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp)
44

5+
function DI.overloaded_input(
6+
::typeof(DI.pushforward), f::F, backend::AutoForwardDiff, x, tx::NTuple{B}
7+
) where {F,B}
8+
T = tag_type(f, backend, x)
9+
xdual = make_dual(T, x, tx)
10+
return xdual
11+
end
12+
13+
function DI.overloaded_input(
14+
::typeof(DI.pushforward), f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B}
15+
) where {F,B}
16+
T = tag_type(f!, backend, x)
17+
xdual = make_dual(T, x, tx)
18+
return xdual
19+
end
20+
521
## Derivative
622
function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep)
723
return DI.overloaded_input_type(prep.pushforward_prep)

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ function DI.value_and_gradient!(
272272
if (
273273
isnothing(chunksize) &&
274274
T === Nothing &&
275-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
275+
contexts isa NTuple{C,DI.GeneralizedConstant}
276276
)
277277
fc = DI.with_contexts(f, contexts...)
278278
result = DiffResult(zero(eltype(x)), (grad,))
@@ -292,7 +292,7 @@ function DI.value_and_gradient(
292292
if (
293293
isnothing(chunksize) &&
294294
T === Nothing &&
295-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
295+
contexts isa NTuple{C,DI.GeneralizedConstant}
296296
)
297297
fc = DI.with_contexts(f, contexts...)
298298
result = GradientResult(x)
@@ -310,7 +310,7 @@ function DI.gradient!(
310310
if (
311311
isnothing(chunksize) &&
312312
T === Nothing &&
313-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
313+
contexts isa NTuple{C,DI.GeneralizedConstant}
314314
)
315315
fc = DI.with_contexts(f, contexts...)
316316
return gradient!(grad, fc, x)
@@ -326,7 +326,7 @@ function DI.gradient(
326326
if (
327327
isnothing(chunksize) &&
328328
T === Nothing &&
329-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
329+
contexts isa NTuple{C,DI.GeneralizedConstant}
330330
)
331331
fc = DI.with_contexts(f, contexts...)
332332
return gradient(fc, x)
@@ -435,7 +435,7 @@ function DI.value_and_jacobian!(
435435
if (
436436
isnothing(chunksize) &&
437437
T === Nothing &&
438-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
438+
contexts isa NTuple{C,DI.GeneralizedConstant}
439439
)
440440
fc = DI.with_contexts(f, contexts...)
441441
y = fc(x)
@@ -456,7 +456,7 @@ function DI.value_and_jacobian(
456456
if (
457457
isnothing(chunksize) &&
458458
T === Nothing &&
459-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
459+
contexts isa NTuple{C,DI.GeneralizedConstant}
460460
)
461461
fc = DI.with_contexts(f, contexts...)
462462
return fc(x), jacobian(fc, x)
@@ -472,7 +472,7 @@ function DI.jacobian!(
472472
if (
473473
isnothing(chunksize) &&
474474
T === Nothing &&
475-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
475+
contexts isa NTuple{C,DI.GeneralizedConstant}
476476
)
477477
fc = DI.with_contexts(f, contexts...)
478478
return jacobian!(jac, fc, x)
@@ -488,7 +488,7 @@ function DI.jacobian(
488488
if (
489489
isnothing(chunksize) &&
490490
T === Nothing &&
491-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
491+
contexts isa NTuple{C,DI.GeneralizedConstant}
492492
)
493493
fc = DI.with_contexts(f, contexts...)
494494
return jacobian(fc, x)
@@ -738,7 +738,7 @@ function DI.hessian!(
738738
if (
739739
isnothing(chunksize) &&
740740
T === Nothing &&
741-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
741+
contexts isa NTuple{C,DI.GeneralizedConstant}
742742
)
743743
fc = DI.with_contexts(f, contexts...)
744744
return hessian!(hess, fc, x)
@@ -754,7 +754,7 @@ function DI.hessian(
754754
if (
755755
isnothing(chunksize) &&
756756
T === Nothing &&
757-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
757+
contexts isa NTuple{C,DI.GeneralizedConstant}
758758
)
759759
fc = DI.with_contexts(f, contexts...)
760760
return hessian(fc, x)
@@ -775,7 +775,7 @@ function DI.value_gradient_and_hessian!(
775775
if (
776776
isnothing(chunksize) &&
777777
T === Nothing &&
778-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
778+
contexts isa NTuple{C,DI.GeneralizedConstant}
779779
)
780780
fc = DI.with_contexts(f, contexts...)
781781
result = DiffResult(one(eltype(x)), (grad, hess))
@@ -796,7 +796,7 @@ function DI.value_gradient_and_hessian(
796796
if (
797797
isnothing(chunksize) &&
798798
T === Nothing &&
799-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
799+
contexts isa NTuple{C,DI.GeneralizedConstant}
800800
)
801801
fc = DI.with_contexts(f, contexts...)
802802
result = HessianResult(x)

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ end
117117
function DI.value_and_derivative(
118118
f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
119119
) where {F,C,chunksize,T}
120-
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
120+
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
121121
fc! = DI.with_contexts(f!, contexts...)
122122
result = MutableDiffResult(y, (similar(y),))
123123
result = derivative!(result, fc!, y, x)
@@ -131,7 +131,7 @@ end
131131
function DI.value_and_derivative!(
132132
f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
133133
) where {F,C,chunksize,T}
134-
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
134+
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
135135
fc! = DI.with_contexts(f!, contexts...)
136136
result = MutableDiffResult(y, (der,))
137137
result = derivative!(result, fc!, y, x)
@@ -145,7 +145,7 @@ end
145145
function DI.derivative(
146146
f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
147147
) where {F,C,chunksize,T}
148-
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
148+
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
149149
fc! = DI.with_contexts(f!, contexts...)
150150
return derivative(fc!, y, x)
151151
else
@@ -157,7 +157,7 @@ end
157157
function DI.derivative!(
158158
f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
159159
) where {F,C,chunksize,T}
160-
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
160+
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
161161
fc! = DI.with_contexts(f!, contexts...)
162162
return derivative!(der, fc!, y, x)
163163
else
@@ -188,7 +188,7 @@ function DI.prepare!_derivative(
188188
old_prep::ForwardDiffTwoArgDerivativePrep,
189189
backend::AutoForwardDiff,
190190
x,
191-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
191+
contexts::Vararg{DI.GeneralizedConstant,C},
192192
) where {F,C}
193193
if y isa Vector
194194
(; config) = old_prep
@@ -283,7 +283,7 @@ function DI.value_and_jacobian(
283283
if (
284284
isnothing(chunksize) &&
285285
T === Nothing &&
286-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
286+
contexts isa NTuple{C,DI.GeneralizedConstant}
287287
)
288288
fc! = DI.with_contexts(f!, contexts...)
289289
jac = similar(y, length(y), length(x))
@@ -302,7 +302,7 @@ function DI.value_and_jacobian!(
302302
if (
303303
isnothing(chunksize) &&
304304
T === Nothing &&
305-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
305+
contexts isa NTuple{C,DI.GeneralizedConstant}
306306
)
307307
fc! = DI.with_contexts(f!, contexts...)
308308
result = MutableDiffResult(y, (jac,))
@@ -320,7 +320,7 @@ function DI.jacobian(
320320
if (
321321
isnothing(chunksize) &&
322322
T === Nothing &&
323-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
323+
contexts isa NTuple{C,DI.GeneralizedConstant}
324324
)
325325
fc! = DI.with_contexts(f!, contexts...)
326326
return jacobian(fc!, y, x)
@@ -336,7 +336,7 @@ function DI.jacobian!(
336336
if (
337337
isnothing(chunksize) &&
338338
T === Nothing &&
339-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
339+
contexts isa NTuple{C,DI.GeneralizedConstant}
340340
)
341341
fc! = DI.with_contexts(f!, contexts...)
342342
return jacobian!(jac, fc!, y, x)
@@ -369,7 +369,7 @@ function DI.prepare!_jacobian(
369369
old_prep::ForwardDiffTwoArgJacobianPrep,
370370
backend::AutoForwardDiff,
371371
x,
372-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
372+
contexts::Vararg{DI.GeneralizedConstant,C},
373373
) where {F,C}
374374
if x isa Vector && y isa Vector
375375
(; config) = old_prep

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,11 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B}
8282
return ty
8383
end
8484

85-
# store preparation result with the right input eltype
86-
struct PrepContext{T<:DI.Prep} <: DI.Context
87-
data::T
85+
function _translate(
86+
::Type{D}, c::Union{DI.GeneralizedConstant,DI.PrepContext}
87+
) where {D<:Dual}
88+
return DI.unwrap(c)
8889
end
89-
90-
NotCache = Union{DI.ConstantOrFunctionOrBackend,PrepContext}
91-
92-
_translate(::Type{D}, c::NotCache) where {D<:Dual} = DI.unwrap(c)
9390
function _translate(::Type{D}, c::DI.Cache) where {D<:Dual}
9491
c0 = DI.unwrap(c)
9592
return similar(c0, D)
@@ -102,7 +99,11 @@ function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}
10299
return new_contexts
103100
end
104101

105-
_translate_toprep(::Type{D}, c::NotCache) where {D<:Dual} = nothing
102+
function _translate_toprep(
103+
::Type{D}, c::Union{DI.GeneralizedConstant,DI.PrepContext}
104+
) where {D<:Dual}
105+
return nothing
106+
end
106107
function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual}
107108
c0 = DI.unwrap(c)
108109
return similar(c0, D)
@@ -115,7 +116,7 @@ function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:D
115116
return new_contexts
116117
end
117118

118-
_translate_prepared(c::NotCache, _pc) = DI.unwrap(c)
119+
_translate_prepared(c::Union{DI.GeneralizedConstant,DI.PrepContext}, _pc) = DI.unwrap(c)
119120
_translate_prepared(_c::DI.Cache, pc) = pc
120121

121122
function translate_prepared(

0 commit comments

Comments
 (0)