Skip to content

Commit 58605dc

Browse files
authored
feat: use dir backend field for FiniteDiff (#727)
1 parent 1a95c7c commit 58605dc

File tree

4 files changed

+98
-75
lines changed

4 files changed

+98
-75
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
4848
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]
4949

5050
[compat]
51-
ADTypes = "1.12.1"
51+
ADTypes = "1.13.0"
5252
ChainRulesCore = "1.23.0"
5353
DiffResults = "1.1.0"
5454
Diffractor = "=0.2.6"

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
## Pushforward
22

3-
struct FiniteDiffOneArgPushforwardPrep{C,R,A} <: DI.PushforwardPrep
3+
struct FiniteDiffOneArgPushforwardPrep{C,R,A,D} <: DI.PushforwardPrep
44
cache::C
55
relstep::R
66
absstep::A
7+
dir::D
78
end
89

910
function DI.prepare_pushforward(
@@ -26,7 +27,8 @@ function DI.prepare_pushforward(
2627
else
2728
backend.relstep
2829
end
29-
return FiniteDiffOneArgPushforwardPrep(cache, relstep, absstep)
30+
dir = backend.dir
31+
return FiniteDiffOneArgPushforwardPrep(cache, relstep, absstep, dir)
3032
end
3133

3234
function DI.pushforward(
@@ -37,11 +39,11 @@ function DI.pushforward(
3739
tx::NTuple,
3840
contexts::Vararg{DI.Context,C},
3941
) where {C}
40-
(; relstep, absstep) = prep
42+
(; relstep, absstep, dir) = prep
4143
step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...)
4244
ty = map(tx) do dx
4345
finite_difference_derivative(
44-
Base.Fix2(step, dx), zero(eltype(x)), fdtype(backend); relstep, absstep
46+
Base.Fix2(step, dx), zero(eltype(x)), fdtype(backend); relstep, absstep, dir
4547
)
4648
end
4749
return ty
@@ -55,7 +57,7 @@ function DI.value_and_pushforward(
5557
tx::NTuple,
5658
contexts::Vararg{DI.Context,C},
5759
) where {C}
58-
(; relstep, absstep) = prep
60+
(; relstep, absstep, dir) = prep
5961
step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...)
6062
y = f(x, map(DI.unwrap, contexts)...)
6163
ty = map(tx) do dx
@@ -67,6 +69,7 @@ function DI.value_and_pushforward(
6769
y;
6870
relstep,
6971
absstep,
72+
dir,
7073
)
7174
end
7275
return y, ty
@@ -80,10 +83,10 @@ function DI.pushforward(
8083
tx::NTuple,
8184
contexts::Vararg{DI.Context,C},
8285
) where {C}
83-
(; relstep, absstep) = prep
86+
(; relstep, absstep, dir) = prep
8487
fc = DI.with_contexts(f, contexts...)
8588
ty = map(tx) do dx
86-
finite_difference_jvp(fc, x, dx, prep.cache; relstep, absstep)
89+
finite_difference_jvp(fc, x, dx, prep.cache; relstep, absstep, dir)
8790
end
8891
return ty
8992
end
@@ -96,21 +99,22 @@ function DI.value_and_pushforward(
9699
tx::NTuple,
97100
contexts::Vararg{DI.Context,C},
98101
) where {C}
99-
(; relstep, absstep) = prep
102+
(; relstep, absstep, dir) = prep
100103
fc = DI.with_contexts(f, contexts...)
101104
y = fc(x)
102105
ty = map(tx) do dx
103-
finite_difference_jvp(fc, x, dx, prep.cache, y; relstep, absstep)
106+
finite_difference_jvp(fc, x, dx, prep.cache, y; relstep, absstep, dir)
104107
end
105108
return y, ty
106109
end
107110

108111
## Derivative
109112

110-
struct FiniteDiffOneArgDerivativePrep{C,R,A} <: DI.DerivativePrep
113+
struct FiniteDiffOneArgDerivativePrep{C,R,A,D} <: DI.DerivativePrep
111114
cache::C
112115
relstep::R
113116
absstep::A
117+
dir::D
114118
end
115119

116120
function DI.prepare_derivative(
@@ -134,7 +138,8 @@ function DI.prepare_derivative(
134138
else
135139
backend.relstep
136140
end
137-
return FiniteDiffOneArgDerivativePrep(cache, relstep, absstep)
141+
dir = backend.dir
142+
return FiniteDiffOneArgDerivativePrep(cache, relstep, absstep, dir)
138143
end
139144

140145
### Scalar to scalar
@@ -146,9 +151,9 @@ function DI.derivative(
146151
x,
147152
contexts::Vararg{DI.Context,C},
148153
) where {C}
149-
(; relstep, absstep) = prep
154+
(; relstep, absstep, dir) = prep
150155
fc = DI.with_contexts(f, contexts...)
151-
return finite_difference_derivative(fc, x, fdtype(backend); relstep, absstep)
156+
return finite_difference_derivative(fc, x, fdtype(backend); relstep, absstep, dir)
152157
end
153158

154159
function DI.value_and_derivative(
@@ -158,13 +163,13 @@ function DI.value_and_derivative(
158163
x,
159164
contexts::Vararg{DI.Context,C},
160165
) where {C}
161-
(; relstep, absstep) = prep
166+
(; relstep, absstep, dir) = prep
162167
fc = DI.with_contexts(f, contexts...)
163168
y = fc(x)
164169
return (
165170
y,
166171
finite_difference_derivative(
167-
fc, x, fdtype(backend), eltype(y), y; relstep, absstep
172+
fc, x, fdtype(backend), eltype(y), y; relstep, absstep, dir
168173
),
169174
)
170175
end
@@ -178,9 +183,9 @@ function DI.derivative(
178183
x,
179184
contexts::Vararg{DI.Context,C},
180185
) where {C}
181-
(; relstep, absstep) = prep
186+
(; relstep, absstep, dir) = prep
182187
fc = DI.with_contexts(f, contexts...)
183-
return finite_difference_gradient(fc, x, prep.cache; relstep, absstep)
188+
return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir)
184189
end
185190

186191
function DI.derivative!(
@@ -191,9 +196,9 @@ function DI.derivative!(
191196
x,
192197
contexts::Vararg{DI.Context,C},
193198
) where {C}
194-
(; relstep, absstep) = prep
199+
(; relstep, absstep, dir) = prep
195200
fc = DI.with_contexts(f, contexts...)
196-
return finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep)
201+
return finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir)
197202
end
198203

199204
function DI.value_and_derivative(
@@ -204,9 +209,9 @@ function DI.value_and_derivative(
204209
contexts::Vararg{DI.Context,C},
205210
) where {C}
206211
fc = DI.with_contexts(f, contexts...)
207-
(; relstep, absstep) = prep
212+
(; relstep, absstep, dir) = prep
208213
y = fc(x)
209-
return (y, finite_difference_gradient(fc, x, prep.cache; relstep, absstep))
214+
return (y, finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir))
210215
end
211216

212217
function DI.value_and_derivative!(
@@ -217,17 +222,20 @@ function DI.value_and_derivative!(
217222
x,
218223
contexts::Vararg{DI.Context,C},
219224
) where {C}
220-
(; relstep, absstep) = prep
225+
(; relstep, absstep, dir) = prep
221226
fc = DI.with_contexts(f, contexts...)
222-
return (fc(x), finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep))
227+
return (
228+
fc(x), finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir)
229+
)
223230
end
224231

225232
## Gradient
226233

227-
struct FiniteDiffGradientPrep{C,R,A} <: DI.GradientPrep
234+
struct FiniteDiffGradientPrep{C,R,A,D} <: DI.GradientPrep
228235
cache::C
229236
relstep::R
230237
absstep::A
238+
dir::D
231239
end
232240

233241
function DI.prepare_gradient(
@@ -247,7 +255,8 @@ function DI.prepare_gradient(
247255
else
248256
backend.relstep
249257
end
250-
return FiniteDiffGradientPrep(cache, relstep, absstep)
258+
dir = backend.dir
259+
return FiniteDiffGradientPrep(cache, relstep, absstep, dir)
251260
end
252261

253262
function DI.gradient(
@@ -257,9 +266,9 @@ function DI.gradient(
257266
x::AbstractArray,
258267
contexts::Vararg{DI.Context,C},
259268
) where {C}
260-
(; relstep, absstep) = prep
269+
(; relstep, absstep, dir) = prep
261270
fc = DI.with_contexts(f, contexts...)
262-
return finite_difference_gradient(fc, x, prep.cache; relstep, absstep)
271+
return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir)
263272
end
264273

265274
function DI.value_and_gradient(
@@ -269,9 +278,9 @@ function DI.value_and_gradient(
269278
x::AbstractArray,
270279
contexts::Vararg{DI.Context,C},
271280
) where {C}
272-
(; relstep, absstep) = prep
281+
(; relstep, absstep, dir) = prep
273282
fc = DI.with_contexts(f, contexts...)
274-
return fc(x), finite_difference_gradient(fc, x, prep.cache; relstep, absstep)
283+
return fc(x), finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir)
275284
end
276285

277286
function DI.gradient!(
@@ -282,9 +291,9 @@ function DI.gradient!(
282291
x::AbstractArray,
283292
contexts::Vararg{DI.Context,C},
284293
) where {C}
285-
(; relstep, absstep) = prep
294+
(; relstep, absstep, dir) = prep
286295
fc = DI.with_contexts(f, contexts...)
287-
return finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep)
296+
return finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir)
288297
end
289298

290299
function DI.value_and_gradient!(
@@ -295,17 +304,20 @@ function DI.value_and_gradient!(
295304
x::AbstractArray,
296305
contexts::Vararg{DI.Context,C},
297306
) where {C}
298-
(; relstep, absstep) = prep
307+
(; relstep, absstep, dir) = prep
299308
fc = DI.with_contexts(f, contexts...)
300-
return (fc(x), finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep))
309+
return (
310+
fc(x), finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir)
311+
)
301312
end
302313

303314
## Jacobian
304315

305-
struct FiniteDiffOneArgJacobianPrep{C,R,A} <: DI.JacobianPrep
316+
struct FiniteDiffOneArgJacobianPrep{C,R,A,D} <: DI.JacobianPrep
306317
cache::C
307318
relstep::R
308319
absstep::A
320+
dir::D
309321
end
310322

311323
function DI.prepare_jacobian(
@@ -327,7 +339,8 @@ function DI.prepare_jacobian(
327339
else
328340
backend.relstep
329341
end
330-
return FiniteDiffOneArgJacobianPrep(cache, relstep, absstep)
342+
dir = backend.dir
343+
return FiniteDiffOneArgJacobianPrep(cache, relstep, absstep, dir)
331344
end
332345

333346
function DI.jacobian(
@@ -337,9 +350,9 @@ function DI.jacobian(
337350
x,
338351
contexts::Vararg{DI.Context,C},
339352
) where {C}
340-
(; relstep, absstep) = prep
353+
(; relstep, absstep, dir) = prep
341354
fc = DI.with_contexts(f, contexts...)
342-
return finite_difference_jacobian(fc, x, prep.cache; relstep, absstep)
355+
return finite_difference_jacobian(fc, x, prep.cache; relstep, absstep, dir)
343356
end
344357

345358
function DI.value_and_jacobian(
@@ -350,9 +363,9 @@ function DI.value_and_jacobian(
350363
contexts::Vararg{DI.Context,C},
351364
) where {C}
352365
fc = DI.with_contexts(f, contexts...)
353-
(; relstep, absstep) = prep
366+
(; relstep, absstep, dir) = prep
354367
y = fc(x)
355-
return (y, finite_difference_jacobian(fc, x, prep.cache, y; relstep, absstep))
368+
return (y, finite_difference_jacobian(fc, x, prep.cache, y; relstep, absstep, dir))
356369
end
357370

358371
function DI.jacobian!(
@@ -363,11 +376,13 @@ function DI.jacobian!(
363376
x,
364377
contexts::Vararg{DI.Context,C},
365378
) where {C}
366-
(; relstep, absstep) = prep
379+
(; relstep, absstep, dir) = prep
367380
fc = DI.with_contexts(f, contexts...)
368381
return copyto!(
369382
jac,
370-
finite_difference_jacobian(fc, x, prep.cache; jac_prototype=jac, relstep, absstep),
383+
finite_difference_jacobian(
384+
fc, x, prep.cache; jac_prototype=jac, relstep, absstep, dir
385+
),
371386
)
372387
end
373388

@@ -379,15 +394,15 @@ function DI.value_and_jacobian!(
379394
x,
380395
contexts::Vararg{DI.Context,C},
381396
) where {C}
382-
(; relstep, absstep) = prep
397+
(; relstep, absstep, dir) = prep
383398
fc = DI.with_contexts(f, contexts...)
384399
y = fc(x)
385400
return (
386401
y,
387402
copyto!(
388403
jac,
389404
finite_difference_jacobian(
390-
fc, x, prep.cache, y; jac_prototype=jac, relstep, absstep
405+
fc, x, prep.cache, y; jac_prototype=jac, relstep, absstep, dir
391406
),
392407
),
393408
)

0 commit comments

Comments
 (0)