Skip to content

Commit a7f2ce9

Browse files
authored
Merge pull request #9 from EarthyScience/la/lbfgs
lbfgs
2 parents fd6f702 + 0cba55d commit a7f2ce9

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

examples/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@ AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
33
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
44
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
55
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
67
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
78
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
89
EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3"
910
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
1011
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1112
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
13+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1214
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
15+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
16+
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
1317
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1418
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

examples/Q10_lbfgs.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using EasyHybrid
2+
using Lux
3+
using MLUtils
4+
using Random
5+
using Optimization
6+
using OptimizationOptimisers
7+
using ComponentArrays
8+
using GLMakie
9+
using Random
10+
using LuxCore
11+
using CSV, DataFrames
12+
using Statistics
13+
using Printf
14+
15+
df = CSV.read("/Users/lalonso/Documents/HybridML/data/Rh_AliceHolt_forcing_filled.csv", DataFrame)
16+
17+
df[!, :Temp] = df[!, :Temp] .- 273.15 # convert to Celsius
18+
df_forcing = filter(:Respiration_heterotrophic => !isnan, df)
19+
# df_forcing = df
20+
ds_k = to_keyedArray(Float32.(df_forcing))
21+
yobs = ds_k(:Respiration_heterotrophic)'[:,:]
22+
23+
NN = Lux.Chain(Dense(2, 15, Lux.relu), Dense(15, 15, Lux.relu), Dense(15, 1));
24+
#? do different initial Q10s
25+
RbQ10 = RespirationRbQ10(NN, (:Rgpot, :Moist), (:Temp,), 2.5f0)
26+
27+
data = (ds_k([:Rgpot, :Moist, :Temp]), yobs)
28+
29+
(x_train, y_train), (x_val, y_val) = splitobs(data; at=0.8, shuffle=false)
30+
dataloader = DataLoader((x_train, y_train), batchsize=512, shuffle=true);
31+
32+
ps, st = LuxCore.setup(Random.default_rng(), RbQ10)
33+
34+
ps_ca = ComponentArray(ps)
35+
smodel = StatefulLuxLayer{false}(RbQ10, nothing, st)
36+
# deal with the `Rb` state also here, (; Rb, st), since this is the output from LuxCore.apply.
37+
# ! note that for now is set to `{false}`.
38+
39+
function callback(state, l)
40+
state.iter % 2 == 1 && @printf "Iteration: %5d, Loss: %.6f\n" state.iter l
41+
return l < 0.2 ## Terminate if loss is smaller than
42+
end
43+
44+
function lossfn_optim(ps_ca, data)
45+
ds, y = data
46+
# ! unpack nan indices here as well
47+
ŷ, _ = smodel(ds, ps_ca)
48+
return Statistics.mean(abs2, ŷ .- y)
49+
end
50+
51+
lossfn_optim(ps_ca, (ds_k, yobs))
52+
53+
opt_func = OptimizationFunction(lossfn_optim, Optimization.AutoZygote())
54+
opt_prob = OptimizationProblem(opt_func, ps_ca, dataloader)
55+
56+
epochs = 25
57+
res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, epochs)
58+
# res_shopia = solve(opt_prob, Optimization.Sophia(); callback, maxiters=epochs)
59+
60+
# ! finetune a bit with L-BFGS
61+
# ? LBFGS needs to this `convert(Float64, res_adam.u)` which it fails!
62+
# ! but there is more, see issue: https://github.com/LuxDL/Lux.jl/issues/1260
63+
64+
# using ForwardDiff
65+
# opt_func = OptimizationFunction(lossfn_optim, Optimization.AutoForwardDiff())
66+
# opt_prob2 = remake(opt_prob, u0=res_adam.u)
67+
opt_prob = OptimizationProblem(opt_func, res_adam.u, dataloader)
68+
res_lbfgs = solve(opt_prob, Optimization.LBFGS(); callback, maxiters=epochs)

0 commit comments

Comments
 (0)