Skip to content

Commit 2896511

Browse files
authored
feat: use JVPCache for FiniteDiff pushforwards (#705)
* feat: use `JVPCache` for FiniteDiff pushforwards * test both fdtypes * Up * adapt to recent release * Add benchmarks * define logging
1 parent 0ea7f1d commit 2896511

File tree

8 files changed

+290
-89
lines changed

8 files changed

+290
-89
lines changed

DifferentiationInterface/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.37"
4+
version = "0.6.38"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -56,7 +56,7 @@ Enzyme = "0.13.17"
5656
EnzymeCore = "0.8.8"
5757
ExplicitImports = "1.10.1"
5858
FastDifferentiation = "0.4.3"
59-
FiniteDiff = "2.23.1"
59+
FiniteDiff = "2.27.0"
6060
FiniteDifferences = "0.12.31"
6161
ForwardDiff = "0.10.36"
6262
GTPSA = "1.4.0"

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@ using FiniteDiff:
77
GradientCache,
88
HessianCache,
99
JacobianCache,
10+
JVPCache,
1011
finite_difference_derivative,
1112
finite_difference_gradient,
1213
finite_difference_gradient!,
1314
finite_difference_hessian,
1415
finite_difference_hessian!,
1516
finite_difference_jacobian,
1617
finite_difference_jacobian!,
18+
finite_difference_jvp,
19+
finite_difference_jvp!,
1720
default_relstep
1821
using LinearAlgebra: dot, mul!
1922

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
## Pushforward
22

3-
struct FiniteDiffOneArgPushforwardPrep{R,A} <: DI.PushforwardPrep
3+
struct FiniteDiffOneArgPushforwardPrep{C,R,A} <: DI.PushforwardPrep
4+
cache::C
45
relstep::R
56
absstep::A
67
end
78

89
function DI.prepare_pushforward(
910
f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}
1011
) where {C}
12+
fc = DI.with_contexts(f, contexts...)
13+
y = fc(x)
14+
cache = if x isa Number || y isa Number
15+
nothing
16+
else
17+
JVPCache(similar(x), y, fdtype(backend))
18+
end
1119
relstep = if isnothing(backend.relstep)
1220
default_relstep(fdtype(backend), eltype(x))
1321
else
@@ -18,12 +26,12 @@ function DI.prepare_pushforward(
1826
else
1927
backend.relstep
2028
end
21-
return FiniteDiffOneArgPushforwardPrep(relstep, absstep)
29+
return FiniteDiffOneArgPushforwardPrep(cache, relstep, absstep)
2230
end
2331

2432
function DI.pushforward(
2533
f,
26-
prep::FiniteDiffOneArgPushforwardPrep,
34+
prep::FiniteDiffOneArgPushforwardPrep{Nothing},
2735
backend::AutoFiniteDiff,
2836
x,
2937
tx::NTuple,
@@ -41,7 +49,7 @@ end
4149

4250
function DI.value_and_pushforward(
4351
f,
44-
prep::FiniteDiffOneArgPushforwardPrep,
52+
prep::FiniteDiffOneArgPushforwardPrep{Nothing},
4553
backend::AutoFiniteDiff,
4654
x,
4755
tx::NTuple,
@@ -64,6 +72,39 @@ function DI.value_and_pushforward(
6472
return y, ty
6573
end
6674

75+
function DI.pushforward(
76+
f,
77+
prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache},
78+
::AutoFiniteDiff,
79+
x,
80+
tx::NTuple,
81+
contexts::Vararg{DI.Context,C},
82+
) where {C}
83+
(; relstep, absstep) = prep
84+
fc = DI.with_contexts(f, contexts...)
85+
ty = map(tx) do dx
86+
finite_difference_jvp(fc, x, dx, prep.cache; relstep, absstep)
87+
end
88+
return ty
89+
end
90+
91+
function DI.value_and_pushforward(
92+
f,
93+
prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache},
94+
::AutoFiniteDiff,
95+
x,
96+
tx::NTuple,
97+
contexts::Vararg{DI.Context,C},
98+
) where {C}
99+
(; relstep, absstep) = prep
100+
fc = DI.with_contexts(f, contexts...)
101+
y = fc(x)
102+
ty = map(tx) do dx
103+
finite_difference_jvp(fc, x, dx, prep.cache, y; relstep, absstep)
104+
end
105+
return y, ty
106+
end
107+
67108
## Derivative
68109

69110
struct FiniteDiffOneArgDerivativePrep{C,R,A} <: DI.DerivativePrep

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
## Pushforward
22

3-
struct FiniteDiffTwoArgPushforwardPrep{R,A} <: DI.PushforwardPrep
3+
struct FiniteDiffTwoArgPushforwardPrep{C,R,A} <: DI.PushforwardPrep
4+
cache::C
45
relstep::R
56
absstep::A
67
end
78

89
function DI.prepare_pushforward(
910
f!, y, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}
1011
) where {C}
12+
cache = if x isa Number
13+
nothing
14+
else
15+
JVPCache(similar(x), similar(y), fdtype(backend))
16+
end
1117
relstep = if isnothing(backend.relstep)
1218
default_relstep(fdtype(backend), eltype(x))
1319
else
@@ -18,14 +24,13 @@ function DI.prepare_pushforward(
1824
else
1925
backend.relstep
2026
end
21-
return FiniteDiffTwoArgPushforwardPrep(relstep, absstep)
22-
return DI.NoPushforwardPrep()
27+
return FiniteDiffTwoArgPushforwardPrep(cache, relstep, absstep)
2328
end
2429

2530
function DI.value_and_pushforward(
2631
f!,
2732
y,
28-
prep::FiniteDiffTwoArgPushforwardPrep,
33+
prep::FiniteDiffTwoArgPushforwardPrep{Nothing},
2934
backend::AutoFiniteDiff,
3035
x,
3136
tx::NTuple,
@@ -52,6 +57,84 @@ function DI.value_and_pushforward(
5257
return y, ty
5358
end
5459

60+
function DI.pushforward(
61+
f!,
62+
y,
63+
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
64+
::AutoFiniteDiff,
65+
x,
66+
tx::NTuple,
67+
contexts::Vararg{DI.Context,C},
68+
) where {C}
69+
(; relstep, absstep) = prep
70+
fc! = DI.with_contexts(f!, contexts...)
71+
ty = map(tx) do dx
72+
dy = similar(y)
73+
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep)
74+
dy
75+
end
76+
return ty
77+
end
78+
79+
function DI.value_and_pushforward(
80+
f!,
81+
y,
82+
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
83+
::AutoFiniteDiff,
84+
x,
85+
tx::NTuple,
86+
contexts::Vararg{DI.Context,C},
87+
) where {C}
88+
(; relstep, absstep) = prep
89+
fc! = DI.with_contexts(f!, contexts...)
90+
ty = map(tx) do dx
91+
dy = similar(y)
92+
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep)
93+
dy
94+
end
95+
fc!(y, x)
96+
return y, ty
97+
end
98+
99+
function DI.pushforward!(
100+
f!,
101+
y,
102+
ty::NTuple,
103+
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
104+
::AutoFiniteDiff,
105+
x,
106+
tx::NTuple,
107+
contexts::Vararg{DI.Context,C},
108+
) where {C}
109+
(; relstep, absstep) = prep
110+
fc! = DI.with_contexts(f!, contexts...)
111+
for b in eachindex(tx, ty)
112+
dx, dy = tx[b], ty[b]
113+
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep)
114+
end
115+
return ty
116+
end
117+
118+
function DI.value_and_pushforward!(
119+
f!,
120+
y,
121+
ty::NTuple,
122+
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
123+
::AutoFiniteDiff,
124+
x,
125+
tx::NTuple,
126+
contexts::Vararg{DI.Context,C},
127+
) where {C}
128+
(; relstep, absstep) = prep
129+
fc! = DI.with_contexts(f!, contexts...)
130+
for b in eachindex(tx, ty)
131+
dx, dy = tx[b], ty[b]
132+
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep)
133+
end
134+
fc!(y, x)
135+
return y, ty
136+
end
137+
55138
## Derivative
56139

