@@ -119,8 +119,11 @@ function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O}
119
119
end
120
120
121
121
function DI. prepare_gradient (
122
- f:: F , backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} , x
123
- ) where {F}
122
+ f:: F ,
123
+ backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
124
+ x,
125
+ contexts:: Vararg{DI.Constant,C} ,
126
+ ) where {F,C}
124
127
valB = to_val (DI. pick_batchsize (backend, x))
125
128
shadows = create_shadows (valB, x)
126
129
return EnzymeForwardGradientPrep (valB, shadows)
@@ -131,23 +134,31 @@ function DI.gradient(
131
134
prep:: EnzymeForwardGradientPrep{B} ,
132
135
backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
133
136
x,
134
- ) where {F,B}
137
+ contexts:: Vararg{DI.Constant,C} ,
138
+ ) where {F,B,C}
135
139
mode = forward_noprimal (backend)
136
140
f_and_df = get_f_and_df (f, backend, mode)
137
- derivs = gradient (mode, f_and_df, x; chunk= Val (B), shadows= prep. shadows)
138
- return only (derivs)
141
+ annotated_contexts = translate (backend, mode, Val (B), contexts... )
142
+ derivs = gradient (
143
+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep. shadows
144
+ )
145
+ return first (derivs)
139
146
end
140
147
141
148
function DI. value_and_gradient (
142
149
f:: F ,
143
150
prep:: EnzymeForwardGradientPrep{B} ,
144
151
backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
145
152
x,
146
- ) where {F,B}
153
+ contexts:: Vararg{DI.Constant,C} ,
154
+ ) where {F,B,C}
147
155
mode = forward_withprimal (backend)
148
156
f_and_df = get_f_and_df (f, backend, mode)
149
- (; derivs, val) = gradient (mode, f_and_df, x; chunk= Val (B), shadows= prep. shadows)
150
- return val, only (derivs)
157
+ annotated_contexts = translate (backend, mode, Val (B), contexts... )
158
+ (; derivs, val) = gradient (
159
+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep. shadows
160
+ )
161
+ return val, first (derivs)
151
162
end
152
163
153
164
function DI. gradient! (
@@ -156,8 +167,9 @@ function DI.gradient!(
156
167
prep:: EnzymeForwardGradientPrep{B} ,
157
168
backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
158
169
x,
159
- ) where {F,B}
160
- return copyto! (grad, DI. gradient (f, prep, backend, x))
170
+ contexts:: Vararg{DI.Constant,C} ,
171
+ ) where {F,B,C}
172
+ return copyto! (grad, DI. gradient (f, prep, backend, x, contexts... ))
161
173
end
162
174
163
175
function DI. value_and_gradient! (
@@ -166,8 +178,9 @@ function DI.value_and_gradient!(
166
178
prep:: EnzymeForwardGradientPrep{B} ,
167
179
backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
168
180
x,
169
- ) where {F,B}
170
- y, new_grad = DI. value_and_gradient (f, prep, backend, x)
181
+ contexts:: Vararg{DI.Constant,C} ,
182
+ ) where {F,B,C}
183
+ y, new_grad = DI. value_and_gradient (f, prep, backend, x, contexts... )
171
184
return y, copyto! (grad, new_grad)
172
185
end
173
186
@@ -185,9 +198,12 @@ function EnzymeForwardOneArgJacobianPrep(
185
198
end
186
199
187
200
function DI. prepare_jacobian (
188
- f:: F , backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} , x
189
- ) where {F}
190
- y = f (x)
201
+ f:: F ,
202
+ backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
203
+ x,
204
+ contexts:: Vararg{DI.Constant,C} ,
205
+ ) where {F,C}
206
+ y = f (x, map (DI. unwrap, contexts)... )
191
207
valB = to_val (DI. pick_batchsize (backend, x))
192
208
shadows = create_shadows (valB, x)
193
209
return EnzymeForwardOneArgJacobianPrep (valB, shadows, length (y))
@@ -198,11 +214,15 @@ function DI.jacobian(
198
214
prep:: EnzymeForwardOneArgJacobianPrep{B} ,
199
215
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
200
216
x,
201
- ) where {F,B}
217
+ contexts:: Vararg{DI.Constant,C} ,
218
+ ) where {F,B,C}
202
219
mode = forward_noprimal (backend)
203
220
f_and_df = get_f_and_df (f, backend, mode)
204
- derivs = jacobian (mode, f_and_df, x; chunk= Val (B), shadows= prep. shadows)
205
- jac_tensor = only (derivs)
221
+ annotated_contexts = translate (backend, mode, Val (B), contexts... )
222
+ derivs = jacobian (
223
+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep. shadows
224
+ )
225
+ jac_tensor = first (derivs)
206
226
return maybe_reshape (jac_tensor, prep. output_length, length (x))
207
227
end
208
228
@@ -211,11 +231,15 @@ function DI.value_and_jacobian(
211
231
prep:: EnzymeForwardOneArgJacobianPrep{B} ,
212
232
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
213
233
x,
214
- ) where {F,B}
234
+ contexts:: Vararg{DI.Constant,C} ,
235
+ ) where {F,B,C}
215
236
mode = forward_withprimal (backend)
216
237
f_and_df = get_f_and_df (f, backend, mode)
217
- (; derivs, val) = jacobian (mode, f_and_df, x; chunk= Val (B), shadows= prep. shadows)
218
- jac_tensor = only (derivs)
238
+ annotated_contexts = translate (backend, mode, Val (B), contexts... )
239
+ (; derivs, val) = jacobian (
240
+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep. shadows
241
+ )
242
+ jac_tensor = first (derivs)
219
243
return val, maybe_reshape (jac_tensor, prep. output_length, length (x))
220
244
end
221
245
@@ -225,8 +249,9 @@ function DI.jacobian!(
225
249
prep:: EnzymeForwardOneArgJacobianPrep ,
226
250
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
227
251
x,
228
- ) where {F}
229
- return copyto! (jac, DI. jacobian (f, prep, backend, x))
252
+ contexts:: Vararg{DI.Constant,C} ,
253
+ ) where {F,C}
254
+ return copyto! (jac, DI. jacobian (f, prep, backend, x, contexts... ))
230
255
end
231
256
232
257
function DI. value_and_jacobian! (
@@ -235,7 +260,8 @@ function DI.value_and_jacobian!(
235
260
prep:: EnzymeForwardOneArgJacobianPrep ,
236
261
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
237
262
x,
238
- ) where {F}
239
- y, new_jac = DI. value_and_jacobian (f, prep, backend, x)
263
+ contexts:: Vararg{DI.Constant,C} ,
264
+ ) where {F,C}
265
+ y, new_jac = DI. value_and_jacobian (f, prep, backend, x, contexts... )
240
266
return y, copyto! (jac, new_jac)
241
267
end
0 commit comments