Skip to content

Commit cbd4d34

Browse files
Merge pull request #134 from SciML/fminboxminibatch
Pass args to gradient function in Fminbox/SAMIN dispatch
2 parents e1eaaf4 + c17a5eb commit cbd4d34

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/solve.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ function __solve(prob::OptimizationProblem, opt::Optim.AbstractOptimizer,
146146
end
147147
cur, state = iterate(data, state)
148148
cb_call
149-
end
149+
end
150150

151151
if !(isnothing(maxiters)) && maxiters <= 0.0
152152
error("The number of maxiters has to be a non-negative and non-zero number.")
@@ -203,13 +203,13 @@ function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN
203203

204204
cur, state = iterate(data)
205205

206-
function _cb(trace)
207-
cb_call = !(opt isa Optim.SAMIN) && opt.method == NelderMead() ? cb(decompose_trace(trace).metadata["centroid"],x...) : cb(decompose_trace(trace).metadata["x"],x...)
208-
if !(typeof(cb_call) <: Bool)
206+
function _cb(trace)
207+
cb_call = !(opt isa Optim.SAMIN) && opt.method == NelderMead() ? cb(decompose_trace(trace).metadata["centroid"],x...) : cb(decompose_trace(trace).metadata["x"],x...)
208+
if !(typeof(cb_call) <: Bool)
209209
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
210210
end
211211
cur, state = iterate(data, state)
212-
cb_call
212+
cb_call
213213
end
214214

215215
if !(isnothing(maxiters)) && maxiters <= 0.0
@@ -233,7 +233,7 @@ function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN
233233

234234
return _loss(θ)
235235
end
236-
optim_f = OnceDifferentiable(_loss, f.grad, fg!, prob.u0)
236+
optim_f = OnceDifferentiable(_loss, (G, θ) -> f.grad(G, θ, cur...), fg!, prob.u0)
237237

238238
original = Optim.optimize(optim_f, prob.lb, prob.ub, prob.u0, opt,
239239
!(isnothing(maxiters)) ? Optim.Options(;

0 commit comments

Comments
 (0)