Skip to content

Commit 67fe260

Browse files
committed
added named targets
1 parent 5c7c0fd commit 67fe260

File tree

4 files changed

+41
-34
lines changed

4 files changed

+41
-34
lines changed

projects/RbQ10/Q10.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,27 @@ using Random
66
using LuxCore
77
using CSV, DataFrames
88
using EasyHybrid.MLUtils
9-
9+
using EasyHybrid.AxisKeys
10+
using Zygote
1011
# data
11-
df = CSV.read("/Users/lalonso/Documents/HybridML/data/Rh_AliceHolt_forcing_filled.csv", DataFrame)
12+
df_o = CSV.read("/Users/lalonso/Documents/HybridML/data/Rh_AliceHolt_forcing_filled.csv", DataFrame)
1213

14+
df = copy(df_o)
1315
df[!, :Temp] = df[!, :Temp] .- 273.15 # convert to Celsius
14-
df_forcing = filter(:Respiration_heterotrophic => !isnan, df)
16+
# df = filter(:Respiration_heterotrophic => !isnan, df)
17+
rename!(df, :Respiration_heterotrophic => :Rh) # rename as in hybrid model
1518
df_forcing = df
16-
ds_k = to_keyedArray(Float32.(df_forcing))
17-
yobs = ds_k(:Respiration_heterotrophic)'[:,:]
19+
20+
ds_p_f = to_keyedArray(Float32.(df_forcing)) # predictors + forcing
21+
ds_t = ds_p_f([:Rh]) # do the array so that you conserve the name
1822

1923
NN = Lux.Chain(Dense(2, 15, Lux.relu), Dense(15, 15, Lux.relu), Dense(15, 1));
2024
#? do different initial Q10s
21-
RbQ10 = RespirationRbQ10(NN, (:Rgpot, :Moist), (:Temp,), 2.5f0)
25+
RbQ10 = RespirationRbQ10(NN, (:Rgpot, :Moist), (:Rh, ), (:Temp,), 2.5f0)
2226

2327
# ? play with :Temp as predictors in NN, temperature sensitivity!
2428
# TODO: variance effect due to LSTM vs NN
25-
26-
out = train(RbQ10, (ds_k([:Rgpot, :Moist, :Temp]), yobs), (:Q10, ); nepochs=1000, batchsize=512, opt=Adam(0.01));
29+
out = train(RbQ10, (ds_p_f, ds_t), (:Q10, ); nepochs=1000, batchsize=512, opt=Adam(0.01));
2730

2831

2932
with_theme(theme_light()) do
@@ -39,10 +42,10 @@ with_theme(theme_light()) do
3942
fig = Figure(; size = (1200, 600))
4043
ax_train = Makie.Axis(fig[1, 1], title = "training")
4144
ax_val = Makie.Axis(fig[2, 1], title = "validation")
42-
lines!(ax_train, out.ŷ_train[:], color=:orangered, label = "prediction")
45+
lines!(ax_train, out.ŷ_train.Rh[:], color=:orangered, label = "prediction")
4346
lines!(ax_train, out.y_train[:], color=:dodgerblue, label ="observation")
4447
# validation
45-
lines!(ax_val, out.ŷ_val[:], color=:orangered, label = "prediction")
48+
lines!(ax_val, out.ŷ_val.Rh[:], color=:orangered, label = "prediction")
4649
lines!(ax_val, out.y_val[:], color=:dodgerblue, label ="observation")
4750
axislegend(; position=:lt)
4851
Label(fig[0,1], "Observations vs predictions", tellwidth=false)
@@ -62,22 +65,20 @@ with_theme(theme_light()) do
6265
end
6366

6467

65-
ds_k = to_keyedArray(Float32.(df))
66-
yobs_all = ds_k(:Respiration_heterotrophic)'[:,:]
68+
yobs_all = ds_p_f(:Rh)
6769

68-
ŷ, RbQ10_st = LuxCore.apply(RbQ10, ds_k, out.ps, out.st)
70+
ŷ, RbQ10_st = LuxCore.apply(RbQ10, ds_p_f, out.ps, out.st)
6971

7072
with_theme(theme_light()) do
7173
fig = Figure(; size = (1200, 600))
7274
ax_train = Makie.Axis(fig[1, 1], title = "full time series")
73-
lines!(ax_train, ŷ[:], color=:orangered, label = "prediction")
75+
lines!(ax_train, ŷ.Rh[:], color=:orangered, label = "prediction")
7476
lines!(ax_train, yobs_all[:], color=:dodgerblue, label ="observation")
7577
axislegend(ax_train; position=:lt)
7678
Label(fig[0,1], "Observations vs predictions", tellwidth=false)
7779
fig
7880
end
7981

80-
8182
# ? Rb
8283
lines(out.αst_train.Rb[:])
83-
lines!(ds_k(:Moist)[:])
84+
lines!(ds_p_f(:Moist)[:])

src/models/Respiration_Rb_Q10.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22
export RespirationRbQ10
33

