diff --git a/src/lbfgsb.jl b/src/lbfgsb.jl index fcab0ae59..3cc89c609 100644 --- a/src/lbfgsb.jl +++ b/src/lbfgsb.jl @@ -116,13 +116,15 @@ function SciMLBase.__solve(cache::OptimizationCache{ cache.f.cons(cons_tmp, cache.u0) ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, cache.p))) / norm(cons_tmp))) + iter_count = Ref(0) _loss = function (θ) x = cache.f(θ, cache.p) + iter_count[] += 1 cons_tmp .= zero(eltype(θ)) cache.f.cons(cons_tmp, θ) cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds] - opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = cache.p) + opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = cache.p, iter = iter_count[]) if cache.callback(opt_state, x...) error("Optimization halted by callback.") end @@ -206,10 +208,11 @@ function SciMLBase.__solve(cache::OptimizationCache{ cache, cache.opt, res[2], cache.f(res[2], cache.p)[1], stats = stats, retcode = opt_ret) else + iter_count = Ref(0) _loss = function (θ) x = cache.f(θ, cache.p) - - opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = cache.p) + iter_count[] += 1 + opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = cache.p, iter = iter_count[]) if cache.callback(opt_state, x...) error("Optimization halted by callback.") end