Skip to content

Commit 435dea6

Browse files
authored
feat: test type consistency between preparation and execution (#745)
1 parent bac2d02 commit 435dea6

File tree

59 files changed

+2586
-1317
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+2586
-1317
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.44"
4+
version = "0.6.45"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
22
(; f, backend) = dw
33
y = f(x)
4-
prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,))
4+
prep_same = DI.prepare_pullback_same_point(Val(true), f, backend, x, (y,))
55
function pullbackfunc(dy)
66
tx = DI.pullback(f, prep_same, backend, x, (dy,))
77
return (NoTangent(), only(tx))

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,47 @@
11
## Pullback
22

3-
struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
3+
struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG}
4+
_sig::Val{SIG}
45
y::Y
56
pb::PB
67
end
78

89
function DI.prepare_pullback(
9-
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}
10+
strict::Val,
11+
f,
12+
backend::AutoReverseChainRules,
13+
x,
14+
ty::NTuple,
15+
contexts::Vararg{DI.GeneralizedConstant,C};
1016
) where {C}
11-
return DI.NoPullbackPrep()
17+
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
18+
return DI.NoPullbackPrep(_sig)
1219
end
1320

1421
function DI.prepare_pullback_same_point(
1522
f,
16-
::DI.NoPullbackPrep,
23+
prep::DI.NoPullbackPrep,
1724
backend::AutoReverseChainRules,
1825
x,
1926
ty::NTuple,
20-
contexts::Vararg{DI.GeneralizedConstant,C},
27+
contexts::Vararg{DI.GeneralizedConstant,C};
2128
) where {C}
29+
DI.check_prep(f, prep, backend, x, ty, contexts...)
30+
_sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep))
2231
rc = ruleconfig(backend)
2332
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
24-
return ChainRulesPullbackPrepSamePoint(y, pb)
33+
return ChainRulesPullbackPrepSamePoint(_sig, y, pb)
2534
end
2635

2736
function DI.value_and_pullback(
2837
f,
29-
::DI.NoPullbackPrep,
38+
prep::DI.NoPullbackPrep,
3039
backend::AutoReverseChainRules,
3140
x,
3241
ty::NTuple,
3342
contexts::Vararg{DI.GeneralizedConstant,C},
3443
) where {C}
44+
DI.check_prep(f, prep, backend, x, ty, contexts...)
3545
rc = ruleconfig(backend)
3646
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
3747
tx = map(ty) do dy
@@ -43,11 +53,12 @@ end
4353
function DI.value_and_pullback(
4454
f,
4555
prep::ChainRulesPullbackPrepSamePoint,
46-
::AutoReverseChainRules,
56+
backend::AutoReverseChainRules,
4757
x,
4858
ty::NTuple,
4959
contexts::Vararg{DI.GeneralizedConstant,C},
5060
) where {C}
61+
DI.check_prep(f, prep, backend, x, ty, contexts...)
5162
(; y, pb) = prep
5263
tx = map(ty) do dy
5364
unthunk(pb(dy)[2])
@@ -58,11 +69,12 @@ end
5869
function DI.pullback(
5970
f,
6071
prep::ChainRulesPullbackPrepSamePoint,
61-
::AutoReverseChainRules,
72+
backend::AutoReverseChainRules,
6273
x,
6374
ty::NTuple,
6475
contexts::Vararg{DI.GeneralizedConstant,C},
6576
) where {C}
77+
DI.check_prep(f, prep, backend, x, ty, contexts...)
6678
(; pb) = prep
6779
tx = map(ty) do dy
6880
unthunk(pb(dy)[2])

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
1010

1111
## Pushforward
1212

13-
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = DI.NoPushforwardPrep()
13+
function DI.prepare_pushforward(strict::Val, f, backend::AutoDiffractor, x, tx::NTuple)
14+
_sig = DI.signature(f, backend, x, tx; strict)
15+
return DI.NoPushforwardPrep(_sig)
16+
end
1417

