@@ -41,26 +41,26 @@ function instantiate_function(f, x, ::AutoForwardDiff{_chunksize}, p, num_cons =
41
41
42
42
chunksize = _chunksize === nothing ? default_chunk_size (length (x)) : _chunksize
43
43
44
- _f = θ -> first (f. f (θ,p ))
44
+ _f = (θ, args ... ) -> first (f. f (θ, p, args ... ))
45
45
46
46
if f. grad === nothing
47
- gradcfg = ForwardDiff. GradientConfig (_f , x, ForwardDiff. Chunk {chunksize} ())
48
- grad = (res,θ ) -> ForwardDiff. gradient! (res, _f, θ, gradcfg)
47
+ gradcfg = (args ... ) -> ForwardDiff. GradientConfig (x -> _f (x, args ... ) , x, ForwardDiff. Chunk {chunksize} ())
48
+ grad = (res, θ, args ... ) -> ForwardDiff. gradient! (res, x -> _f (x, args ... ), θ, gradcfg (args ... ), Val {false} () )
49
49
else
50
50
grad = f. grad
51
51
end
52
52
53
53
if f. hess === nothing
54
- hesscfg = ForwardDiff. HessianConfig (_f , x, ForwardDiff. Chunk {chunksize} ())
55
- hess = (res,θ ) -> ForwardDiff. hessian! (res, _f, θ, hesscfg)
54
+ hesscfg = (args ... ) -> ForwardDiff. HessianConfig (x -> _f (x, args ... ) , x, ForwardDiff. Chunk {chunksize} ())
55
+ hess = (res, θ, args ... ) -> ForwardDiff. hessian! (res, x -> _f (x, args ... ), θ, hesscfg (args ... ), Val {false} () )
56
56
else
57
57
hess = f. hess
58
58
end
59
59
60
60
if f. hv === nothing
61
- hv = function (H,θ,v)
61
+ hv = function (H,θ,v, args ... )
62
62
res = ArrayInterface. zeromatrix (θ)
63
- hess (res, θ)
63
+ hess (res, θ, args ... )
64
64
H .= res* v
65
65
end
66
66
else
@@ -101,34 +101,34 @@ end
101
101
function instantiate_function (f, x, :: AutoZygote , p, num_cons = 0 )
102
102
num_cons != 0 && error (" AutoZygote does not currently support constraints" )
103
103
104
- _f = θ -> f (θ,p)[1 ]
104
+ _f = (θ, args ... ) -> f (θ,p,args ... )[1 ]
105
105
if f. grad === nothing
106
- grad = (res,θ ) -> res isa DiffResults. DiffResult ? DiffResults. gradient! (res, Zygote. gradient (_f, θ)[1 ]) : res .= Zygote. gradient (_f , θ)[1 ]
106
+ grad = (res, θ, args ... ) -> res isa DiffResults. DiffResult ? DiffResults. gradient! (res, Zygote. gradient (x -> _f (x, args ... ), θ)[1 ]) : res .= Zygote. gradient (x -> _f (x, args ... ) , θ)[1 ]
107
107
else
108
108
grad = f. grad
109
109
end
110
110
111
111
if f. hess === nothing
112
- hess = function (res,θ )
112
+ hess = function (res, θ, args ... )
113
113
if res isa DiffResults. DiffResult
114
114
DiffResults. hessian! (res, ForwardDiff. jacobian (θ) do θ
115
- Zygote. gradient (_f, θ)[1 ]
115
+ Zygote. gradient (x -> _f (x, args ... ), θ)[1 ]
116
116
end )
117
117
else
118
118
res .= ForwardDiff. jacobian (θ) do θ
119
- Zygote. gradient (_f, θ)[1 ]
120
- end
119
+ Zygote. gradient (x -> _f (x, args ... ), θ)[1 ]
120
+ end
121
121
end
122
122
end
123
123
else
124
124
hess = f. hess
125
125
end
126
126
127
127
if f. hv === nothing
128
- hv = function (H,θ,v )
129
- _θ = ForwardDiff. Dual .(θ,v)
128
+ hv = function (H, θ, v, args ... )
129
+ _θ = ForwardDiff. Dual .(θ, v)
130
130
res = DiffResults. GradientResult (_θ)
131
- grad (res,_θ )
131
+ grad (res, _θ, args ... )
132
132
H .= getindex .(ForwardDiff. partials .(DiffResults. gradient (res)),1 )
133
133
end
134
134
else
@@ -141,23 +141,23 @@ end
141
141
function instantiate_function (f, x, :: AutoReverseDiff , p= DiffEqBase. NullParameters (), num_cons = 0 )
142
142
num_cons != 0 && error (" AutoReverseDiff does not currently support constraints" )
143
143
144
- _f = θ -> f. f (θ,p)[ 1 ]
144
+ _f = (θ, args ... ) -> first ( f. f (θ,p, args ... ))
145
145
146
146
if f. grad === nothing
147
- grad = (res,θ ) -> ReverseDiff. gradient! (res, _f , θ, ReverseDiff. GradientConfig (θ))
147
+ grad = (res, θ, args ... ) -> ReverseDiff. gradient! (res, x -> _f (x, args ... ) , θ, ReverseDiff. GradientConfig (θ))
148
148
else
149
149
grad = f. grad
150
150
end
151
151
152
152
if f. hess === nothing
153
- hess = function (res,θ )
153
+ hess = function (res, θ, args ... )
154
154
if res isa DiffResults. DiffResult
155
155
DiffResults. hessian! (res, ForwardDiff. jacobian (θ) do θ
156
- ReverseDiff. gradient (_f, θ)[1 ]
156
+ ReverseDiff. gradient (x -> _f (x, args ... ), θ)[1 ]
157
157
end )
158
158
else
159
159
res .= ForwardDiff. jacobian (θ) do θ
160
- ReverseDiff. gradient (_f, θ)
160
+ ReverseDiff. gradient (x -> _f (x, args ... ), θ)
161
161
end
162
162
end
163
163
end
@@ -167,10 +167,10 @@ function instantiate_function(f, x, ::AutoReverseDiff, p=DiffEqBase.NullParamete
167
167
168
168
169
169
if f. hv === nothing
170
- hv = function (H,θ,v)
170
+ hv = function (H,θ,v, args ... )
171
171
_θ = ForwardDiff. Dual .(θ,v)
172
172
res = DiffResults. GradientResult (_θ)
173
- grad (res,_θ )
173
+ grad (res, _θ, args ... )
174
174
H .= getindex .(ForwardDiff. partials .(DiffResults. gradient (res)),1 )
175
175
end
176
176
else
@@ -183,22 +183,22 @@ end
183
183
184
184
function instantiate_function (f, x, :: AutoTracker , p, num_cons = 0 )
185
185
num_cons != 0 && error (" AutoTracker does not currently support constraints" )
186
- _f = θ -> f. f (θ,p)[ 1 ]
186
+ _f = (θ, args ... ) -> first ( f. f (θ, p, args ... ))
187
187
188
188
if f. grad === nothing
189
- grad = (res,θ ) -> res isa DiffResults. DiffResult ? DiffResults. gradient! (res, Tracker. data (Tracker. gradient (_f, θ)[1 ])) : res .= Tracker. data (Tracker. gradient (_f , θ)[1 ])
189
+ grad = (res, θ, args ... ) -> res isa DiffResults. DiffResult ? DiffResults. gradient! (res, Tracker. data (Tracker. gradient (x -> _f (x, args ... ), θ)[1 ])) : res .= Tracker. data (Tracker. gradient (x -> _f (x, args ... ) , θ)[1 ])
190
190
else
191
191
grad = f. grad
192
192
end
193
193
194
194
if f. hess === nothing
195
- hess = (res, θ) -> error (" Hessian based methods not supported with Tracker backend, pass in the `hess` kwarg" )
195
+ hess = (res, θ, args ... ) -> error (" Hessian based methods not supported with Tracker backend, pass in the `hess` kwarg" )
196
196
else
197
197
hess = f. hess
198
198
end
199
199
200
200
if f. hv === nothing
201
- hv = (res, θ) -> error (" Hessian based methods not supported with Tracker backend, pass in the `hess` and `hv` kwargs" )
201
+ hv = (res, θ, args ... ) -> error (" Hessian based methods not supported with Tracker backend, pass in the `hess` and `hv` kwargs" )
202
202
else
203
203
hv = f. hv
204
204
end
@@ -209,24 +209,24 @@ end
209
209
210
210
function instantiate_function (f, x, adtype:: AutoFiniteDiff , p, num_cons = 0 )
211
211
num_cons != 0 && error (" AutoFiniteDiff does not currently support constraints" )
212
- _f = θ -> f. f (θ,p)[ 1 ]
212
+ _f = (θ, args ... ) -> first ( f. f (θ, p, args ... ))
213
213
214
214
if f. grad === nothing
215
- grad = (res,θ ) -> FiniteDiff. finite_difference_gradient! (res, _f , θ, FiniteDiff. GradientCache (res, x, adtype. fdtype))
215
+ grad = (res, θ, args ... ) -> FiniteDiff. finite_difference_gradient! (res,x -> _f (x, args ... ) , θ, FiniteDiff. GradientCache (res, x, adtype. fdtype))
216
216
else
217
217
grad = f. grad
218
218
end
219
219
220
220
if f. hess === nothing
221
- hess = (res,θ ) -> FiniteDiff. finite_difference_hessian! (res, _f , θ, FiniteDiff. HessianCache (x, adtype. fdhtype))
221
+ hess = (res, θ, args ... ) -> FiniteDiff. finite_difference_hessian! (res,x -> _f (x, args ... ) , θ, FiniteDiff. HessianCache (x, adtype. fdhtype))
222
222
else
223
223
hess = f. hess
224
224
end
225
225
226
226
if f. hv === nothing
227
- hv = function (H,θ,v )
227
+ hv = function (H, θ, v, args ... )
228
228
res = ArrayInterface. zeromatrix (θ)
229
- hess (res, θ)
229
+ hess (res, θ, args ... )
230
230
H .= res* v
231
231
end
232
232
else
0 commit comments