Skip to content

Commit e0d1665

Browse files
Dev example as Literate.jl script
1 parent c917a33 commit e0d1665

File tree

2 files changed

+73
-80
lines changed

2 files changed

+73
-80
lines changed

docs/literate/tutorials/example_synthetic_lstm.jl

Lines changed: 72 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
1-
# CC BY-SA 4.0
2-
# =============================================================================
3-
# EasyHybrid Example: Synthetic Data Analysis
4-
# =============================================================================
5-
# This example demonstrates how to use EasyHybrid to train a hybrid model
6-
# on synthetic data for respiration modeling with Q10 temperature sensitivity.
7-
# =============================================================================
8-
9-
# =============================================================================
10-
# Project Setup and Environment
11-
# =============================================================================
12-
using Pkg
1+
# # LSTM Hybrid Model with EasyHybrid.jl
2+
#
3+
# This tutorial demonstrates how to use EasyHybrid to train a hybrid model with LSTM
4+
# neural networks on synthetic data for respiration modeling with Q10 temperature sensitivity.
5+
# The code for this tutorial can be found in [docs/src/literate/tutorials](https://github.com/EarthyScience/EasyHybrid.jl/tree/main/docs/src/literate/tutorials/) => example_synthetic_lstm.jl.
6+
#
7+
# ## 1. Load Packages
138

149
# Set project path and activate environment
10+
using Pkg
1511
project_path = "docs"
1612
Pkg.activate(project_path)
1713
EasyHybrid_path = joinpath(pwd())
@@ -24,19 +20,17 @@ using AxisKeys
2420
using DimensionalData
2521
using Lux
2622

27-
# =============================================================================
28-
# Data Loading and Preprocessing
29-
# =============================================================================
30-
# Load synthetic dataset from GitHub into DataFrame
23+
# ## 2. Data Loading and Preprocessing
24+
25+
# Load synthetic dataset from GitHub
3126
df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc")
3227

3328
# Select a subset of data for faster execution
3429
df = df[1:20000, :]
3530

36-
# =============================================================================
37-
# neural network model
38-
# =============================================================================
39-
# Define neural network
31+
# ## 3. Define Neural Network Architectures
32+
33+
# Define a standard feedforward neural network
4034
NN = Chain(Dense(15, 15, Lux.sigmoid), Dense(15, 15, Lux.sigmoid), Dense(15, 1))
4135

4236
# Define LSTM-based neural network with memory
@@ -47,106 +41,106 @@ NN_Memory = Chain(
4741
Recurrence(LSTMCell(15 => 15), return_sequence = true),
4842
)
4943

50-
# =============================================================================
51-
# Define the Physical Model
52-
# =============================================================================
53-
# RbQ10 model: Respiration model with Q10 temperature sensitivity
54-
# Parameters:
55-
# - ta: air temperature [°C]
56-
# - Q10: temperature sensitivity factor [-]
57-
# - rb: basal respiration rate [μmol/m²/s]
58-
# - tref: reference temperature [°C] (default: 15.0)
44+
# ## 4. Define the Physical Model
45+
46+
"""
47+
RbQ10(; ta, Q10, rb, tref=15.0f0)
48+
49+
Respiration model with Q10 temperature sensitivity.
50+
51+
- `ta`: air temperature [°C]
52+
- `Q10`: temperature sensitivity factor [-]
53+
- `rb`: basal respiration rate [μmol/m²/s]
54+
- `tref`: reference temperature [°C] (default: 15.0)
55+
"""
5956
function RbQ10(; ta, Q10, rb, tref = 15.0f0)
6057
reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref))
6158
return (; reco, Q10, rb)
6259
end
6360

64-
# =============================================================================
65-
# Define Model Parameters
66-
# =============================================================================
61+
# ## 5. Define Model Parameters
62+
6763
# Parameter specification: (default, lower_bound, upper_bound)
6864
parameters = (
69-
# Parameter name | Default | Lower | Upper | Description
7065
rb = (3.0f0, 0.0f0, 13.0f0), # Basal respiration [μmol/m²/s]
7166
Q10 = (2.0f0, 1.0f0, 4.0f0), # Temperature sensitivity factor [-]
7267
)
7368

74-
# =============================================================================
75-
# Configure Hybrid Model Components
76-
# =============================================================================
77-
# Define input variables
78-
forcing = [:ta] # Forcing variables (temperature)
69+
# ## 6. Configure Hybrid Model Components
7970

80-
# Target variable
81-
target = [:reco] # Target variable (respiration)
71+
# Define input variables
72+
# Forcing variables (temperature)
73+
forcing = [:ta]
74+
# Predictor variables (solar radiation, and its derivative)
75+
predictors = [:sw_pot, :dsw_pot]
76+
# Target variable (respiration)
77+
target = [:reco]
8278

8379
# Parameter classification
84-
global_param_names = [:Q10] # Global parameters (same for all samples)
85-
neural_param_names = [:rb] # Neural network predicted parameters
80+
# Global parameters (same for all samples)
81+
global_param_names = [:Q10]
82+
# Neural network predicted parameters
83+
neural_param_names = [:rb]
8684