15-
function DI.pushforward(f, ::DI.NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
18+
function DI.pushforward(
19+
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
20+
)
21+
DI.check_prep(f, prep, backend, x, tx)
1622
ty = map(tx) do dx
1723
# code copied from Diffractor.jl
1824
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
@@ -25,6 +31,7 @@ end
2531
function DI.value_and_pushforward(
2632
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
2733
)
34+
DI.check_prep(f, prep, backend, x, tx)
2835
return f(x), DI.pushforward(f, prep, backend, x, tx)
2936
end
3037

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
## Pushforward
22

33
function DI.prepare_pushforward(
4+
strict::Val,
45
f::F,
5-
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
6+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
67
x,
78
tx::NTuple,
8-
contexts::Vararg{DI.Context,C},
9+
contexts::Vararg{DI.Context,C};
910
) where {F,C}
10-
return DI.NoPushforwardPrep()
11+
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
12+
return DI.NoPushforwardPrep(_sig)
1113
end
1214

1315
function DI.value_and_pushforward(
1416
f::F,
15-
::DI.NoPushforwardPrep,
17+
prep::DI.NoPushforwardPrep,
1618
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1719
x,
1820
tx::NTuple{1},
1921
contexts::Vararg{DI.Context,C},
2022
) where {F,C}
23+
DI.check_prep(f, prep, backend, x, tx, contexts...)
2124
mode = forward_withprimal(backend)
2225
f_and_df = get_f_and_df(f, backend, mode)
2326
dx = only(tx)
@@ -29,12 +32,13 @@ end
2932

3033
function DI.value_and_pushforward(
3134
f::F,
32-
::DI.NoPushforwardPrep,
35+
prep::DI.NoPushforwardPrep,
3336
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
3437
x,
3538
tx::NTuple{B},
3639
contexts::Vararg{DI.Context,C},
3740
) where {F,B,C}
41+
DI.check_prep(f, prep, backend, x, tx, contexts...)
3842
mode = forward_withprimal(backend)
3943
f_and_df = get_f_and_df(f, backend, mode, Val(B))
4044
x_and_tx = BatchDuplicated(x, tx)
@@ -45,12 +49,13 @@ end
4549

4650
function DI.pushforward(
4751
f::F,
48-
::DI.NoPushforwardPrep,
52+
prep::DI.NoPushforwardPrep,
4953
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
5054
x,
5155
tx::NTuple{1},
5256
contexts::Vararg{DI.Context,C},
5357
) where {F,C}
58+
DI.check_prep(f, prep, backend, x, tx, contexts...)
5459
mode = forward_noprimal(backend)
5560
f_and_df = get_f_and_df(f, backend, mode)
5661
dx = only(tx)
@@ -62,12 +67,13 @@ end
6267

6368
function DI.pushforward(
6469
f::F,
65-
::DI.NoPushforwardPrep,
70+
prep::DI.NoPushforwardPrep,
6671
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
6772
x,
6873
tx::NTuple{B},
6974
contexts::Vararg{DI.Context,C},
7075
) where {F,B,C}
76+
DI.check_prep(f, prep, backend, x, tx, contexts...)
7177
mode = forward_noprimal(backend)
7278
f_and_df = get_f_and_df(f, backend, mode, Val(B))
7379
x_and_tx = BatchDuplicated(x, tx)
@@ -85,6 +91,7 @@ function DI.value_and_pushforward!(
8591
tx::NTuple,
8692
contexts::Vararg{DI.Context,C},
8793
) where {F,C}
94+
DI.check_prep(f, prep, backend, x, tx, contexts...)
8895
# dy cannot be passed anyway
8996
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
9097
foreach(copyto!, ty, new_ty)
@@ -100,6 +107,7 @@ function DI.pushforward!(
100107
tx::NTuple,
101108
contexts::Vararg{DI.Context,C},
102109
) where {F,C}
110+
DI.check_prep(f, prep, backend, x, tx, contexts...)
103111
# dy cannot be passed anyway
104112
new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...)
105113
foreach(copyto!, ty, new_ty)
@@ -108,32 +116,33 @@ end
108116

109117
## Gradient
110118

111-
struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep
119+
struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
120+
_sig::Val{SIG}
121+
_valB::Val{B}
112122
shadows::O
113123
end
114124

115-
function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O}
116-
return EnzymeForwardGradientPrep{B,O}(shadows)
117-
end
118-
119125
function DI.prepare_gradient(
126+
strict::Val,
120127
f::F,
121128
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
122129
x,
123-
contexts::Vararg{DI.Constant,C},
130+
contexts::Vararg{DI.Constant,C};
124131
) where {F,C}
132+
_sig = DI.signature(f, backend, x, contexts...; strict)
125133
valB = to_val(DI.pick_batchsize(backend, x))
126134
shadows = create_shadows(valB, x)
127-
return EnzymeForwardGradientPrep(valB, shadows)
135+
return EnzymeForwardGradientPrep(_sig, valB, shadows)
128136
end
129137

130138
function DI.gradient(
131139
f::F,
132-
prep::EnzymeForwardGradientPrep{B},
140+
prep::EnzymeForwardGradientPrep{SIG,B},
133141
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
134142
x,
135143
contexts::Vararg{DI.Constant,C},
136-
) where {F,B,C}
144+
) where {F,SIG,B,C}
145+
DI.check_prep(f, prep, backend, x, contexts...)
137146
mode = forward_noprimal(backend)
138147
f_and_df = get_f_and_df(f, backend, mode)
139148
annotated_contexts = translate(backend, mode, Val(B), contexts...)
@@ -145,11 +154,12 @@ end
145154