57140
struct FiniteDiffTwoArgDerivativePrep{C,R,A} <: DI.DerivativePrep
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using Pkg
2+
Pkg.add("FiniteDiff")
3+
4+
using ADTypes: ADTypes
5+
using DifferentiationInterface, DifferentiationInterfaceTest
6+
import DifferentiationInterface as DI
7+
import DifferentiationInterfaceTest as DIT
8+
using FiniteDiff: FiniteDiff
9+
using Test
10+
11+
LOGGING = get(ENV, "CI", "false") == "false"
12+
13+
@testset "Benchmarking sparse" begin
14+
filtered_sparse_scenarios = filter(sparse_scenarios(; band_sizes=[])) do scen
15+
DIT.function_place(scen) == :in &&
16+
DIT.operator_place(scen) == :in &&
17+
scen.x isa AbstractVector &&
18+
scen.y isa AbstractVector
19+
end
20+
21+
data = benchmark_differentiation(
22+
MyAutoSparse(AutoFiniteDiff()),
23+
filtered_sparse_scenarios;
24+
benchmark=:prepared,
25+
excluded=SECOND_ORDER,
26+
logging=LOGGING,
27+
)
28+
@testset "Analyzing benchmark results" begin
29+
@testset "$(row[:scenario])" for row in eachrow(data)
30+
@test row[:allocs] == 0
31+
end
32+
end
33+
end

