Skip to content

Commit e1d171f

Browse files
authored
fix: clarify dispatch for preparation (#746)
* fix: separe `prepare` from the hidden `prepare_nokwarg` * DOcs * Typing * Fix * Toggle fail fast
1 parent 435dea6 commit e1d171f

File tree

43 files changed

+426
-254
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+426
-254
lines changed

DifferentiationInterface/docs/src/dev_guide.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ Most operators have 4 variants, which look like this in the first order: `operat
2323
To implement a new operator for an existing backend, you need to write 5 methods: 1 for [preparation](@ref Preparation) and 4 corresponding to the variants of the operator (see above).
2424
For first-order operators, you may also want to support [in-place functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`).
2525

26-
The method `prepare_operator` must output a `prep` object of the correct type.
27-
For instance, `prepare_gradient(f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref).
28-
Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep`.
29-
Otherwise, define a custom struct like `MyGradientPrep <: DifferentiationInterface.GradientPrep` and put the necessary storage in there.
26+
The method `prepare_operator_nokwarg` must output a `prep` object of the correct type.
27+
For instance, `prepare_gradient(strict, f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref).
28+
Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep{SIG}`.
29+
Otherwise, define a custom struct like `MyGradientPrep{SIG} <: DifferentiationInterface.GradientPrep{SIG}` and put the necessary storage in there.
3030

3131
## New backend
3232

@@ -75,4 +75,4 @@ GROUP = get(ENV, "JULIA_DI_TEST_GROUP", "Back/SuperDiff")
7575

7676
but don't forget to switch it back before pushing.
7777

78-
Finally, you need to add your backend to the documentation, modifying every page that involves a list of backends (including the `README.md`).
78+
Finally, you need to add your backend to the documentation, modifying every page that involves a list of backends (including the `README.md`).

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
22
(; f, backend) = dw
33
y = f(x)
4-
prep_same = DI.prepare_pullback_same_point(Val(true), f, backend, x, (y,))
4+
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(true), f, backend, x, (y,))
55
function pullbackfunc(dy)
66
tx = DI.pullback(f, prep_same, backend, x, (dy,))
77
return (NoTangent(), only(tx))

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG}
66
pb::PB
77
end
88

