Skip to content

Commit 3b61d27

Browse files
Fix LBFGS/BFGS callback receiving Dual numbers instead of scalar loss values
Fixes #1073 When using LBFGS or BFGS with bounds, Optim.jl wraps the optimizer in Fminbox, which may use ForwardDiff internally for gradient computation. This resulted in the callback receiving ForwardDiff.Dual numbers instead of scalar loss values, causing incorrect (sometimes negative) values to be reported. Changes: - Added ForwardDiff as a dependency in OptimizationOptimJL - Added _scalar_value() utility function to extract scalar values from Dual numbers - Updated all three _cb callback functions to extract scalar values before passing to user callbacks - Added comprehensive test case verifying callbacks receive correct scalar non-negative values 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 36fbd85 commit 3b61d27

File tree

3 files changed

+53
-8
lines changed

3 files changed

+53
-8
lines changed

lib/OptimizationOptimJL/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "36348300-93cb-4f02-beb5-3c3902f8871e"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
44
version = "0.4.6"
55
[deps]
6+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
67
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
78
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
89
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -11,7 +12,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1112
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1213

1314
[extras]
14-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1515
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1616
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1717
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -28,4 +28,4 @@ Reexport = "1.2"
2828
SciMLBase = "2.58"
2929

3030
[targets]
31-
test = ["ForwardDiff", "ModelingToolkit", "Random", "ReverseDiff", "Test", "Zygote"]
31+
test = ["ModelingToolkit", "Random", "ReverseDiff", "Test", "Zygote"]

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,19 @@ module OptimizationOptimJL
33
using Reexport
44
@reexport using Optim, OptimizationBase
55
using SciMLBase, SparseArrays
6+
7+
# Import ForwardDiff to handle Dual numbers in callbacks
8+
import ForwardDiff
9+
610
decompose_trace(trace::Optim.OptimizationTrace) = last(trace)
711
decompose_trace(trace::Optim.OptimizationState) = trace
812

13+
# Extract scalar value from potentially Dual-valued trace values
14+
# This is needed because Optim.jl may use ForwardDiff internally for gradient computation,
15+
# resulting in Dual numbers in the trace, but callbacks should receive scalar values
16+
_scalar_value(x) = x # Default case for regular numbers
17+
_scalar_value(x::ForwardDiff.Dual) = ForwardDiff.value(x) # Extract value from Dual
18+
919
SciMLBase.allowsconstraints(::IPNewton) = true
1020
SciMLBase.allowsbounds(opt::Optim.AbstractOptimizer) = true
1121
SciMLBase.allowsbounds(opt::Optim.SimulatedAnnealing) = false
@@ -149,14 +159,16 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
149159
trace_state = decompose_trace(trace)
150160
metadata = trace_state.metadata
151161
θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"]
162+
# Extract scalar value from potentially Dual-valued trace (issue #1073)
163+
loss_val = _scalar_value(trace_state.value)
152164
opt_state = OptimizationBase.OptimizationState(iter = trace_state.iteration,
153165
u = θ,
154166
p = cache.p,
155-
objective = trace_state.value,
167+
objective = loss_val,
156168
grad = get(metadata, "g(x)", nothing),
157169
hess = get(metadata, "h(x)", nothing),
158170
original = trace)
159-
cb_call = cache.callback(opt_state, trace_state.value)
171+
cb_call = cache.callback(opt_state, loss_val)
160172
if !(cb_call isa Bool)
161173
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
162174
end
@@ -270,14 +282,16 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
270282
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ?
271283
metadata["centroid"] :
272284
metadata["x"]
285+
# Extract scalar value from potentially Dual-valued trace (issue #1073)
286+
loss_val = _scalar_value(trace_state.value)
273287
opt_state = OptimizationBase.OptimizationState(iter = trace_state.iteration,
274288
u = θ,
275289
p = cache.p,
276-
objective = trace_state.value,
290+
objective = loss_val,
277291
grad = get(metadata, "g(x)", nothing),
278292
hess = get(metadata, "h(x)", nothing),
279293
original = trace)
280-
cb_call = cache.callback(opt_state, trace_state.value)
294+
cb_call = cache.callback(opt_state, loss_val)
281295
if !(cb_call isa Bool)
282296
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
283297
end
@@ -357,14 +371,16 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
357371

358372
function _cb(trace)
359373
metadata = decompose_trace(trace).metadata
374+
# Extract scalar value from potentially Dual-valued trace (issue #1073)
375+
loss_val = _scalar_value(trace.value)
360376
opt_state = OptimizationBase.OptimizationState(iter = trace.iteration,
361377
u = metadata["x"],
362378
p = cache.p,
363379
grad = get(metadata, "g(x)", nothing),
364380
hess = get(metadata, "h(x)", nothing),
365-
objective = trace.value,
381+
objective = loss_val,
366382
original = trace)
367-
cb_call = cache.callback(opt_state, trace.value)
383+
cb_call = cache.callback(opt_state, loss_val)
368384
if !(cb_call isa Bool)
369385
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
370386
end

lib/OptimizationOptimJL/test/runtests.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,35 @@ end
199199
@test_throws ArgumentError (sol = solve(prob, Optim.BFGS())) isa Any # test exception is thrown
200200
@test 10 * sol.objective < l1
201201

202+
# Test for issue #1073: callbacks should receive scalar non-negative loss values
203+
# when using (L)BFGS with bounds and automatic differentiation
204+
@testset "Issue #1073: LBFGS/BFGS callback receives correct scalar loss with bounds" begin
205+
# Create a non-negative loss function (sum of squares)
206+
loss_vals = Float64[]
207+
function test_callback(state, loss_val)
208+
# Verify loss_val is a scalar, not a Dual number
209+
@test loss_val isa Real
210+
@test !(loss_val isa ForwardDiff.Dual)
211+
# For a sum-of-squares loss, values should be non-negative
212+
push!(loss_vals, loss_val)
213+
return false
214+
end
215+
216+
# Test with LBFGS + bounds (triggers Fminbox wrapping)
217+
optprob = OptimizationFunction(rosenbrock, OptimizationBase.AutoForwardDiff())
218+
prob = OptimizationProblem(optprob, x0, _p; lb = [-1.0, -1.0], ub = [0.8, 0.8])
219+
empty!(loss_vals)
220+
sol = solve(prob, Optim.LBFGS(); callback = test_callback, maxiters = 10)
221+
@test all(>=(0), loss_vals) # All loss values should be non-negative
222+
@test length(loss_vals) > 0 # Callback should have been called
223+
224+
# Test with BFGS + bounds
225+
empty!(loss_vals)
226+
sol = solve(prob, Optim.BFGS(); callback = test_callback, maxiters = 10)
227+
@test all(>=(0), loss_vals) # All loss values should be non-negative
228+
@test length(loss_vals) > 0 # Callback should have been called
229+
end
230+
202231
@testset "cache" begin
203232
objective(x, p) = (p[1] - x[1])^2
204233
x0 = zeros(1)

0 commit comments

Comments
 (0)