1+ using EasyHybrid
2+ using Lux
3+ using MLUtils
4+ using Random
5+ using Optimization
6+ using OptimizationOptimisers
7+ using ComponentArrays
8+ using GLMakie
9+ using Random
10+ using LuxCore
11+ using CSV, DataFrames
12+ using Statistics
13+ using Printf
14+
15+ df = CSV. read (" /Users/lalonso/Documents/HybridML/data/Rh_AliceHolt_forcing_filled.csv" , DataFrame)
16+
17+ df[! , :Temp ] = df[! , :Temp ] .- 273.15 # convert to Celsius
18+ df_forcing = filter (:Respiration_heterotrophic => ! isnan, df)
19+ # df_forcing = df
20+ ds_k = to_keyedArray (Float32 .(df_forcing))
21+ yobs = ds_k (:Respiration_heterotrophic )' [:,:]
22+
23+ NN = Lux. Chain (Dense (2 , 15 , Lux. relu), Dense (15 , 15 , Lux. relu), Dense (15 , 1 ));
24+ # ? do different initial Q10s
25+ RbQ10 = RespirationRbQ10 (NN, (:Rgpot , :Moist ), (:Temp ,), 2.5f0 )
26+
27+ data = (ds_k ([:Rgpot , :Moist , :Temp ]), yobs)
28+
29+ (x_train, y_train), (x_val, y_val) = splitobs (data; at= 0.8 , shuffle= false )
30+ dataloader = DataLoader ((x_train, y_train), batchsize= 512 , shuffle= true );
31+
32+ ps, st = LuxCore. setup (Random. default_rng (), RbQ10)
33+
34+ ps_ca = ComponentArray (ps)
35+ smodel = StatefulLuxLayer {false} (RbQ10, nothing , st)
36+ # deal with the `Rb` state also here, (; Rb, st), since this is the output from LuxCore.apply.
37+ # ! note that for now is set to `{false}`.
38+
39+ function callback (state, l)
40+ state. iter % 2 == 1 && @printf " Iteration: %5d, Loss: %.6f\n " state. iter l
41+ return l < 0.2 # # Terminate if loss is smaller than
42+ end
43+
44+ function lossfn_optim (ps_ca, data)
45+ ds, y = data
46+ # ! unpack nan indices here as well
47+ ŷ, _ = smodel (ds, ps_ca)
48+ return Statistics. mean (abs2, ŷ .- y)
49+ end
50+
51+ lossfn_optim (ps_ca, (ds_k, yobs))
52+
53+ opt_func = OptimizationFunction (lossfn_optim, Optimization. AutoZygote ())
54+ opt_prob = OptimizationProblem (opt_func, ps_ca, dataloader)
55+
56+ epochs = 25
57+ res_adam = solve (opt_prob, Optimisers. Adam (0.001 ); callback, epochs)
58+ # res_shopia = solve(opt_prob, Optimization.Sophia(); callback, maxiters=epochs)
59+
60+ # ! finetune a bit with L-BFGS
61+ # ? LBFGS needs to this `convert(Float64, res_adam.u)` which it fails!
62+ # ! but there is more, see issue: https://github.com/LuxDL/Lux.jl/issues/1260
63+
64+ # using ForwardDiff
65+ # opt_func = OptimizationFunction(lossfn_optim, Optimization.AutoForwardDiff())
66+ # opt_prob2 = remake(opt_prob, u0=res_adam.u)
67+ opt_prob = OptimizationProblem (opt_func, res_adam. u, dataloader)
68+ res_lbfgs = solve (opt_prob, Optimization. LBFGS (); callback, maxiters= epochs)
0 commit comments