87-
# =============================================================================
88-
# Single NN Hybrid Model Training
89-
# =============================================================================
90-
using GLMakie
91-
# Create single NN hybrid model using the unified constructor
92-
predictors = [:sw_pot, :dsw_pot] # Predictor variables (solar radiation, and its derivative)
85+
# ## 7. Construct LSTM Hybrid Model
9386

87+
# Create LSTM hybrid model using the unified constructor
9488
hlstm = constructHybridModel(
95-
predictors, # Input features
96-
forcing, # Forcing variables
97-
target, # Target variables
98-
RbQ10, # Process-based model function
99-
parameters, # Parameter definitions
100-
neural_param_names, # NN-predicted parameters
101-
global_param_names, # Global parameters
89+
predictors,
90+
forcing,
91+
target,
92+
RbQ10,
93+
parameters,
94+
neural_param_names,
95+
global_param_names,
10296
hidden_layers = NN_Memory, # Neural network architecture
10397
scale_nn_outputs = true, # Scale neural network outputs
10498
input_batchnorm = false # Apply batch normalization to inputs
10599
)
106100

107-
# =================================================================================
108-
# show steps for data preparation, happens under the hood in the end.
101+
# ## 8. Data Preparation Steps (Demonstration)
102+
103+
# The following steps demonstrate what happens under the hood during training.
104+
# In practice, you can skip to Section 9 and use the `train` function directly.
109105

110106
# :KeyedArray and :DimArray are supported
111107
x, y = prepare_data(hlstm, df, array_type = :DimArray)
112108

113-
# new split_into_sequences with input_window, output_window, shift and lead_time
109+
# New split_into_sequences with input_window, output_window, shift and lead_time
114110
# for many-to-one, many-to-many, and different prediction lead times and overlap
115111
xs, ys = split_into_sequences(x, y; input_window = 20, output_window = 2, shift = 1, lead_time = 0)
116112
ys_nan = .!isnan.(ys)
117113

118-
# split data as in train
119-
sdf = split_data(df, hlstm, sequence_kwargs = (; input_window = 10, output_window = 3, shift = 1, lead_time = 1));
114+
# Split data as in train
115+
sdf = split_data(df, hlstm, sequence_kwargs = (; input_window = 10, output_window = 3, shift = 1, lead_time = 1))
120116

121117
typeof(sdf)
122-
(x_train, y_train), (x_val, y_val) = sdf;
118+
(x_train, y_train), (x_val, y_val) = sdf
123119
x_train
124120
y_train
125121
y_train_nan = .!isnan.(y_train)
126122

127-
# put into train loader to compose minibatches
123+
# Put into train loader to compose minibatches
128124
train_dl = EasyHybrid.DataLoader((x_train, y_train); batchsize = 32)
129125

130-
# run hybrid model forwards
126+
# Run hybrid model forwards
131127
x_first = first(train_dl)[1]
132128
y_first = first(train_dl)[2]
133129

134130
ps, st = Lux.setup(Random.default_rng(), hlstm)
135131
frun = hlstm(x_first, ps, st)
136132

137-
# extract predicted yhat
133+
# Extract predicted yhat
138134
reco_mod = frun[1].reco
139135

140-
# bring observations in same shape
136+
# Bring observations in same shape
141137
reco_obs = dropdims(y_first, dims = 1)
142138
reco_nan = .!isnan.(reco_obs)
143139

144-
# compute loss
140+
# Compute loss
145141
EasyHybrid.compute_loss(hlstm, ps, st, (x_train, (y_train, y_train_nan)), logging = LoggingLoss(train_mode = true))
146142

147-
# =============================================================================
148-
# train on DataFrame
149-
# =============================================================================
143+
# ## 9. Train LSTM Hybrid Model
150144

151145
out_lstm = train(
152146
hlstm,
@@ -164,24 +158,22 @@ out_lstm = train(
164158
array_type = :DimArray
165159
)
166160

161+
# ## 10. Train Single NN Hybrid Model (Optional)
167162

168-
#####################################################################################
169-
# is neural network still running?
170-
163+
# For comparison, we can also train a hybrid model with a standard feedforward neural network
171164
hm = constructHybridModel(
172-
predictors, # Input features
173-
forcing, # Forcing variables
174-
target, # Target variables
175-
RbQ10, # Process-based model function
176-
parameters, # Parameter definitions
177-
neural_param_names, # NN-predicted parameters
178-
global_param_names, # Global parameters
165+
predictors,
166+
forcing,
167+
target,
168+
RbQ10,
169+
parameters,
170+
neural_param_names,
171+
global_param_names,
179172
hidden_layers = NN, # Neural network architecture
180173
scale_nn_outputs = true, # Scale neural network outputs
181174
input_batchnorm = false, # Apply batch normalization to inputs
182175
)
183176

184-
185177
# Train the hybrid model
186178
single_nn_out = train(
187179
hm,

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ makedocs(;
6060
"Hyperparameter Tuning" => "tutorials/hyperparameter_tuning.md",
6161
"Slurm" => "tutorials/slurm.md",
6262
"Cross-validation" => "tutorials/folds.md",
63+
"LSTM Hybrid Model" => "tutorials/example_synthetic_lstm.md",
6364
"Loss Functions" => "tutorials/losses.md",
6465
],
6566
"Research" => [

0 commit comments

Comments
 (0)