Skip to content

Commit 51bd57a

Browse files
Fix solve interface to use CommonSolve dispatches correctly
The `__init` and `__solve` functions are internal hooks, but the actual dispatches should be to CommonSolve.jl's `init`, `solve`, and `solve!` functions (which are imported via SciMLBase). Changes: - Import `init`, `solve`, `solve!`, `__init`, and `__solve` from SciMLBase - Change function definitions from `SciMLBase.solve` to `solve` to properly extend the CommonSolve interface - Remove `SciMLBase.` prefix from function calls to use the imported functions directly - Keep type annotations with `SciMLBase.` prefix (these are correct) This matches the pattern used in other SciML packages like OrdinaryDiffEq.jl and aligns with the CommonSolve.jl interface design. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent bf9036b commit 51bd57a

File tree

2 files changed

+26
-25
lines changed

2 files changed

+26
-25
lines changed

lib/OptimizationBase/src/OptimizationBase.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ import SciMLBase: OptimizationProblem,
1212
allowsconstraints, requiresconstraints,
1313
allowscallback, requiresgradient,
1414
requireshessian, requiresconsjac,
15-
requiresconshess, supports_opt_cache_interface
15+
requiresconshess, supports_opt_cache_interface,
16+
__init, __solve, init, solve, solve!
1617
export ObjSense, MaxSense, MinSense
1718
export allowsbounds, requiresbounds, allowsconstraints, requiresconstraints,
1819
allowscallback, requiresgradient, requireshessian,

lib/OptimizationBase/src/solve.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -92,51 +92,51 @@ from NLopt for an example. The common local optimizer arguments are:
9292
- `local_reltol`: relative tolerance in changes of the objective value
9393
- `local_options`: `NamedTuple` of keyword arguments for local optimizer
9494
"""
95-
function SciMLBase.solve(prob::SciMLBase.OptimizationProblem, alg, args...;
95+
function solve(prob::SciMLBase.OptimizationProblem, alg, args...;
9696
kwargs...)::SciMLBase.AbstractOptimizationSolution
97-
if SciMLBase.supports_opt_cache_interface(alg)
98-
SciMLBase.solve!(SciMLBase.init(prob, alg, args...; kwargs...))
97+
if supports_opt_cache_interface(alg)
98+
solve!(init(prob, alg, args...; kwargs...))
9999
else
100100
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
101101
throw(SciMLBase.NonConcreteEltypeError(eltype(prob.u0)))
102102
end
103103
_check_opt_alg(prob, alg; kwargs...)
104-
SciMLBase.__solve(prob, alg, args...; kwargs...)
104+
__solve(prob, alg, args...; kwargs...)
105105
end
106106
end
107107

108-
function SciMLBase.solve(
108+
function solve(
109109
prob::SciMLBase.EnsembleProblem{T}, args...; kwargs...) where {T <:
110110
SciMLBase.OptimizationProblem}
111-
return SciMLBase.__solve(prob, args...; kwargs...)
111+
return __solve(prob, args...; kwargs...)
112112
end
113113

114114
function _check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...)
115-
!SciMLBase.allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) &&
115+
!allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) &&
116116
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support box constraints. Either remove the `lb` or `ub` bounds passed to `OptimizationProblem` or use a different algorithm."))
117-
SciMLBase.requiresbounds(alg) && isnothing(prob.lb) &&
117+
requiresbounds(alg) && isnothing(prob.lb) &&
118118
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires box constraints. Either pass `lb` and `ub` bounds to `OptimizationProblem` or use a different algorithm."))
119-
!SciMLBase.allowsconstraints(alg) && !isnothing(prob.f.cons) &&
119+
!allowsconstraints(alg) && !isnothing(prob.f.cons) &&
120120
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support constraints. Either remove the `cons` function passed to `OptimizationFunction` or use a different algorithm."))
121-
SciMLBase.requiresconstraints(alg) && isnothing(prob.f.cons) &&
121+
requiresconstraints(alg) && isnothing(prob.f.cons) &&
122122
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires constraints, pass them with the `cons` kwarg in `OptimizationFunction`."))
123123
# Check that if constraints are present and the algorithm supports constraints, both lcons and ucons are provided
124-
SciMLBase.allowsconstraints(alg) && !isnothing(prob.f.cons) &&
124+
allowsconstraints(alg) && !isnothing(prob.f.cons) &&
125125
(isnothing(prob.lcons) || isnothing(prob.ucons)) &&
126126
throw(ArgumentError("Constrained optimization problem requires both `lcons` and `ucons` to be provided to OptimizationProblem. " *
127127
"Example: OptimizationProblem(optf, u0, p; lcons=[-Inf], ucons=[0.0])"))
128-
!SciMLBase.allowscallback(alg) && haskey(kwargs, :callback) &&
128+
!allowscallback(alg) && haskey(kwargs, :callback) &&
129129
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support callbacks, remove the `callback` keyword argument from the `solve` call."))
130-
SciMLBase.requiresgradient(alg) &&
130+
requiresgradient(alg) &&
131131
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
132132
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires gradients, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoForwardDiff())` or pass it in with `grad` kwarg."))
133-
SciMLBase.requireshessian(alg) &&
133+
requireshessian(alg) &&
134134
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
135135
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires hessians, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoFiniteDiff(); kwargs...)` or pass them in with `hess` kwarg."))
136-
SciMLBase.requiresconsjac(alg) &&
136+
requiresconsjac(alg) &&
137137
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
138138
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires constraint jacobians, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoFiniteDiff(); kwargs...)` or pass them in with `cons` kwarg."))
139-
SciMLBase.requiresconshess(alg) &&
139+
requiresconshess(alg) &&
140140
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
141141
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires constraint hessians, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoFiniteDiff(), AutoFiniteDiff(hess=true); kwargs...)` or pass them in with `cons` kwarg."))
142142
return
@@ -184,13 +184,13 @@ These arguments can be passed as `kwargs...` to `init`.
184184
185185
See also [`solve(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
186186
"""
187-
function SciMLBase.init(prob::SciMLBase.OptimizationProblem, alg, args...;
187+
function init(prob::SciMLBase.OptimizationProblem, alg, args...;
188188
kwargs...)::SciMLBase.AbstractOptimizationCache
189189
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
190190
throw(SciMLBase.NonConcreteEltypeError(eltype(prob.u0)))
191191
end
192192
_check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...)
193-
cache = SciMLBase.__init(prob, alg, args...; prob.kwargs..., kwargs...)
193+
cache = __init(prob, alg, args...; prob.kwargs..., kwargs...)
194194
return cache
195195
end
196196

@@ -203,19 +203,19 @@ Solves the given optimization cache.
203203
204204
See also [`init(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
205205
"""
206-
function SciMLBase.solve!(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution
207-
SciMLBase.__solve(cache)
206+
function solve!(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution
207+
__solve(cache)
208208
end
209209

210210
# needs to be defined for each cache
211-
SciMLBase.supports_opt_cache_interface(alg) = false
212-
function SciMLBase.__solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution end
213-
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, alg, args...;
211+
supports_opt_cache_interface(alg) = false
212+
function __solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution end
213+
function __init(prob::SciMLBase.OptimizationProblem, alg, args...;
214214
kwargs...)::SciMLBase.AbstractOptimizationCache
215215
throw(OptimizerMissingError(alg))
216216
end
217217

218218
# if no cache interface is supported at least the following method has to be defined
219-
function SciMLBase.__solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...)
219+
function __solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...)
220220
throw(OptimizerMissingError(alg))
221221
end

0 commit comments

Comments
 (0)