Skip to content

Commit 7c28259

Browse files
Merge pull request #1075 from ChrisRackauckas-Claude/fix-lbfgs-callback-dual-numbers
Fix LBFGS/BFGS callback receiving Dual numbers instead of scalar loss values
2 parents dbeed2d + e4721a1 commit 7c28259

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ SciMLBase.requiresbounds(opt::Optim.SAMIN) = true
1818
end
1919
@static if isdefined(OptimizationBase, :supports_opt_cache_interface)
2020
OptimizationBase.supports_opt_cache_interface(opt::Optim.AbstractOptimizer) = true
21-
OptimizationBase.supports_opt_cache_interface(opt::Union{Optim.Fminbox, Optim.SAMIN}) = true
21+
OptimizationBase.supports_opt_cache_interface(opt::Union{
22+
Optim.Fminbox, Optim.SAMIN}) = true
2223
OptimizationBase.supports_opt_cache_interface(opt::Optim.ConstrainedOptimizer) = true
2324
end
2425
function SciMLBase.requiresgradient(opt::Optim.AbstractOptimizer)
@@ -149,14 +150,17 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
149150
trace_state = decompose_trace(trace)
150151
metadata = trace_state.metadata
151152
θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"]
153+
# Extract scalar value from potentially Dual-valued trace (issue #1073)
154+
# Using SciMLBase.value to handle ForwardDiff.Dual numbers from Fminbox
155+
loss_val = SciMLBase.value(trace_state.value)
152156
opt_state = OptimizationBase.OptimizationState(iter = trace_state.iteration,
153157
u = θ,
154158
p = cache.p,
155-
objective = trace_state.value,
159+
objective = loss_val,
156160
grad = get(metadata, "g(x)", nothing),
157161
hess = get(metadata, "h(x)", nothing),
158162
original = trace)
159-
cb_call = cache.callback(opt_state, trace_state.value)
163+
cb_call = cache.callback(opt_state, loss_val)
160164
if !(cb_call isa Bool)
161165
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
162166
end
@@ -270,14 +274,17 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
270274
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ?
271275
metadata["centroid"] :
272276
metadata["x"]
277+
# Extract scalar value from potentially Dual-valued trace (issue #1073)
278+
# Using SciMLBase.value to handle ForwardDiff.Dual numbers from Fminbox
279+
loss_val = SciMLBase.value(trace_state.value)
273280
opt_state = OptimizationBase.OptimizationState(iter = trace_state.iteration,
274281
u = θ,
275282
p = cache.p,
276-
objective = trace_state.value,
283+
objective = loss_val,
277284
grad = get(metadata, "g(x)", nothing),
278285
hess = get(metadata, "h(x)", nothing),
279286
original = trace)
280-
cb_call = cache.callback(opt_state, trace_state.value)
287+
cb_call = cache.callback(opt_state, loss_val)
281288
if !(cb_call isa Bool)
282289
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
283290
end
@@ -357,14 +364,17 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
357364

358365
function _cb(trace)
359366
metadata = decompose_trace(trace).metadata
367+
# Extract scalar value from potentially Dual-valued trace (issue #1073)
368+
# Using SciMLBase.value to handle ForwardDiff.Dual numbers from Fminbox
369+
loss_val = SciMLBase.value(trace.value)
360370
opt_state = OptimizationBase.OptimizationState(iter = trace.iteration,
361371
u = metadata["x"],
362372
p = cache.p,
363373
grad = get(metadata, "g(x)", nothing),
364374
hess = get(metadata, "h(x)", nothing),
365-
objective = trace.value,
375+
objective = loss_val,
366376
original = trace)
367-
cb_call = cache.callback(opt_state, trace.value)
377+
cb_call = cache.callback(opt_state, loss_val)
368378
if !(cb_call isa Bool)
369379
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
370380
end

lib/OptimizationOptimJL/test/runtests.jl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ end
150150
G[1] = -2.0 * (1.0 - x[1]) - 400.0 * (x[2] - x[1]^2) * x[1]
151151
G[2] = 200.0 * (x[2] - x[1]^2)
152152
end
153-
optprob = OptimizationFunction((x, p) -> -rosenbrock(x, p), OptimizationBase.AutoZygote(),
153+
optprob = OptimizationFunction(
154+
(x, p) -> -rosenbrock(x, p), OptimizationBase.AutoZygote(),
154155
grad = g!)
155156
prob = OptimizationProblem(optprob, x0, _p; sense = OptimizationBase.MaxSense)
156157
sol = solve(prob, BFGS())
@@ -171,7 +172,8 @@ end
171172
@test 10 * sol.objective < l1
172173

173174
prob = OptimizationProblem(
174-
optprob, x0, _p; sense = OptimizationBase.MaxSense, lb = [-1.0, -1.0], ub = [0.8, 0.8])
175+
optprob, x0, _p; sense = OptimizationBase.MaxSense, lb = [-1.0, -1.0], ub = [
176+
0.8, 0.8])
175177
sol = solve(prob, BFGS())
176178
@test 10 * sol.objective < l1
177179

@@ -199,6 +201,34 @@ end
199201
@test_throws ArgumentError (sol = solve(prob, Optim.BFGS())) isa Any # test exception is thrown
200202
@test 10 * sol.objective < l1
201203

204+
# Test for issue #1073: callbacks should receive scalar non-negative loss values
205+
# when using (L)BFGS with bounds and automatic differentiation
206+
@testset "Issue #1073: LBFGS/BFGS callback receives correct scalar loss with bounds" begin
207+
# Create a non-negative loss function (sum of squares)
208+
loss_vals = Float64[]
209+
function test_callback(state, loss_val)
210+
# Verify loss_val is a scalar Float64, not a Dual number
211+
@test loss_val isa Float64
212+
# For a sum-of-squares loss, values should be non-negative
213+
push!(loss_vals, loss_val)
214+
return false
215+
end
216+
217+
# Test with LBFGS + bounds (triggers Fminbox wrapping)
218+
optprob = OptimizationFunction(rosenbrock, OptimizationBase.AutoForwardDiff())
219+
prob = OptimizationProblem(optprob, x0, _p; lb = [-1.0, -1.0], ub = [0.8, 0.8])
220+
empty!(loss_vals)
221+
sol = solve(prob, Optim.LBFGS(); callback = test_callback, maxiters = 10)
222+
@test all(>=(0), loss_vals) # All loss values should be non-negative
223+
@test length(loss_vals) > 0 # Callback should have been called
224+
225+
# Test with BFGS + bounds
226+
empty!(loss_vals)
227+
sol = solve(prob, Optim.BFGS(); callback = test_callback, maxiters = 10)
228+
@test all(>=(0), loss_vals) # All loss values should be non-negative
229+
@test length(loss_vals) > 0 # Callback should have been called
230+
end
231+
202232
@testset "cache" begin
203233
objective(x, p) = (p[1] - x[1])^2
204234
x0 = zeros(1)

0 commit comments

Comments
 (0)