DifferentiationInterface/test/Back/FiniteDiff/test.jl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,32 @@ for backend in [AutoFiniteDiff()]
1717
@test check_inplace(backend)
1818
end
1919

20-
test_differentiation(
21-
AutoFiniteDiff(),
22-
default_scenarios(; include_constantified=true, include_cachified=true);
23-
excluded=[:second_derivative, :hvp],
24-
logging=LOGGING,
25-
);
26-
27-
test_differentiation(
28-
[
29-
AutoFiniteDiff(; relstep=cbrt(eps(Float64))),
30-
AutoFiniteDiff(; relstep=cbrt(eps(Float64)), absstep=cbrt(eps(Float64))),
31-
];
32-
excluded=[:second_derivative, :hvp],
33-
logging=LOGGING,
34-
);
20+
@testset "Dense" begin
21+
test_differentiation(
22+
AutoFiniteDiff(),
23+
default_scenarios(; include_constantified=true, include_cachified=true);
24+
excluded=[:second_derivative, :hvp],
25+
logging=LOGGING,
26+
)
27+
28+
test_differentiation(
29+
[
30+
AutoFiniteDiff(; relstep=cbrt(eps(Float64))),
31+
AutoFiniteDiff(; relstep=cbrt(eps(Float64)), absstep=cbrt(eps(Float64))),
32+
];
33+
excluded=[:second_derivative, :hvp],
34+
logging=LOGGING,
35+
)
36+
end
37+
38+
@testset "Sparse" begin
39+
test_differentiation(
40+
MyAutoSparse(AutoFiniteDiff()),
41+
sparse_scenarios();
42+
excluded=SECOND_ORDER,
43+
logging=LOGGING,
44+
)
45+
end
3546

3647
@testset "Complex" begin
3748
test_differentiation(AutoFiniteDiff(), complex_scenarios(); logging=LOGGING)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using Pkg
2+
Pkg.add("ForwardDiff")
3+
4+
using ADTypes: ADTypes
5+
using DifferentiationInterface, DifferentiationInterfaceTest
6+
import DifferentiationInterface as DI
7+
import DifferentiationInterfaceTest as DIT
8+
using ForwardDiff: ForwardDiff
9+
using StaticArrays: StaticArrays, @SVector
10+
using Test
11+
12+
LOGGING = get(ENV, "CI", "false") == "false"
13+
14+
@testset verbose = true "Benchmarking static" begin
15+
filtered_static_scenarios = filter(static_scenarios(; include_batchified=false)) do scen
16+
DIT.function_place(scen) == :out && DIT.operator_place(scen) == :out
17+
end
18+
data = benchmark_differentiation(
19+
AutoForwardDiff(),
20+
filtered_static_scenarios;
21+
benchmark=:prepared,
22+
excluded=[:hessian, :pullback], # TODO: figure this out
23+
logging=LOGGING,
24+
)
25+
@testset "Analyzing benchmark results" begin
26+
@testset "$(row[:scenario])" for row in eachrow(data)
27+
@test row[:allocs] == 0
28+
end
29+
end
30+
end
31+
32+
@testset "Benchmarking sparse" begin
33+
filtered_sparse_scenarios = filter(sparse_scenarios(; band_sizes=[])) do scen
34+
DIT.function_place(scen) == :in &&
35+
DIT.operator_place(scen) == :in &&
36+
scen.x isa AbstractVector &&
37+
scen.y isa AbstractVector
38+
end
39+
40+
data = benchmark_differentiation(
41+
MyAutoSparse(AutoForwardDiff()),
42+
filtered_sparse_scenarios;
43+
benchmark=:prepared,
44+
excluded=SECOND_ORDER,
45+
logging=LOGGING,
46+
)
47+
@testset "Analyzing benchmark results" begin
48+
@testset "$(row[:scenario])" for row in eachrow(data)
49+
@test row[:allocs] == 0
50+
end
51+
end
52+
end

0 commit comments

Comments
 (0)