1
1
# # Pushforward
2
2
3
- struct FiniteDiffOneArgPushforwardPrep{C,R,A} <: DI.PushforwardPrep
3
+ struct FiniteDiffOneArgPushforwardPrep{C,R,A,D } <: DI.PushforwardPrep
4
4
cache:: C
5
5
relstep:: R
6
6
absstep:: A
7
+ dir:: D
7
8
end
8
9
9
10
function DI. prepare_pushforward (
@@ -26,7 +27,8 @@ function DI.prepare_pushforward(
26
27
else
27
28
backend. relstep
28
29
end
29
- return FiniteDiffOneArgPushforwardPrep (cache, relstep, absstep)
30
+ dir = backend. dir
31
+ return FiniteDiffOneArgPushforwardPrep (cache, relstep, absstep, dir)
30
32
end
31
33
32
34
function DI. pushforward (
@@ -37,11 +39,11 @@ function DI.pushforward(
37
39
tx:: NTuple ,
38
40
contexts:: Vararg{DI.Context,C} ,
39
41
) where {C}
40
- (; relstep, absstep) = prep
42
+ (; relstep, absstep, dir ) = prep
41
43
step (t:: Number , dx) = f (x .+ t .* dx, map (DI. unwrap, contexts)... )
42
44
ty = map (tx) do dx
43
45
finite_difference_derivative (
44
- Base. Fix2 (step, dx), zero (eltype (x)), fdtype (backend); relstep, absstep
46
+ Base. Fix2 (step, dx), zero (eltype (x)), fdtype (backend); relstep, absstep, dir
45
47
)
46
48
end
47
49
return ty
@@ -55,7 +57,7 @@ function DI.value_and_pushforward(
55
57
tx:: NTuple ,
56
58
contexts:: Vararg{DI.Context,C} ,
57
59
) where {C}
58
- (; relstep, absstep) = prep
60
+ (; relstep, absstep, dir ) = prep
59
61
step (t:: Number , dx) = f (x .+ t .* dx, map (DI. unwrap, contexts)... )
60
62
y = f (x, map (DI. unwrap, contexts)... )
61
63
ty = map (tx) do dx
@@ -67,6 +69,7 @@ function DI.value_and_pushforward(
67
69
y;
68
70
relstep,
69
71
absstep,
72
+ dir,
70
73
)
71
74
end
72
75
return y, ty
@@ -80,10 +83,10 @@ function DI.pushforward(
80
83
tx:: NTuple ,
81
84
contexts:: Vararg{DI.Context,C} ,
82
85
) where {C}
83
- (; relstep, absstep) = prep
86
+ (; relstep, absstep, dir ) = prep
84
87
fc = DI. with_contexts (f, contexts... )
85
88
ty = map (tx) do dx
86
- finite_difference_jvp (fc, x, dx, prep. cache; relstep, absstep)
89
+ finite_difference_jvp (fc, x, dx, prep. cache; relstep, absstep, dir )
87
90
end
88
91
return ty
89
92
end
@@ -96,21 +99,22 @@ function DI.value_and_pushforward(
96
99
tx:: NTuple ,
97
100
contexts:: Vararg{DI.Context,C} ,
98
101
) where {C}
99
- (; relstep, absstep) = prep
102
+ (; relstep, absstep, dir ) = prep
100
103
fc = DI. with_contexts (f, contexts... )
101
104
y = fc (x)
102
105
ty = map (tx) do dx
103
- finite_difference_jvp (fc, x, dx, prep. cache, y; relstep, absstep)
106
+ finite_difference_jvp (fc, x, dx, prep. cache, y; relstep, absstep, dir )
104
107
end
105
108
return y, ty
106
109
end
107
110
108
111
# # Derivative
109
112
110
- struct FiniteDiffOneArgDerivativePrep{C,R,A} <: DI.DerivativePrep
113
+ struct FiniteDiffOneArgDerivativePrep{C,R,A,D } <: DI.DerivativePrep
111
114
cache:: C
112
115
relstep:: R
113
116
absstep:: A
117
+ dir:: D
114
118
end
115
119
116
120
function DI. prepare_derivative (
@@ -134,7 +138,8 @@ function DI.prepare_derivative(
134
138
else
135
139
backend. relstep
136
140
end
137
- return FiniteDiffOneArgDerivativePrep (cache, relstep, absstep)
141
+ dir = backend. dir
142
+ return FiniteDiffOneArgDerivativePrep (cache, relstep, absstep, dir)
138
143
end
139
144
140
145
# ## Scalar to scalar
@@ -146,9 +151,9 @@ function DI.derivative(
146
151
x,
147
152
contexts:: Vararg{DI.Context,C} ,
148
153
) where {C}
149
- (; relstep, absstep) = prep
154
+ (; relstep, absstep, dir ) = prep
150
155
fc = DI. with_contexts (f, contexts... )
151
- return finite_difference_derivative (fc, x, fdtype (backend); relstep, absstep)
156
+ return finite_difference_derivative (fc, x, fdtype (backend); relstep, absstep, dir )
152
157
end
153
158
154
159
function DI. value_and_derivative (
@@ -158,13 +163,13 @@ function DI.value_and_derivative(
158
163
x,
159
164
contexts:: Vararg{DI.Context,C} ,
160
165
) where {C}
161
- (; relstep, absstep) = prep
166
+ (; relstep, absstep, dir ) = prep
162
167
fc = DI. with_contexts (f, contexts... )
163
168
y = fc (x)
164
169
return (
165
170
y,
166
171
finite_difference_derivative (
167
- fc, x, fdtype (backend), eltype (y), y; relstep, absstep
172
+ fc, x, fdtype (backend), eltype (y), y; relstep, absstep, dir
168
173
),
169
174
)
170
175
end
@@ -178,9 +183,9 @@ function DI.derivative(
178
183
x,
179
184
contexts:: Vararg{DI.Context,C} ,
180
185
) where {C}
181
- (; relstep, absstep) = prep
186
+ (; relstep, absstep, dir ) = prep
182
187
fc = DI. with_contexts (f, contexts... )
183
- return finite_difference_gradient (fc, x, prep. cache; relstep, absstep)
188
+ return finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir )
184
189
end
185
190
186
191
function DI. derivative! (
@@ -191,9 +196,9 @@ function DI.derivative!(
191
196
x,
192
197
contexts:: Vararg{DI.Context,C} ,
193
198
) where {C}
194
- (; relstep, absstep) = prep
199
+ (; relstep, absstep, dir ) = prep
195
200
fc = DI. with_contexts (f, contexts... )
196
- return finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep)
201
+ return finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep, dir )
197
202
end
198
203
199
204
function DI. value_and_derivative (
@@ -204,9 +209,9 @@ function DI.value_and_derivative(
204
209
contexts:: Vararg{DI.Context,C} ,
205
210
) where {C}
206
211
fc = DI. with_contexts (f, contexts... )
207
- (; relstep, absstep) = prep
212
+ (; relstep, absstep, dir ) = prep
208
213
y = fc (x)
209
- return (y, finite_difference_gradient (fc, x, prep. cache; relstep, absstep))
214
+ return (y, finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir ))
210
215
end
211
216
212
217
function DI. value_and_derivative! (
@@ -217,17 +222,20 @@ function DI.value_and_derivative!(
217
222
x,
218
223
contexts:: Vararg{DI.Context,C} ,
219
224
) where {C}
220
- (; relstep, absstep) = prep
225
+ (; relstep, absstep, dir ) = prep
221
226
fc = DI. with_contexts (f, contexts... )
222
- return (fc (x), finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep))
227
+ return (
228
+ fc (x), finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep, dir)
229
+ )
223
230
end
224
231
225
232
# # Gradient
226
233
227
- struct FiniteDiffGradientPrep{C,R,A} <: DI.GradientPrep
234
+ struct FiniteDiffGradientPrep{C,R,A,D } <: DI.GradientPrep
228
235
cache:: C
229
236
relstep:: R
230
237
absstep:: A
238
+ dir:: D
231
239
end
232
240
233
241
function DI. prepare_gradient (
@@ -247,7 +255,8 @@ function DI.prepare_gradient(
247
255
else
248
256
backend. relstep
249
257
end
250
- return FiniteDiffGradientPrep (cache, relstep, absstep)
258
+ dir = backend. dir
259
+ return FiniteDiffGradientPrep (cache, relstep, absstep, dir)
251
260
end
252
261
253
262
function DI. gradient (
@@ -257,9 +266,9 @@ function DI.gradient(
257
266
x:: AbstractArray ,
258
267
contexts:: Vararg{DI.Context,C} ,
259
268
) where {C}
260
- (; relstep, absstep) = prep
269
+ (; relstep, absstep, dir ) = prep
261
270
fc = DI. with_contexts (f, contexts... )
262
- return finite_difference_gradient (fc, x, prep. cache; relstep, absstep)
271
+ return finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir )
263
272
end
264
273
265
274
function DI. value_and_gradient (
@@ -269,9 +278,9 @@ function DI.value_and_gradient(
269
278
x:: AbstractArray ,
270
279
contexts:: Vararg{DI.Context,C} ,
271
280
) where {C}
272
- (; relstep, absstep) = prep
281
+ (; relstep, absstep, dir ) = prep
273
282
fc = DI. with_contexts (f, contexts... )
274
- return fc (x), finite_difference_gradient (fc, x, prep. cache; relstep, absstep)
283
+ return fc (x), finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir )
275
284
end
276
285
277
286
function DI. gradient! (
@@ -282,9 +291,9 @@ function DI.gradient!(
282
291
x:: AbstractArray ,
283
292
contexts:: Vararg{DI.Context,C} ,
284
293
) where {C}
285
- (; relstep, absstep) = prep
294
+ (; relstep, absstep, dir ) = prep
286
295
fc = DI. with_contexts (f, contexts... )
287
- return finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep)
296
+ return finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep, dir )
288
297
end
289
298
290
299
function DI. value_and_gradient! (
@@ -295,17 +304,20 @@ function DI.value_and_gradient!(
295
304
x:: AbstractArray ,
296
305
contexts:: Vararg{DI.Context,C} ,
297
306
) where {C}
298
- (; relstep, absstep) = prep
307
+ (; relstep, absstep, dir ) = prep
299
308
fc = DI. with_contexts (f, contexts... )
300
- return (fc (x), finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep))
309
+ return (
310
+ fc (x), finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep, dir)
311
+ )
301
312
end
302
313
303
314
# # Jacobian
304
315
305
- struct FiniteDiffOneArgJacobianPrep{C,R,A} <: DI.JacobianPrep
316
+ struct FiniteDiffOneArgJacobianPrep{C,R,A,D } <: DI.JacobianPrep
306
317
cache:: C
307
318
relstep:: R
308
319
absstep:: A
320
+ dir:: D
309
321
end
310
322
311
323
function DI. prepare_jacobian (
@@ -327,7 +339,8 @@ function DI.prepare_jacobian(
327
339
else
328
340
backend. relstep
329
341
end
330
- return FiniteDiffOneArgJacobianPrep (cache, relstep, absstep)
342
+ dir = backend. dir
343
+ return FiniteDiffOneArgJacobianPrep (cache, relstep, absstep, dir)
331
344
end
332
345
333
346
function DI. jacobian (
@@ -337,9 +350,9 @@ function DI.jacobian(
337
350
x,
338
351
contexts:: Vararg{DI.Context,C} ,
339
352
) where {C}
340
- (; relstep, absstep) = prep
353
+ (; relstep, absstep, dir ) = prep
341
354
fc = DI. with_contexts (f, contexts... )
342
- return finite_difference_jacobian (fc, x, prep. cache; relstep, absstep)
355
+ return finite_difference_jacobian (fc, x, prep. cache; relstep, absstep, dir )
343
356
end
344
357
345
358
function DI. value_and_jacobian (
@@ -350,9 +363,9 @@ function DI.value_and_jacobian(
350
363
contexts:: Vararg{DI.Context,C} ,
351
364
) where {C}
352
365
fc = DI. with_contexts (f, contexts... )
353
- (; relstep, absstep) = prep
366
+ (; relstep, absstep, dir ) = prep
354
367
y = fc (x)
355
- return (y, finite_difference_jacobian (fc, x, prep. cache, y; relstep, absstep))
368
+ return (y, finite_difference_jacobian (fc, x, prep. cache, y; relstep, absstep, dir ))
356
369
end
357
370
358
371
function DI. jacobian! (
@@ -363,11 +376,13 @@ function DI.jacobian!(
363
376
x,
364
377
contexts:: Vararg{DI.Context,C} ,
365
378
) where {C}
366
- (; relstep, absstep) = prep
379
+ (; relstep, absstep, dir ) = prep
367
380
fc = DI. with_contexts (f, contexts... )
368
381
return copyto! (
369
382
jac,
370
- finite_difference_jacobian (fc, x, prep. cache; jac_prototype= jac, relstep, absstep),
383
+ finite_difference_jacobian (
384
+ fc, x, prep. cache; jac_prototype= jac, relstep, absstep, dir
385
+ ),
371
386
)
372
387
end
373
388
@@ -379,15 +394,15 @@ function DI.value_and_jacobian!(
379
394
x,
380
395
contexts:: Vararg{DI.Context,C} ,
381
396
) where {C}
382
- (; relstep, absstep) = prep
397
+ (; relstep, absstep, dir ) = prep
383
398
fc = DI. with_contexts (f, contexts... )
384
399
y = fc (x)
385
400
return (
386
401
y,
387
402
copyto! (
388
403
jac,
389
404
finite_difference_jacobian (
390
- fc, x, prep. cache, y; jac_prototype= jac, relstep, absstep
405
+ fc, x, prep. cache, y; jac_prototype= jac, relstep, absstep, dir
391
406
),
392
407
),
393
408
)
0 commit comments