Skip to content

Commit e4721a1

Browse files
Fix LBFGS/BFGS callback receiving Dual numbers instead of scalar loss values
Fixes #1073 When using LBFGS or BFGS with bounds, Optim.jl wraps the optimizer in Fminbox, which may use ForwardDiff internally for gradient computation. This resulted in the callback receiving ForwardDiff.Dual numbers instead of scalar loss values, causing incorrect (sometimes negative) values to be reported. Changes: - Use SciMLBase.value() to extract scalar values from potentially Dual-valued traces - Updated all three _cb callback functions to use SciMLBase.value() - Added comprehensive test case verifying callbacks receive correct scalar non-negative values - No new dependencies required (SciMLBase already provides the functionality) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 3b61d27 commit e4721a1

File tree

3 files changed

+16
-21
lines changed

3 files changed

+16
-21
lines changed

lib/OptimizationOptimJL/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ uuid = "36348300-93cb-4f02-beb5-3c3902f8871e"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
44
version = "0.4.6"
55
[deps]
6-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
76
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
87
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
98
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -12,6 +11,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1211
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1312

1413
[extras]
14+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1515
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1616
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1717
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -28,4 +28,4 @@ Reexport = "1.2"
2828
SciMLBase = "2.58"
2929

3030
[targets]
31-
test = ["ModelingToolkit", "Random", "ReverseDiff", "Test", "Zygote"]
31+
test = ["ForwardDiff", "ModelingToolkit", "Random", "ReverseDiff", "Test", "Zygote"]

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,9 @@ module OptimizationOptimJL
33
using Reexport
44
@reexport using Optim, OptimizationBase
55
using SciMLBase, SparseArrays
6-
7-
# Import ForwardDiff to handle Dual numbers in callbacks
8-
import ForwardDiff
9-
106
decompose_trace(trace::Optim.OptimizationTrace) = last(trace)
117
decompose_trace(trace::Optim.OptimizationState) = trace
128

13-
# Extract scalar value from potentially Dual-valued trace values
14-
# This is needed because Optim.jl may use ForwardDiff internally for gradient computation,
15-
# resulting in Dual numbers in the trace, but callbacks should receive scalar values
16-
_scalar_value(x) = x # Default case for regular numbers
17-
_scalar_value(x::ForwardDiff.Dual) = ForwardDiff.value(x) # Extract value from Dual
18-
199
SciMLBase.allowsconstraints(::IPNewton) = true
2010
SciMLBase.allowsbounds(opt::Optim.AbstractOptimizer) = true
2111
SciMLBase.allowsbounds(opt::Optim.SimulatedAnnealing) = false
@@ -28,7 +18,8 @@ SciMLBase.requiresbounds(opt::Optim.SAMIN) = true
2818
end
2919
@static if isdefined(OptimizationBase, :supports_opt_cache_interface)
3020
OptimizationBase.supports_opt_cache_interface(opt::Optim.AbstractOptimizer) = true
31-
OptimizationBase.supports_opt_cache_interface(opt::Union{Optim.Fminbox, Optim.SAMIN}) = true
21+
OptimizationBase.supports_opt_cache_interface(opt::Union{
22+
Optim.Fminbox, Optim.SAMIN}) = true
3223
OptimizationBase.supports_opt_cache_interface(opt::Optim.ConstrainedOptimizer) = true
3324
end
3425
function SciMLBase.requiresgradient(opt::Optim.AbstractOptimizer)
@@ -160,7 +151,8 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
160151
metadata = trace_state.metadata
161152
θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"]
162153
# Extract scalar value from potentially Dual-valued trace (issue #1073)
163-
loss_val = _scalar_value(trace_state.value)
154+
# Using SciMLBase.value to handle ForwardDiff.Dual numbers from Fminbox
155+
loss_val = SciMLBase.value(trace_state.value)
164156
opt_state = OptimizationBase.OptimizationState(iter = trace_state.iteration,
165157
u = θ,
166158
p = cache.p,
@@ -283,7 +275,8 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
283275
metadata["centroid"] :
284276
metadata["x"]
285277
# Extract scalar value from potentially Dual-valued trace (issue #1073)
286-
loss_val = _scalar_value(trace_state.value)
278+
# Using SciMLBase.value to handle ForwardDiff.Dual numbers from Fminbox
279+
loss_val = SciMLBase.value(trace_state.value)
287280
opt_state = OptimizationBase.OptimizationState(iter = trace_state.iteration,
288281
u = θ,
289282
p = cache.p,
@@ -372,7 +365,8 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
372365
function _cb(trace)
373366
metadata = decompose_trace(trace).metadata
374367
# Extract scalar value from potentially Dual-valued trace (issue #1073)
375-
loss_val = _scalar_value(trace.value)
368+
# Using SciMLBase.value to handle ForwardDiff.Dual numbers from Fminbox
369+
loss_val = SciMLBase.value(trace.value)
376370
opt_state = OptimizationBase.OptimizationState(iter = trace.iteration,
377371
u = metadata["x"],
378372
p = cache.p,

lib/OptimizationOptimJL/test/runtests.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ end
150150
G[1] = -2.0 * (1.0 - x[1]) - 400.0 * (x[2] - x[1]^2) * x[1]
151151
G[2] = 200.0 * (x[2] - x[1]^2)
152152
end
153-
optprob = OptimizationFunction((x, p) -> -rosenbrock(x, p), OptimizationBase.AutoZygote(),
153+
optprob = OptimizationFunction(
154+
(x, p) -> -rosenbrock(x, p), OptimizationBase.AutoZygote(),
154155
grad = g!)
155156
prob = OptimizationProblem(optprob, x0, _p; sense = OptimizationBase.MaxSense)
156157
sol = solve(prob, BFGS())
@@ -171,7 +172,8 @@ end
171172
@test 10 * sol.objective < l1
172173

173174
prob = OptimizationProblem(
174-
optprob, x0, _p; sense = OptimizationBase.MaxSense, lb = [-1.0, -1.0], ub = [0.8, 0.8])
175+
optprob, x0, _p; sense = OptimizationBase.MaxSense, lb = [-1.0, -1.0], ub = [
176+
0.8, 0.8])
175177
sol = solve(prob, BFGS())
176178
@test 10 * sol.objective < l1
177179

@@ -205,9 +207,8 @@ end
205207
# Create a non-negative loss function (sum of squares)
206208
loss_vals = Float64[]
207209
function test_callback(state, loss_val)
208-
# Verify loss_val is a scalar, not a Dual number
209-
@test loss_val isa Real
210-
@test !(loss_val isa ForwardDiff.Dual)
210+
# Verify loss_val is a scalar Float64, not a Dual number
211+
@test loss_val isa Float64
211212
# For a sum-of-squares loss, values should be non-negative
212213
push!(loss_vals, loss_val)
213214
return false

0 commit comments

Comments
 (0)