Skip to content

Commit b221953

Browse files
Added new subpackage for Sophia.jl
1 parent 343d8d3 commit b221953

File tree

5 files changed

+103
-6
lines changed

5 files changed

+103
-6
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name = "OptimizationSophia"
2+
uuid = "892fee11-dca1-40d6-b698-84ba0d87399a"
3+
authors = ["paramthakkar123 <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
8+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
9+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
10+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
11+
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
12+
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
13+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
15+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
16+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
17+
18+
[compat]
19+
ComponentArrays = "0.15.29"
20+
Lux = "1.16.0"
21+
MLUtils = "0.4.8"
22+
Optimization = "4.5.0"
23+
OptimizationBase = "2.10.0"
24+
OrdinaryDiffEqTsit5 = "1.2.0"
25+
Random = "1.11.0"
26+
SciMLSensitivity = "7.88.0"
27+
Test = "1.11.0"
28+
Zygote = "0.7.10"

src/sophia.jl renamed to lib/OptimizationSophia/src/OptimizationSophia.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
module OptimizationSophia
2+
3+
using OptimizationBase.SciMLBase
4+
using OptimizationBase: OptimizationCache
5+
using Optimization
6+
17
struct Sophia
28
η::Float64
39
βs::Tuple{Float64, Float64}
@@ -119,3 +125,5 @@ function SciMLBase.__solve(cache::OptimizationCache{
119125
θ,
120126
x)
121127
end
128+
129+
end
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using OptimizationBase, Optimization
2+
using OptimizationBase.SciMLBase: solve, OptimizationFunction, OptimizationProblem
3+
using OptimizationSophia
4+
using Lux, MLUtils, Random, ComponentArrays
5+
using SciMLSensitivity
6+
using Test
7+
using Zygote
8+
using OrdinaryDiffEqTsit5
9+
10+
function dudt_(u, p, t)
11+
ann(u, p, st)[1] .* u
12+
end
13+
14+
function newtons_cooling(du, u, p, t)
15+
temp = u[1]
16+
k, temp_m = p
17+
du[1] = dT = -k * (temp - temp_m)
18+
end
19+
20+
function true_sol(du, u, p, t)
21+
true_p = [log(2) / 8.0, 100.0]
22+
newtons_cooling(du, u, true_p, t)
23+
end
24+
25+
function callback(state, l) #callback function to observe training
26+
display(l)
27+
return l < 1e-2
28+
end
29+
30+
function predict_adjoint(fullp, time_batch)
31+
Array(solve(prob, Tsit5(), p = fullp, saveat = time_batch))
32+
end
33+
34+
function loss_adjoint(fullp, p)
35+
(batch, time_batch) = p
36+
pred = predict_adjoint(fullp, time_batch)
37+
sum(abs2, batch .- pred)
38+
end
39+
40+
u0 = Float32[200.0]
41+
datasize = 30
42+
tspan = (0.0f0, 1.5f0)
43+
rng = Random.default_rng()
44+
45+
ann = Lux.Chain(Lux.Dense(1, 8, tanh), Lux.Dense(8, 1, tanh))
46+
pp, st = Lux.setup(rng, ann)
47+
pp = ComponentArray(pp)
48+
49+
prob = ODEProblem{false}(dudt_, u0, tspan, pp)
50+
51+
t = range(tspan[1], tspan[2], length = datasize)
52+
true_prob = ODEProblem(true_sol, u0, tspan)
53+
ode_data = Array(solve(true_prob, Tsit5(), saveat = t))
54+
55+
k = 10
56+
train_loader = MLUtils.DataLoader((ode_data, t), batchsize = k)
57+
58+
l1 = loss_adjoint(pp, (train_loader.data[1], train_loader.data[2]))[1]
59+
60+
optfun = OptimizationFunction(loss_adjoint,
61+
OptimizationBase.AutoZygote())
62+
optprob = OptimizationProblem(optfun, pp, train_loader)
63+
64+
res1 = solve(optprob,
65+
OptimizationSophia.Sophia(), callback = callback,
66+
maxiters = 2000)
67+
@test 10res1.objective < l1

src/Optimization.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ export ObjSense, MaxSense, MinSense
2323
include("utils.jl")
2424
include("state.jl")
2525
include("lbfgsb.jl")
26-
include("sophia.jl")
2726
include("auglag.jl")
2827

2928
export solve

test/minibatch.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,6 @@ optfun = OptimizationFunction(loss_adjoint,
5858
Optimization.AutoZygote())
5959
optprob = OptimizationProblem(optfun, pp, train_loader)
6060

61-
res1 = Optimization.solve(optprob,
62-
Optimization.Sophia(), callback = callback,
63-
maxiters = 2000)
64-
@test 10res1.objective < l1
65-
6661
optfun = OptimizationFunction(loss_adjoint,
6762
Optimization.AutoForwardDiff())
6863
optprob = OptimizationProblem(optfun, pp, train_loader)

0 commit comments

Comments
 (0)