44
"""
5-
RespirationRbQ10(NN, predictors, forcing, Q10)
5+
RespirationRbQ10(NN, predictors, forcing, targets, Q10)
66
7-
A linear hybrid model with a neural network `NN`, `predictors` and `forcing` terms.
7+
A linear hybrid model with a neural network `NN`, `predictors`, `targets` and `forcing` terms.
88
"""
9-
struct RespirationRbQ10{D, T1, T2, T3} <: LuxCore.AbstractLuxContainerLayer{(:NN, :predictors, :forcing, :Q10)}
9+
struct RespirationRbQ10{D, T1, T2, T3, T4} <: LuxCore.AbstractLuxContainerLayer{(:NN, :predictors, :forcing, :targets, :Q10)}
1010
NN
1111
predictors
1212
forcing
13+
targets
1314
Q10
14-
function RespirationRbQ10(NN::D, predictors::T1, forcing::T2, Q10::T3) where {D, T1, T2, T3}
15-
new{D, T1, T2, T3}(NN, collect(predictors), collect(forcing), [Q10])
15+
function RespirationRbQ10(NN::D, predictors::T1, forcing::T2, targets::T3, Q10::T4) where {D, T1, T2, T3, T4}
16+
new{D, T1, T2, T3, T4}(NN, collect(predictors), collect(targets), collect(forcing), [Q10])
1617
end
1718
end
1819

@@ -28,7 +29,7 @@ function LuxCore.initialstates(::AbstractRNG, layer::RespirationRbQ10)
2829
end
2930

3031
"""
31-
RespirationRbQ10(NN, predictors, forcing, Q10)(ds_k)
32+
RespirationRbQ10(NN, predictors, forcing, targets, Q10)(ds_k)
3233
3334
# Model definition `ŷ = Rb(αᵢ(t)) * Q10^((T(t) - T_ref)/10)`
3435
@@ -37,11 +38,11 @@ ŷ (respiration rate) is computed as a function of the neural network output `R
3738
"""
3839
function (hm::RespirationRbQ10)(ds_k, ps, st::NamedTuple)
3940
p = ds_k(hm.predictors)
40-
x = ds_k(hm.forcing)
41+
x = Array(ds_k(hm.forcing)) # don't propagate names after this
4142

4243
Rb, st = LuxCore.apply(hm.NN, p, ps.ps, st.st) #! NN(αᵢ(t)) ≡ Rb(T(t), M(t))
4344

44-
= Rb .* ps.Q10 .^(0.1f0 * (x .- 15.0f0)) # ? should 15°C be the reference temperature also an input variable?
45+
Rh = Rb .* ps.Q10 .^(0.1f0 * (x .- 15.0f0)) # ? should 15°C be the reference temperature also an input variable?
4546

46-
return , (; Rb, st)
47+
return (; Rh), (; Rb, st)
4748
end

src/train.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ function train(hybridModel, data, save_ps; nepochs=200, batchsize=10, opt=Adam(0
1212
opt_state = Optimisers.setup(opt, ps)
1313

1414
# ? initial losses
15-
is_no_nan_t = .!isnan.(y_train)[1, :]
16-
is_no_nan_v = .!isnan.(y_val)[1, :]
15+
is_no_nan_t = .!isnan.(y_train)
16+
is_no_nan_v = .!isnan.(y_val)
1717
l_init_train = lossfn(hybridModel, x_train, (y_train, is_no_nan_t), ps, st)
1818
l_init_val = lossfn(hybridModel, x_val, (y_val, is_no_nan_v), ps, st)
1919

@@ -24,8 +24,8 @@ function train(hybridModel, data, save_ps; nepochs=200, batchsize=10, opt=Adam(0
2424
for epoch in 1:nepochs
2525
for (x, y) in train_loader
2626
# ? check NaN indices before going forward, and pass filtered `x, y`.
27-
is_no_nan = .!isnan.(y)[1, :]
28-
if length(is_no_nan)>0
27+
is_no_nan = .!isnan.(y)
28+
if length(is_no_nan)>0 # ! be careful here, multivariate needs fine tuning
2929
grads = Zygote.gradient((ps) -> lossfn(hybridModel, x, (y, is_no_nan), ps, st), ps)[1]
3030
Optimisers.update!(opt_state, ps, grads)
3131
end

src/utils/losses.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,16 @@ end
1212

1313

1414
"""
15-
lossfn(HM::RespirationRbQ10, ds, y, ps, st)
15+
lossfn(HM::RespirationRbQ10, ds_p, (ds_t, ds_t_nan), ps, st)
1616
"""
17-
function lossfn(HM::RespirationRbQ10, ds, (y, no_nan), ps, st)
18-
ŷ, αst = HM(ds, ps, st)
19-
_, st = αst
20-
loss = mean((y[no_nan] .- ŷ[no_nan]).^2)
17+
function lossfn(HM::RespirationRbQ10, ds_p, (ds_t, ds_t_nan), ps, st)
18+
ŷ, _ = HM(ds_p, ps, st)
19+
y = ds_t(HM.targets)
20+
y_nan = ds_t_nan(HM.targets)
21+
22+
loss = 0.0
23+
for k in axiskeys(y, 1)
24+
loss += mean(abs2, (ŷ[k][y_nan(k)] .- y(k)[y_nan(k)]))
25+
end
2126
return loss
2227
end

0 commit comments

Comments
 (0)