|
| 1 | +module OptimizationODE |
| 2 | + |
| 3 | +using Reexport |
| 4 | +@reexport using Optimization |
| 5 | +using Optimization.SciMLBase |
| 6 | + |
| 7 | +export ODEGradientDescent |
| 8 | + |
| 9 | +# The optimizer “type” |
| 10 | + |
| 11 | +struct ODEGradientDescent end |
| 12 | + |
| 13 | +# capability flags |
| 14 | +SciMLBase.requiresbounds(::ODEGradientDescent) = false |
| 15 | +SciMLBase.allowsbounds(::ODEGradientDescent) = false |
| 16 | +SciMLBase.allowscallback(::ODEGradientDescent) = false |
| 17 | +SciMLBase.supports_opt_cache_interface(::ODEGradientDescent) = true |
| 18 | +SciMLBase.requiresgradient(::ODEGradientDescent) = true |
| 19 | +SciMLBase.requireshessian(::ODEGradientDescent) = false |
| 20 | +SciMLBase.requiresconsjac(::ODEGradientDescent) = false |
| 21 | +SciMLBase.requiresconshess(::ODEGradientDescent) = false |
| 22 | + |
| 23 | +# Map standard kwargs to our solver’s args |
| 24 | + |
| 25 | +function __map_optimizer_args!( |
| 26 | + cache::OptimizationCache, opt::ODEGradientDescent; |
| 27 | + callback = nothing, |
| 28 | + maxiters::Union{Number,Nothing}=nothing, |
| 29 | + maxtime::Union{Number,Nothing}=nothing, |
| 30 | + abstol::Union{Number,Nothing}=nothing, |
| 31 | + reltol::Union{Number,Nothing}=nothing, |
| 32 | + η::Float64 = 0.1, |
| 33 | + tmax::Float64 = 1.0, |
| 34 | + dt::Float64 = 0.01, |
| 35 | + kwargs... |
| 36 | +) |
| 37 | + # override our defaults |
| 38 | + cache.solver_args = merge(cache.solver_args, ( |
| 39 | + η = η, |
| 40 | + tmax = tmax, |
| 41 | + dt = dt, |
| 42 | + )) |
| 43 | + # now apply common options |
| 44 | + if !(isnothing(maxiters)) |
| 45 | + cache.solver_args.maxiters = maxiters |
| 46 | + end |
| 47 | + if !(isnothing(maxtime)) |
| 48 | + cache.solver_args.maxtime = maxtime |
| 49 | + end |
| 50 | + return nothing |
| 51 | +end |
| 52 | + |
| 53 | +# 3) Initialize the cache (captures f, u0, bounds, and solver_args) |
| 54 | + |
| 55 | +function SciMLBase.__init( |
| 56 | + prob::SciMLBase.OptimizationProblem, |
| 57 | + opt::ODEGradientDescent, |
| 58 | + data = Optimization.DEFAULT_DATA; |
| 59 | + η::Float64 = 0.1, |
| 60 | + tmax::Float64 = 1.0, |
| 61 | + dt::Float64 = 0.01, |
| 62 | + callback = (args...)->false, |
| 63 | + progress = false, |
| 64 | + kwargs... |
| 65 | +) |
| 66 | + return OptimizationCache( |
| 67 | + prob, opt, data; |
| 68 | + η = η, |
| 69 | + tmax = tmax, |
| 70 | + dt = dt, |
| 71 | + callback = callback, |
| 72 | + progress = progress, |
| 73 | + maxiters = nothing, |
| 74 | + maxtime = nothing, |
| 75 | + kwargs... |
| 76 | + ) |
| 77 | +end |
| 78 | + |
| 79 | +# 4) The actual solve loop: Euler integration of gradient descent |
| 80 | + |
| 81 | +function SciMLBase.__solve( |
| 82 | + cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C} |
| 83 | +) where {F,RC,LB,UB,LC,UC,S,O<:ODEGradientDescent,D,P,C} |
| 84 | + |
| 85 | + # unpack initial state & parameters |
| 86 | + u0 = cache.u0 |
| 87 | + η = get(cache.solver_args, :η, 0.1) |
| 88 | + tmax = get(cache.solver_args, :tmax, 1.0) |
| 89 | + dt = get(cache.solver_args, :dt, 0.01) |
| 90 | + maxiter = get(cache.solver_args, :maxiters, nothing) |
| 91 | + |
| 92 | + # prepare working storage |
| 93 | + u = copy(u0) |
| 94 | + G = similar(u) |
| 95 | + |
| 96 | + t = 0.0 |
| 97 | + iter = 0 |
| 98 | + # Euler loop |
| 99 | + while (isnothing(maxiter) || iter < maxiter) && t <= tmax |
| 100 | + # compute gradient in‐place |
| 101 | + cache.f.grad(G, u, cache.p) |
| 102 | + # Euler step |
| 103 | + u .-= η .* G |
| 104 | + t += dt |
| 105 | + iter += 1 |
| 106 | + end |
| 107 | + |
| 108 | + # final objective |
| 109 | + fval = cache.f(u, cache.p) |
| 110 | + |
| 111 | + # record stats: one final f‐eval, iter gradient‐evals |
| 112 | + stats = Optimization.OptimizationStats( |
| 113 | + iterations = iter, |
| 114 | + time = 0.0, # could time() if you like |
| 115 | + fevals = 1, |
| 116 | + gevals = iter, |
| 117 | + hevals = 0 |
| 118 | + ) |
| 119 | + |
| 120 | + return SciMLBase.build_solution( |
| 121 | + cache, cache.opt, |
| 122 | + u, |
| 123 | + fval, |
| 124 | + retcode = ReturnCode.Success, |
| 125 | + stats = stats |
| 126 | + ) |
| 127 | +end |
| 128 | + |
| 129 | +end # module |
0 commit comments