1
1
# # Pushforward
2
2
3
3
function DI. prepare_pushforward (
4
+ strict:: Val ,
4
5
f:: F ,
5
- :: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
6
+ backend :: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
6
7
x,
7
8
tx:: NTuple ,
8
- contexts:: Vararg{DI.Context,C} ,
9
+ contexts:: Vararg{DI.Context,C} ;
9
10
) where {F,C}
10
- return DI. NoPushforwardPrep ()
11
+ _sig = DI. signature (f, backend, x, tx, contexts... ; strict)
12
+ return DI. NoPushforwardPrep (_sig)
11
13
end
12
14
13
15
function DI. value_and_pushforward (
14
16
f:: F ,
15
- :: DI.NoPushforwardPrep ,
17
+ prep :: DI.NoPushforwardPrep ,
16
18
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
17
19
x,
18
20
tx:: NTuple{1} ,
19
21
contexts:: Vararg{DI.Context,C} ,
20
22
) where {F,C}
23
+ DI. check_prep (f, prep, backend, x, tx, contexts... )
21
24
mode = forward_withprimal (backend)
22
25
f_and_df = get_f_and_df (f, backend, mode)
23
26
dx = only (tx)
29
32
30
33
function DI. value_and_pushforward (
31
34
f:: F ,
32
- :: DI.NoPushforwardPrep ,
35
+ prep :: DI.NoPushforwardPrep ,
33
36
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
34
37
x,
35
38
tx:: NTuple{B} ,
36
39
contexts:: Vararg{DI.Context,C} ,
37
40
) where {F,B,C}
41
+ DI. check_prep (f, prep, backend, x, tx, contexts... )
38
42
mode = forward_withprimal (backend)
39
43
f_and_df = get_f_and_df (f, backend, mode, Val (B))
40
44
x_and_tx = BatchDuplicated (x, tx)
45
49
46
50
function DI. pushforward (
47
51
f:: F ,
48
- :: DI.NoPushforwardPrep ,
52
+ prep :: DI.NoPushforwardPrep ,
49
53
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
50
54
x,
51
55
tx:: NTuple{1} ,
52
56
contexts:: Vararg{DI.Context,C} ,
53
57
) where {F,C}
58
+ DI. check_prep (f, prep, backend, x, tx, contexts... )
54
59
mode = forward_noprimal (backend)
55
60
f_and_df = get_f_and_df (f, backend, mode)
56
61
dx = only (tx)
62
67
63
68
function DI. pushforward (
64
69
f:: F ,
65
- :: DI.NoPushforwardPrep ,
70
+ prep :: DI.NoPushforwardPrep ,
66
71
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
67
72
x,
68
73
tx:: NTuple{B} ,
69
74
contexts:: Vararg{DI.Context,C} ,
70
75
) where {F,B,C}
76
+ DI. check_prep (f, prep, backend, x, tx, contexts... )
71
77
mode = forward_noprimal (backend)
72
78
f_and_df = get_f_and_df (f, backend, mode, Val (B))
73
79
x_and_tx = BatchDuplicated (x, tx)
@@ -85,6 +91,7 @@ function DI.value_and_pushforward!(
85
91
tx:: NTuple ,
86
92
contexts:: Vararg{DI.Context,C} ,
87
93
) where {F,C}
94
+ DI. check_prep (f, prep, backend, x, tx, contexts... )
88
95
# dy cannot be passed anyway
89
96
y, new_ty = DI. value_and_pushforward (f, prep, backend, x, tx, contexts... )
90
97
foreach (copyto!, ty, new_ty)
@@ -100,6 +107,7 @@ function DI.pushforward!(
100
107
tx:: NTuple ,
101
108
contexts:: Vararg{DI.Context,C} ,
102
109
) where {F,C}
110
+ DI. check_prep (f, prep, backend, x, tx, contexts... )
103
111
# dy cannot be passed anyway
104
112
new_ty = DI. pushforward (f, prep, backend, x, tx, contexts... )
105
113
foreach (copyto!, ty, new_ty)
@@ -108,32 +116,33 @@ end
108
116
109
117
# # Gradient
110
118
111
- struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep
119
+ struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
120
+ _sig:: Val{SIG}
121
+ _valB:: Val{B}
112
122
shadows:: O
113
123
end
114
124
115
- function EnzymeForwardGradientPrep (:: Val{B} , shadows:: O ) where {B,O}
116
- return EnzymeForwardGradientPrep {B,O} (shadows)
117
- end
118
-
119
125
function DI. prepare_gradient (
126
+ strict:: Val ,
120
127
f:: F ,
121
128
backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
122
129
x,
123
- contexts:: Vararg{DI.Constant,C} ,
130
+ contexts:: Vararg{DI.Constant,C} ;
124
131
) where {F,C}
132
+ _sig = DI. signature (f, backend, x, contexts... ; strict)
125
133
valB = to_val (DI. pick_batchsize (backend, x))
126
134
shadows = create_shadows (valB, x)
127
- return EnzymeForwardGradientPrep (valB, shadows)
135
+ return EnzymeForwardGradientPrep (_sig, valB, shadows)
128
136
end
129
137
130
138
function DI. gradient (
131
139
f:: F ,
132
- prep:: EnzymeForwardGradientPrep{B} ,
140
+ prep:: EnzymeForwardGradientPrep{SIG, B} ,
133
141
backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
134
142
x,
135
143
contexts:: Vararg{DI.Constant,C} ,
136
- ) where {F,B,C}
144
+ ) where {F,SIG,B,C}
145
+ DI. check_prep (f, prep, backend, x, contexts... )
137
146
mode = forward_noprimal (backend)
138
147
f_and_df = get_f_and_df (f, backend, mode)
139
148
annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -145,11 +154,12 @@ end
145
154
146
155
function DI. value_and_gradient (
147
156
f:: F ,
148
- prep:: EnzymeForwardGradientPrep{B} ,
157
+ prep:: EnzymeForwardGradientPrep{SIG, B} ,
149
158
backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
150
159
x,
151
160
contexts:: Vararg{DI.Constant,C} ,
152
- ) where {F,B,C}
161
+ ) where {F,SIG,B,C}
162
+ DI. check_prep (f, prep, backend, x, contexts... )
153
163
mode = forward_withprimal (backend)
154
164
f_and_df = get_f_and_df (f, backend, mode)
155
165
annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -162,58 +172,59 @@ end
162
172
function DI. gradient! (
163
173
f:: F ,
164
174
grad,
165
- prep:: EnzymeForwardGradientPrep{B} ,
175
+ prep:: EnzymeForwardGradientPrep{SIG, B} ,
166
176
backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
167
177
x,
168
178
contexts:: Vararg{DI.Constant,C} ,
169
- ) where {F,B,C}
179
+ ) where {F,SIG,B,C}
180
+ DI. check_prep (f, prep, backend, x, contexts... )
170
181
return copyto! (grad, DI. gradient (f, prep, backend, x, contexts... ))
171
182
end
172
183
173
184
function DI. value_and_gradient! (
174
185
f:: F ,
175
186
grad,
176
- prep:: EnzymeForwardGradientPrep{B} ,
187
+ prep:: EnzymeForwardGradientPrep{SIG, B} ,
177
188
backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
178
189
x,
179
190
contexts:: Vararg{DI.Constant,C} ,
180
- ) where {F,B,C}
191
+ ) where {F,SIG,B,C}
192
+ DI. check_prep (f, prep, backend, x, contexts... )
181
193
y, new_grad = DI. value_and_gradient (f, prep, backend, x, contexts... )
182
194
return y, copyto! (grad, new_grad)
183
195
end
184
196
185
197
# # Jacobian
186
198
187
- struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep
199
+ struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
200
+ _sig:: Val{SIG}
201
+ _valB:: Val{B}
188
202
shadows:: O
189
203
output_length:: Int
190
204
end
191
205
192
- function EnzymeForwardOneArgJacobianPrep (
193
- :: Val{B} , shadows:: O , output_length:: Integer
194
- ) where {B,O}
195
- return EnzymeForwardOneArgJacobianPrep {B,O} (shadows, output_length)
196
- end
197
-
198
206
function DI. prepare_jacobian (
207
+ strict:: Val ,
199
208
f:: F ,
200
209
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
201
210
x,
202
- contexts:: Vararg{DI.Constant,C} ,
211
+ contexts:: Vararg{DI.Constant,C} ;
203
212
) where {F,C}
213
+ _sig = DI. signature (f, backend, x, contexts... ; strict)
204
214
y = f (x, map (DI. unwrap, contexts)... )
205
215
valB = to_val (DI. pick_batchsize (backend, x))
206
216
shadows = create_shadows (valB, x)
207
- return EnzymeForwardOneArgJacobianPrep (valB, shadows, length (y))
217
+ return EnzymeForwardOneArgJacobianPrep (_sig, valB, shadows, length (y))
208
218
end
209
219
210
220
function DI. jacobian (
211
221
f:: F ,
212
- prep:: EnzymeForwardOneArgJacobianPrep{B} ,
222
+ prep:: EnzymeForwardOneArgJacobianPrep{SIG, B} ,
213
223
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
214
224
x,
215
225
contexts:: Vararg{DI.Constant,C} ,
216
- ) where {F,B,C}
226
+ ) where {F,SIG,B,C}
227
+ DI. check_prep (f, prep, backend, x, contexts... )
217
228
mode = forward_noprimal (backend)
218
229
f_and_df = get_f_and_df (f, backend, mode)
219
230
annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -226,11 +237,12 @@ end
226
237
227
238
function DI. value_and_jacobian (
228
239
f:: F ,
229
- prep:: EnzymeForwardOneArgJacobianPrep{B} ,
240
+ prep:: EnzymeForwardOneArgJacobianPrep{SIG, B} ,
230
241
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
231
242
x,
232
243
contexts:: Vararg{DI.Constant,C} ,
233
- ) where {F,B,C}
244
+ ) where {F,SIG,B,C}
245
+ DI. check_prep (f, prep, backend, x, contexts... )
234
246
mode = forward_withprimal (backend)
235
247
f_and_df = get_f_and_df (f, backend, mode)
236
248
annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -249,6 +261,7 @@ function DI.jacobian!(
249
261
x,
250
262
contexts:: Vararg{DI.Constant,C} ,
251
263
) where {F,C}
264
+ DI. check_prep (f, prep, backend, x, contexts... )
252
265
return copyto! (jac, DI. jacobian (f, prep, backend, x, contexts... ))
253
266
end
254
267
@@ -260,6 +273,7 @@ function DI.value_and_jacobian!(
260
273
x,
261
274
contexts:: Vararg{DI.Constant,C} ,
262
275
) where {F,C}
276
+ DI. check_prep (f, prep, backend, x, contexts... )
263
277
y, new_jac = DI. value_and_jacobian (f, prep, backend, x, contexts... )
264
278
return y, copyto! (jac, new_jac)
265
279
end
0 commit comments