Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions lib/OptimizationOptimJL/src/OptimizationOptimJL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ SciMLBase.requiresbounds(opt::Optim.SAMIN) = true
end
@static if isdefined(OptimizationBase, :supports_opt_cache_interface)
OptimizationBase.supports_opt_cache_interface(opt::Optim.AbstractOptimizer) = true
OptimizationBase.supports_opt_cache_interface(opt::Union{Optim.Fminbox, Optim.SAMIN}) = true
OptimizationBase.supports_opt_cache_interface(opt::Union{
Optim.Fminbox, Optim.SAMIN}) = true
OptimizationBase.supports_opt_cache_interface(opt::Optim.ConstrainedOptimizer) = true
end
function SciMLBase.requiresgradient(opt::Optim.AbstractOptimizer)
Expand Down Expand Up @@ -149,14 +150,17 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
trace_state = decompose_trace(trace)
metadata = trace_state.metadata
θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"]
# Extract scalar value from potentially Dual-valued trace (issue #1073)
# Using SciMLBase.value to handle ForwardDiff.Dual numbers from Fminbox
loss_val = SciMLBase.value(trace_state.value)
opt_state = OptimizationBase.OptimizationState(iter = trace_state.iteration,
u = θ,
p = cache.p,
objective = trace_state.value,
objective = loss_val,
grad = get(metadata, "g(x)", nothing),
hess = get(metadata, "h(x)", nothing),
original = trace)
cb_call = cache.callback(opt_state, trace_state.value)
cb_call = cache.callback(opt_state, loss_val)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
end
Expand Down Expand Up @@ -270,14 +274,17 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ?
metadata["centroid"] :
metadata["x"]
# Extract scalar value from potentially Dual-valued trace (issue #1073)
# Using SciMLBase.value to handle ForwardDiff.Dual numbers from Fminbox
loss_val = SciMLBase.value(trace_state.value)
opt_state = OptimizationBase.OptimizationState(iter = trace_state.iteration,
u = θ,
p = cache.p,
objective = trace_state.value,
objective = loss_val,
grad = get(metadata, "g(x)", nothing),
hess = get(metadata, "h(x)", nothing),
original = trace)
cb_call = cache.callback(opt_state, trace_state.value)
cb_call = cache.callback(opt_state, loss_val)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
end
Expand Down Expand Up @@ -357,14 +364,17 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{

function _cb(trace)
metadata = decompose_trace(trace).metadata
# Extract scalar value from potentially Dual-valued trace (issue #1073)
# Using SciMLBase.value to handle ForwardDiff.Dual numbers from Fminbox
loss_val = SciMLBase.value(trace.value)
opt_state = OptimizationBase.OptimizationState(iter = trace.iteration,
u = metadata["x"],
p = cache.p,
grad = get(metadata, "g(x)", nothing),
hess = get(metadata, "h(x)", nothing),
objective = trace.value,
objective = loss_val,
original = trace)
cb_call = cache.callback(opt_state, trace.value)
cb_call = cache.callback(opt_state, loss_val)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
end
Expand Down
34 changes: 32 additions & 2 deletions lib/OptimizationOptimJL/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ end
G[1] = -2.0 * (1.0 - x[1]) - 400.0 * (x[2] - x[1]^2) * x[1]
G[2] = 200.0 * (x[2] - x[1]^2)
end
optprob = OptimizationFunction((x, p) -> -rosenbrock(x, p), OptimizationBase.AutoZygote(),
optprob = OptimizationFunction(
(x, p) -> -rosenbrock(x, p), OptimizationBase.AutoZygote(),
grad = g!)
prob = OptimizationProblem(optprob, x0, _p; sense = OptimizationBase.MaxSense)
sol = solve(prob, BFGS())
Expand All @@ -171,7 +172,8 @@ end
@test 10 * sol.objective < l1

prob = OptimizationProblem(
optprob, x0, _p; sense = OptimizationBase.MaxSense, lb = [-1.0, -1.0], ub = [0.8, 0.8])
optprob, x0, _p; sense = OptimizationBase.MaxSense, lb = [-1.0, -1.0], ub = [
0.8, 0.8])
sol = solve(prob, BFGS())
@test 10 * sol.objective < l1

Expand Down Expand Up @@ -199,6 +201,34 @@ end
@test_throws ArgumentError (sol = solve(prob, Optim.BFGS())) isa Any # test exception is thrown
@test 10 * sol.objective < l1

# Test for issue #1073: callbacks should receive scalar non-negative loss values
# when using (L)BFGS with bounds and automatic differentiation
@testset "Issue #1073: LBFGS/BFGS callback receives correct scalar loss with bounds" begin
# Create a non-negative loss function (sum of squares)
loss_vals = Float64[]
function test_callback(state, loss_val)
# Verify loss_val is a scalar Float64, not a Dual number
@test loss_val isa Float64
# For a sum-of-squares loss, values should be non-negative
push!(loss_vals, loss_val)
return false
end

# Test with LBFGS + bounds (triggers Fminbox wrapping)
optprob = OptimizationFunction(rosenbrock, OptimizationBase.AutoForwardDiff())
prob = OptimizationProblem(optprob, x0, _p; lb = [-1.0, -1.0], ub = [0.8, 0.8])
empty!(loss_vals)
sol = solve(prob, Optim.LBFGS(); callback = test_callback, maxiters = 10)
@test all(>=(0), loss_vals) # All loss values should be non-negative
@test length(loss_vals) > 0 # Callback should have been called

# Test with BFGS + bounds
empty!(loss_vals)
sol = solve(prob, Optim.BFGS(); callback = test_callback, maxiters = 10)
@test all(>=(0), loss_vals) # All loss values should be non-negative
@test length(loss_vals) > 0 # Callback should have been called
end

@testset "cache" begin
objective(x, p) = (p[1] - x[1])^2
x0 = zeros(1)
Expand Down
Loading