146155
function DI.value_and_gradient(
147156
f::F,
148-
prep::EnzymeForwardGradientPrep{B},
157+
prep::EnzymeForwardGradientPrep{SIG,B},
149158
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
150159
x,
151160
contexts::Vararg{DI.Constant,C},
152-
) where {F,B,C}
161+
) where {F,SIG,B,C}
162+
DI.check_prep(f, prep, backend, x, contexts...)
153163
mode = forward_withprimal(backend)
154164
f_and_df = get_f_and_df(f, backend, mode)
155165
annotated_contexts = translate(backend, mode, Val(B), contexts...)
@@ -162,58 +172,59 @@ end
162172
function DI.gradient!(
163173
f::F,
164174
grad,
165-
prep::EnzymeForwardGradientPrep{B},
175+
prep::EnzymeForwardGradientPrep{SIG,B},
166176
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
167177
x,
168178
contexts::Vararg{DI.Constant,C},
169-
) where {F,B,C}
179+
) where {F,SIG,B,C}
180+
DI.check_prep(f, prep, backend, x, contexts...)
170181
return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...))
171182
end
172183

173184
function DI.value_and_gradient!(
174185
f::F,
175186
grad,
176-
prep::EnzymeForwardGradientPrep{B},
187+
prep::EnzymeForwardGradientPrep{SIG,B},
177188
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
178189
x,
179190
contexts::Vararg{DI.Constant,C},
180-
) where {F,B,C}
191+
) where {F,SIG,B,C}
192+
DI.check_prep(f, prep, backend, x, contexts...)
181193
y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
182194
return y, copyto!(grad, new_grad)
183195
end
184196

185197
## Jacobian
186198

187-
struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep
199+
struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
200+
_sig::Val{SIG}
201+
_valB::Val{B}
188202
shadows::O
189203
output_length::Int
190204
end
191205

192-
function EnzymeForwardOneArgJacobianPrep(
193-
::Val{B}, shadows::O, output_length::Integer
194-
) where {B,O}
195-
return EnzymeForwardOneArgJacobianPrep{B,O}(shadows, output_length)
196-
end
197-
198206
function DI.prepare_jacobian(
207+
strict::Val,
199208
f::F,
200209
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
201210
x,
202-
contexts::Vararg{DI.Constant,C},
211+
contexts::Vararg{DI.Constant,C};
203212
) where {F,C}
213+
_sig = DI.signature(f, backend, x, contexts...; strict)
204214
y = f(x, map(DI.unwrap, contexts)...)
205215
valB = to_val(DI.pick_batchsize(backend, x))
206216
shadows = create_shadows(valB, x)
207-
return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y))
217+
return EnzymeForwardOneArgJacobianPrep(_sig, valB, shadows, length(y))
208218
end
209219

210220
function DI.jacobian(
211221
f::F,
212-
prep::EnzymeForwardOneArgJacobianPrep{B},
222+
prep::EnzymeForwardOneArgJacobianPrep{SIG,B},
213223
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
214224
x,
215225
contexts::Vararg{DI.Constant,C},
216-
) where {F,B,C}
226+
) where {F,SIG,B,C}
227+
DI.check_prep(f, prep, backend, x, contexts...)
217228
mode = forward_noprimal(backend)
218229
f_and_df = get_f_and_df(f, backend, mode)
219230
annotated_contexts = translate(backend, mode, Val(B), contexts...)
@@ -226,11 +237,12 @@ end
226237

227238
function DI.value_and_jacobian(
228239
f::F,
229-
prep::EnzymeForwardOneArgJacobianPrep{B},
240+
prep::EnzymeForwardOneArgJacobianPrep{SIG,B},
230241
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
231242
x,
232243
contexts::Vararg{DI.Constant,C},
233-
) where {F,B,C}
244+
) where {F,SIG,B,C}
245+
DI.check_prep(f, prep, backend, x, contexts...)
234246
mode = forward_withprimal(backend)
235247
f_and_df = get_f_and_df(f, backend, mode)
236248
annotated_contexts = translate(backend, mode, Val(B), contexts...)
@@ -249,6 +261,7 @@ function DI.jacobian!(
249261
x,
250262
contexts::Vararg{DI.Constant,C},
251263
) where {F,C}
264+
DI.check_prep(f, prep, backend, x, contexts...)
252265
return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...))
253266
end
254267

@@ -260,6 +273,7 @@ function DI.value_and_jacobian!(
260273
x,
261274
contexts::Vararg{DI.Constant,C},
262275
) where {F,C}
276+
DI.check_prep(f, prep, backend, x, contexts...)
263277
y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...)
264278
return y, copyto!(jac, new_jac)
265279
end

0 commit comments

Comments
 (0)