Skip to content

Commit 17705af

Browse files
Merge pull request #65 from SciML/dataarg
Add data argument to solve
2 parents deeaea9 + d93a7be commit 17705af

File tree

5 files changed

+163
-73
lines changed

5 files changed

+163
-73
lines changed

src/GalacticOptim.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ using Reexport
44
@reexport using DiffEqBase
55
using Requires
66
using DiffResults, ForwardDiff, Zygote, ReverseDiff, Tracker, FiniteDiff
7-
using Optim, Flux
7+
@reexport using Optim, Flux
88
using Logging, ProgressLogging, Printf, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
9-
using ArrayInterface
9+
using ArrayInterface, Base.Iterators
1010

1111
using ForwardDiff: DEFAULT_CHUNK_THRESHOLD
1212
import DiffEqBase: OptimizationProblem, OptimizationFunction, AbstractADType

src/function.jl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,26 @@ function instantiate_function(f, x, ::AutoForwardDiff{_chunksize}, p, num_cons =
4141

4242
chunksize = _chunksize === nothing ? default_chunk_size(length(x)) : _chunksize
4343

44-
_f = θ -> first(f.f(θ,p))
44+
_f = (θ, args...) -> first(f.f(θ, p, args...))
4545

4646
if f.grad === nothing
47-
gradcfg = ForwardDiff.GradientConfig(_f, x, ForwardDiff.Chunk{chunksize}())
48-
grad = (res,θ) -> ForwardDiff.gradient!(res, _f, θ, gradcfg)
47+
gradcfg = (args...) -> ForwardDiff.GradientConfig(x -> _f(x, args...), x, ForwardDiff.Chunk{chunksize}())
48+
grad = (res, θ, args...) -> ForwardDiff.gradient!(res, x -> _f(x, args...), θ, gradcfg(args...), Val{false}())
4949
else
5050
grad = f.grad
5151
end
5252

5353
if f.hess === nothing
54-
hesscfg = ForwardDiff.HessianConfig(_f, x, ForwardDiff.Chunk{chunksize}())
55-
hess = (res,θ) -> ForwardDiff.hessian!(res, _f, θ, hesscfg)
54+
hesscfg = (args...) -> ForwardDiff.HessianConfig(x -> _f(x, args...), x, ForwardDiff.Chunk{chunksize}())
55+
hess = (res, θ, args...) -> ForwardDiff.hessian!(res, x -> _f(x, args...), θ, hesscfg(args...), Val{false}())
5656
else
5757
hess = f.hess
5858
end
5959

6060
if f.hv === nothing
61-
hv = function (H,θ,v)
61+
hv = function (H,θ,v, args...)
6262
res = ArrayInterface.zeromatrix(θ)
63-
hess(res, θ)
63+
hess(res, θ, args...)
6464
H .= res*v
6565
end
6666
else
@@ -101,34 +101,34 @@ end
101101
function instantiate_function(f, x, ::AutoZygote, p, num_cons = 0)
102102
num_cons != 0 && error("AutoZygote does not currently support constraints")
103103

104-
_f = θ -> f(θ,p)[1]
104+
_f = (θ, args...) -> f(θ,p,args...)[1]
105105
if f.grad === nothing
106-
grad = (res,θ) -> res isa DiffResults.DiffResult ? DiffResults.gradient!(res, Zygote.gradient(_f, θ)[1]) : res .= Zygote.gradient(_f, θ)[1]
106+
grad = (res, θ, args...) -> res isa DiffResults.DiffResult ? DiffResults.gradient!(res, Zygote.gradient(x -> _f(x, args...), θ)[1]) : res .= Zygote.gradient(x -> _f(x, args...), θ)[1]
107107
else
108108
grad = f.grad
109109
end
110110

111111
if f.hess === nothing
112-
hess = function (res,θ)
112+
hess = function (res, θ, args...)
113113
if res isa DiffResults.DiffResult
114114
DiffResults.hessian!(res, ForwardDiff.jacobian(θ) do θ
115-
Zygote.gradient(_f,θ)[1]
115+
Zygote.gradient(x -> _f(x, args...), θ)[1]
116116
end)
117117
else
118118
res .= ForwardDiff.jacobian(θ) do θ
119-
Zygote.gradient(_f,θ)[1]
120-
end
119+
Zygote.gradient(x ->_f(x, args...), θ)[1]
120+
end
121121
end
122122
end
123123
else
124124
hess = f.hess
125125
end
126126

127127
if f.hv === nothing
128-
hv = function (H,θ,v)
129-
= ForwardDiff.Dual.(θ,v)
128+
hv = function (H, θ, v, args...)
129+
= ForwardDiff.Dual.(θ, v)
130130
res = DiffResults.GradientResult(_θ)
131-
grad(res,)
131+
grad(res, _θ, args...)
132132
H .= getindex.(ForwardDiff.partials.(DiffResults.gradient(res)),1)
133133
end
134134
else
@@ -141,23 +141,23 @@ end
141141
function instantiate_function(f, x, ::AutoReverseDiff, p=DiffEqBase.NullParameters(), num_cons = 0)
142142
num_cons != 0 && error("AutoReverseDiff does not currently support constraints")
143143

144-
_f = θ -> f.f(θ,p)[1]
144+
_f = (θ, args...) -> first(f.f(θ,p, args...))
145145

146146
if f.grad === nothing
147-
grad = (res,θ) -> ReverseDiff.gradient!(res, _f, θ, ReverseDiff.GradientConfig(θ))
147+
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, ReverseDiff.GradientConfig(θ))
148148
else
149149
grad = f.grad
150150
end
151151