9-
function DI.prepare_pullback(
9+
function DI.prepare_pullback_nokwarg(
1010
strict::Val,
1111
f,
1212
backend::AutoReverseChainRules,

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
1010

1111
## Pushforward
1212

13-
function DI.prepare_pushforward(strict::Val, f, backend::AutoDiffractor, x, tx::NTuple)
13+
function DI.prepare_pushforward_nokwarg(
14+
strict::Val, f, backend::AutoDiffractor, x, tx::NTuple
15+
)
1416
_sig = DI.signature(f, backend, x, tx; strict)
1517
return DI.NoPushforwardPrep(_sig)
1618
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pushforward
22

3-
function DI.prepare_pushforward(
3+
function DI.prepare_pushforward_nokwarg(
44
strict::Val,
55
f::F,
66
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
@@ -122,7 +122,7 @@ struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
122122
shadows::O
123123
end
124124

125-
function DI.prepare_gradient(
125+
function DI.prepare_gradient_nokwarg(
126126
strict::Val,
127127
f::F,
128128
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
@@ -203,7 +203,7 @@ struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
203203
output_length::Int
204204
end
205205

206-
function DI.prepare_jacobian(
206+
function DI.prepare_jacobian_nokwarg(
207207
strict::Val,
208208
f::F,
209209
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pushforward
22

3-
function DI.prepare_pushforward(
3+
function DI.prepare_pushforward_nokwarg(
44
strict::Val,
55
f!::F,
66
y,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG}
5252
y_example::Y # useful to create return activity
5353
end
5454

55-
function DI.prepare_pullback(
55+
function DI.prepare_pullback_nokwarg(
5656
strict::Val,
5757
f::F,
5858
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
@@ -191,7 +191,7 @@ end
191191

192192
## Gradient
193193

194-
function DI.prepare_gradient(
194+
function DI.prepare_gradient_nokwarg(
195195
strict::Val,
196196
f::F,
197197
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG}
55
ty_copy::TY
66
end
77

8-
function DI.prepare_pullback(
8+
function DI.prepare_pullback_nokwarg(
99
strict::Val,
1010
f!::F,
1111
y,

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardP
77
jvp_exe!::E1!
88
end
99

10-
function DI.prepare_pushforward(
10+
function DI.prepare_pushforward_nokwarg(
1111
strict::Val,
1212
f,
1313
backend::AutoFastDifferentiation,
@@ -105,7 +105,7 @@ struct FastDifferentiationOneArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG}
105105
vjp_exe!::E1!
106106
end
107107

108-
function DI.prepare_pullback(
108+
function DI.prepare_pullback_nokwarg(
109109
strict::Val,
110110
f,
111111
backend::AutoFastDifferentiation,
@@ -204,7 +204,7 @@ struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePre
204204
der_exe!::E1!
205205
end
206206

207-
function DI.prepare_derivative(
207+
function DI.prepare_derivative_nokwarg(
208208
strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
209209
) where {C}
210210
_sig = DI.signature(f, backend, x, contexts...; strict)
@@ -284,7 +284,7 @@ struct FastDifferentiationOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG}
284284
jac_exe!::E1!
285285
end
286286

287-
function DI.prepare_gradient(
287+
function DI.prepare_gradient_nokwarg(
288288
strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
289289
) where {C}
290290
_sig = DI.signature(f, backend, x, contexts...; strict)
@@ -360,7 +360,7 @@ struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SI
360360
jac_exe!::E1!
361361
end
362362

363-
function DI.prepare_jacobian(
363+
function DI.prepare_jacobian_nokwarg(
364364
strict::Val,
365365
f,
366366
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
@@ -445,7 +445,7 @@ struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <:
445445
der2_exe!::E2!
446446
end
447447

448-
function DI.prepare_second_derivative(
448+
function DI.prepare_second_derivative_nokwarg(
449449
strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
450450
) where {C}
451451
_sig = DI.signature(f, backend, x, contexts...; strict)
@@ -462,7 +462,7 @@ function DI.prepare_second_derivative(
462462
der2_exe = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=false)
463463
der2_exe! = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=true)
464464

465-
derivative_prep = DI.prepare_derivative(f, backend, x, contexts...)
465+
derivative_prep = DI.prepare_derivative_nokwarg(strict, f, backend, x, contexts...)
466466
return FastDifferentiationAllocatingSecondDerivativePrep(
467467
_sig, y_prototype, derivative_prep, der2_exe, der2_exe!
468468
)
@@ -534,7 +534,7 @@ struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG}
534534
gradient_prep::E1
535535
end
536536

537-
function DI.prepare_hvp(
537+
function DI.prepare_hvp_nokwarg(
538538
strict::Val,
539539
f,
540540
backend::AutoFastDifferentiation,
@@ -557,7 +557,7 @@ function DI.prepare_hvp(
557557
hv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true
558558
)
559559

560-
gradient_prep = DI.prepare_gradient(f, backend, x, contexts...)
560+
gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x, contexts...)
561561
return FastDifferentiationHVPPrep(_sig, hvp_exe, hvp_exe!, gradient_prep)
562562
end
563563

@@ -633,7 +633,7 @@ struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG}
633633
hess_exe!::E2!
634634
end
635635

636-
function DI.prepare_hessian(
636+
function DI.prepare_hessian_nokwarg(
637637
strict::Val,
638638
f,
639639
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
@@ -656,7 +656,9 @@ function DI.prepare_hessian(
656656
hess_exe = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=false)
657657
hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true)
658658

659-
gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...)
659+
gradient_prep = DI.prepare_gradient_nokwarg(
660+
strict, f, dense_ad(backend), x, contexts...
661+
)
660662
return FastDifferentiationHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!)
661663
end
662664

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ struct FastDifferentiationTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPre
66
jvp_exe!::E1!
77
end
88

9-
function DI.prepare_pushforward(
9+
function DI.prepare_pushforward_nokwarg(
1010
strict::Val,
1111
f!,
1212
y,
@@ -107,7 +107,7 @@ struct FastDifferentiationTwoArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG}
107107
vjp_exe!::E1!
108108
end
109109

110-
function DI.prepare_pullback(
110+
function DI.prepare_pullback_nokwarg(
111111
strict::Val,
112112
f!,
113113
y,
@@ -213,7 +213,7 @@ struct FastDifferentiationTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{
213213
der_exe!::E1!
214214
end
215215

216-
function DI.prepare_derivative(
216+
function DI.prepare_derivative_nokwarg(
217217
strict::Val, f!, y, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
218218
) where {C}
219219
_sig = DI.signature(f!, y, backend, x, contexts...; strict)
@@ -295,7 +295,7 @@ struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG}
295295
jac_exe!::E1!
296296
end
297297

298-
function DI.prepare_jacobian(
298+
function DI.prepare_jacobian_nokwarg(
299299
strict::Val,
300300
f!,
301301
y,

0 commit comments

Comments
 (0)