Skip to content

Commit 1d322a1

Browse files
Create OptimizationODE.jl
OptimizationODE.jl added and package created for the same.
1 parent b6ce64c commit 1d322a1

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
module OptimizationODE
2+
3+
using Reexport
4+
@reexport using Optimization
5+
using Optimization.SciMLBase
6+
7+
export ODEGradientDescent
8+
9+
# The optimizer “type”
10+
11+
struct ODEGradientDescent end
12+
13+
# capability flags
14+
SciMLBase.requiresbounds(::ODEGradientDescent) = false
15+
SciMLBase.allowsbounds(::ODEGradientDescent) = false
16+
SciMLBase.allowscallback(::ODEGradientDescent) = false
17+
SciMLBase.supports_opt_cache_interface(::ODEGradientDescent) = true
18+
SciMLBase.requiresgradient(::ODEGradientDescent) = true
19+
SciMLBase.requireshessian(::ODEGradientDescent) = false
20+
SciMLBase.requiresconsjac(::ODEGradientDescent) = false
21+
SciMLBase.requiresconshess(::ODEGradientDescent) = false
22+
23+
# Map standard kwargs to our solver’s args
24+
25+
function __map_optimizer_args!(
26+
cache::OptimizationCache, opt::ODEGradientDescent;
27+
callback = nothing,
28+
maxiters::Union{Number,Nothing}=nothing,
29+
maxtime::Union{Number,Nothing}=nothing,
30+
abstol::Union{Number,Nothing}=nothing,
31+
reltol::Union{Number,Nothing}=nothing,
32+
η::Float64 = 0.1,
33+
tmax::Float64 = 1.0,
34+
dt::Float64 = 0.01,
35+
kwargs...
36+
)
37+
# override our defaults
38+
cache.solver_args = merge(cache.solver_args, (
39+
η = η,
40+
tmax = tmax,
41+
dt = dt,
42+
))
43+
# now apply common options
44+
if !(isnothing(maxiters))
45+
cache.solver_args.maxiters = maxiters
46+
end
47+
if !(isnothing(maxtime))
48+
cache.solver_args.maxtime = maxtime
49+
end
50+
return nothing
51+
end
52+
53+
# 3) Initialize the cache (captures f, u0, bounds, and solver_args)
54+
55+
function SciMLBase.__init(
56+
prob::SciMLBase.OptimizationProblem,
57+
opt::ODEGradientDescent,
58+
data = Optimization.DEFAULT_DATA;
59+
η::Float64 = 0.1,
60+
tmax::Float64 = 1.0,
61+
dt::Float64 = 0.01,
62+
callback = (args...)->false,
63+
progress = false,
64+
kwargs...
65+
)
66+
return OptimizationCache(
67+
prob, opt, data;
68+
η = η,
69+
tmax = tmax,
70+
dt = dt,
71+
callback = callback,
72+
progress = progress,
73+
maxiters = nothing,
74+
maxtime = nothing,
75+
kwargs...
76+
)
77+
end
78+
79+
# 4) The actual solve loop: Euler integration of gradient descent
80+
81+
function SciMLBase.__solve(
82+
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
83+
) where {F,RC,LB,UB,LC,UC,S,O<:ODEGradientDescent,D,P,C}
84+
85+
# unpack initial state & parameters
86+
u0 = cache.u0
87+
η = get(cache.solver_args, , 0.1)
88+
tmax = get(cache.solver_args, :tmax, 1.0)
89+
dt = get(cache.solver_args, :dt, 0.01)
90+
maxiter = get(cache.solver_args, :maxiters, nothing)
91+
92+
# prepare working storage
93+
u = copy(u0)
94+
G = similar(u)
95+
96+
t = 0.0
97+
iter = 0
98+
# Euler loop
99+
while (isnothing(maxiter) || iter < maxiter) && t <= tmax
100+
# compute gradient in‐place
101+
cache.f.grad(G, u, cache.p)
102+
# Euler step
103+
u .-= η .* G
104+
t += dt
105+
iter += 1
106+
end
107+
108+
# final objective
109+
fval = cache.f(u, cache.p)
110+
111+
# record stats: one final f‐eval, iter gradient‐evals
112+
stats = Optimization.OptimizationStats(
113+
iterations = iter,
114+
time = 0.0, # could time() if you like
115+
fevals = 1,
116+
gevals = iter,
117+
hevals = 0
118+
)
119+
120+
return SciMLBase.build_solution(
121+
cache, cache.opt,
122+
u,
123+
fval,
124+
retcode = ReturnCode.Success,
125+
stats = stats
126+
)
127+
end
128+
129+
end # module

0 commit comments

Comments
 (0)