152152
if f.hess === nothing
153-
hess = function (res,θ)
153+
hess = function (res, θ, args...)
154154
if res isa DiffResults.DiffResult
155155
DiffResults.hessian!(res, ForwardDiff.jacobian(θ) do θ
156-
ReverseDiff.gradient(_f,θ)[1]
156+
ReverseDiff.gradient(x -> _f(x, args...), θ)[1]
157157
end)
158158
else
159159
res .= ForwardDiff.jacobian(θ) do θ
160-
ReverseDiff.gradient(_f,θ)
160+
ReverseDiff.gradient(x ->_f(x, args...), θ)
161161
end
162162
end
163163
end
@@ -167,10 +167,10 @@ function instantiate_function(f, x, ::AutoReverseDiff, p=DiffEqBase.NullParamete
167167

168168

169169
if f.hv === nothing
170-
hv = function (H,θ,v)
170+
hv = function (H,θ,v, args...)
171171
= ForwardDiff.Dual.(θ,v)
172172
res = DiffResults.GradientResult(_θ)
173-
grad(res,)
173+
grad(res, _θ, args...)
174174
H .= getindex.(ForwardDiff.partials.(DiffResults.gradient(res)),1)
175175
end
176176
else
@@ -183,22 +183,22 @@ end
183183

184184
function instantiate_function(f, x, ::AutoTracker, p, num_cons = 0)
185185
num_cons != 0 && error("AutoTracker does not currently support constraints")
186-
_f = θ -> f.f(θ,p)[1]
186+
_f = (θ, args...) -> first(f.f(θ, p, args...))
187187

188188
if f.grad === nothing
189-
grad = (res,θ) -> res isa DiffResults.DiffResult ? DiffResults.gradient!(res, Tracker.data(Tracker.gradient(_f, θ)[1])) : res .= Tracker.data(Tracker.gradient(_f, θ)[1])
189+
grad = (res, θ, args...) -> res isa DiffResults.DiffResult ? DiffResults.gradient!(res, Tracker.data(Tracker.gradient(x -> _f(x, args...), θ)[1])) : res .= Tracker.data(Tracker.gradient(x -> _f(x, args...), θ)[1])
190190
else
191191
grad = f.grad
192192
end
193193

194194
if f.hess === nothing
195-
hess = (res, θ) -> error("Hessian based methods not supported with Tracker backend, pass in the `hess` kwarg")
195+
hess = (res, θ, args...) -> error("Hessian based methods not supported with Tracker backend, pass in the `hess` kwarg")
196196
else
197197
hess = f.hess
198198
end
199199

200200
if f.hv === nothing
201-
hv = (res, θ) -> error("Hessian based methods not supported with Tracker backend, pass in the `hess` and `hv` kwargs")
201+
hv = (res, θ, args...) -> error("Hessian based methods not supported with Tracker backend, pass in the `hess` and `hv` kwargs")
202202
else
203203
hv = f.hv
204204
end
@@ -209,24 +209,24 @@ end
209209

210210
function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
211211
num_cons != 0 && error("AutoFiniteDiff does not currently support constraints")
212-
_f = θ -> f.f(θ,p)[1]
212+
_f = (θ, args...) -> first(f.f(θ, p, args...))
213213

214214
if f.grad === nothing
215-
grad = (res,θ) -> FiniteDiff.finite_difference_gradient!(res, _f, θ, FiniteDiff.GradientCache(res, x, adtype.fdtype))
215+
grad = (res, θ, args...) -> FiniteDiff.finite_difference_gradient!(res,x ->_f(x, args...), θ, FiniteDiff.GradientCache(res, x, adtype.fdtype))
216216
else
217217
grad = f.grad
218218
end
219219

220220
if f.hess === nothing
221-
hess = (res,θ) -> FiniteDiff.finite_difference_hessian!(res, _f, θ, FiniteDiff.HessianCache(x, adtype.fdhtype))
221+
hess = (res, θ, args...) -> FiniteDiff.finite_difference_hessian!(res,x ->_f(x, args...), θ, FiniteDiff.HessianCache(x, adtype.fdhtype))
222222
else
223223
hess = f.hess
224224
end
225225

226226
if f.hv === nothing
227-
hv = function (H,θ,v)
227+
hv = function (H, θ, v, args...)
228228
res = ArrayInterface.zeromatrix(θ)
229-
hess(res, θ)
229+
hess(res, θ, args...)
230230
H .= res*v
231231
end
232232
else

0 commit comments

Comments
 (0)