Skip to content

Commit c1a9351

Browse files
committed
hybrid only BD
1 parent c6d9121 commit c1a9351

File tree

2 files changed

+235
-0
lines changed

2 files changed

+235
-0
lines changed

projects/BulkDSOC/HybridBD.jl

Whitespace-only changes.
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
using Pkg
2+
Pkg.activate("projects/BulkDSOC")
3+
Pkg.develop(path=pwd())
4+
Pkg.instantiate()
5+
6+
using Revise
7+
using EasyHybrid
8+
using Lux
9+
using Optimisers
10+
using GLMakie
11+
using Random
12+
using LuxCore
13+
using CSV, DataFrames
14+
using EasyHybrid.MLUtils
15+
using Statistics
16+
using Plots
17+
using JLD2
18+
19+
# 04 - hybrid
20+
testid = "04a_hybridBD";
21+
results_dir = joinpath(@__DIR__, "eval");
22+
23+
# input
24+
raw = CSV.read(joinpath(@__DIR__, "data/lucas_preprocessed.csv"), DataFrame; normalizenames=true);
25+
raw = dropmissing(raw); # to be discussed, as now train.jl seems to allow training with sparse data
26+
raw .= Float32.(raw);
27+
df = raw
28+
29+
# mechanistic model
30+
function BD_model(; SOCconc, oBD, mBD)
31+
BD = (oBD .* mBD) ./ (1.724f0 .* SOCconc .* mBD .+ (1f0 .- 1.724f0 .* SOCconc) .* oBD)
32+
return (; BD, SOCconc, oBD, mBD) # supervise both BD and SOCconc
33+
end
34+
35+
# param bounds
36+
parameters = (
37+
SOCconc = (0.01f0, 0.0f0, 1.0f0), # fraction
38+
oBD = (1.30f0, 0.90f0, 1.80f0), # g/cm^3
39+
mBD = (1.50f0, 0.80f0, 2.0f0), # global
40+
)
41+
42+
# define param for hybrid model
43+
neural_param_names = [:SOCconc, :oBD]
44+
global_param_names = [:mBD]
45+
forcing = Symbol[]
46+
targets = [:BD, :SOCconc] # SOCconc is both a param and a target
47+
48+
# just exclude targets explicitly to be safe
49+
predictors = setdiff(Symbol.(names(df)), targets); # first 3 and last 1
50+
nf = length(predictors);
51+
52+
# search space
53+
batch_sizes = [32, 64, 128, 256, 512];
54+
lrs = [1e-3, 1e-4];
55+
acts = [swish, gelu];
56+
57+
# store results
58+
results = []
59+
best_r2 = -Inf
60+
best_bundle = nothing
61+
62+
for bs in batch_sizes, lr in lrs, act in acts
63+
@info "Testing bs=$(bs), lr=$(lr), act=$(act)"
64+
65+
hm = constructHybridModel(
66+
predictors, # single NN uses a Vector of predictors
67+
forcing,
68+
targets,
69+
BD_model,
70+
parameters,
71+
neural_param_names,
72+
global_param_names;
73+
hidden_layers = [256, 128, 64, 32, 16],
74+
activation = act,
75+
scale_nn_outputs = true,
76+
input_batchnorm = true,
77+
start_from_default = true
78+
)
79+
80+
res = train(
81+
hm, df, ();
82+
nepochs = 200,
83+
batchsize = bs,
84+
opt = AdamW(lr),
85+
training_loss = :mse,
86+
loss_types = [:mse, :r2],
87+
shuffleobs = true,
88+
file_name = nothing,
89+
random_seed = 42,
90+
patience = 20,
91+
yscale = identity,
92+
monitor_names = [:oBD, :mBD],
93+
agg = mean,
94+
return_model = :best,
95+
show_progress = false
96+
)
97+
98+
# retrieve the best epoch metrics: mse and r2
99+
agg_name = Symbol("mean")
100+
r2s = map(vh -> getproperty(vh, agg_name), res.val_history.r2)
101+
mses = map(vh -> getproperty(vh, agg_name), res.val_history.mse)
102+
best_idx = findmax(r2s)[2] # index of best r2
103+
best_r2_here = r2s[best_idx]
104+
best_mse_here = mses[best_idx]
105+
106+
push!(results, (bs, lr, act, best_r2_here, best_mse_here, best_idx))
107+
108+
# keep the whole bundle if better
109+
if !isnan(best_r2_here) && best_r2_here > best_r2
110+
best_r2 = best_r2_here
111+
112+
# map global mBD -> physical
113+
mBD_phys = EasyHybrid.scale_single_param(:mBD, res.ps[:mBD], hm.parameters) |> vec |> first
114+
mBD_raw = res.ps[:mBD][1] # unconstrained optimizer value
115+
116+
# per-sample oBD
117+
oBD_phys = (hasproperty(res, :val_diffs) && hasproperty(res.val_diffs, :oBD)) ?
118+
collect(res.val_diffs.oBD) : nothing
119+
120+
best_bundle = (
121+
ps = deepcopy(res.ps),
122+
st = deepcopy(res.st),
123+
model = hm,
124+
val_obs_pred = deepcopy(res.val_obs_pred),
125+
val_diffs = hasproperty(res, :val_diffs) ? deepcopy(res.val_diffs) : nothing,
126+
meta = (bs=bs, lr=lr, act=act, best_epoch=best_idx,
127+
r2=best_r2_here, mse=best_mse_here),
128+
# convenience fields
129+
mBD_physical = mBD_phys,
130+
mBD_unconstr = mBD_raw,
131+
oBD_phys = oBD_phys
132+
)
133+
end
134+
end
135+
136+
df_results = DataFrame(
137+
batch_size = [r[1] for r in results],
138+
learning_rate = [r[2] for r in results],
139+
activation = [string(r[3]) for r in results],
140+
r2 = [r[4] for r in results],
141+
mse = [r[5] for r in results],
142+
best_epoch = [r[6] for r in results]
143+
)
144+
145+
out_file = joinpath(results_dir, "$(testid)_parameter_search.csv")
146+
CSV.write(out_file, df_results)
147+
148+
149+
# print best model
150+
@assert best_bundle !== nothing "No valid model found for $testid"
151+
bm = best_bundle
152+
@save joinpath(results_dir, "$(testid)_best_model.jld2") \
153+
ps=best_bundle.ps st=best_bundle.st model=best_bundle.model \
154+
val_obs_pred=best_bundle.val_obs_pred val_diffs=best_bundle.val_diffs \
155+
meta=best_bundle.meta \
156+
mBD_physical=best_bundle.mBD_physical mBD_unconstr=best_bundle.mBD_unconstr \
157+
oBD_phys=best_bundle.oBD_phys
158+
# @load joinpath(results_dir, "best_model_$(tgt).jld2") ps st model val_obs_pred meta
159+
@info "Best for $testid: bs=$(bm.meta.bs), lr=$(bm.meta.lr), act=$(bm.meta.act), epoch=$(bm.meta.best_epoch), R2=$(round(best_r2, digits=4))"
160+
161+
# load predictions
162+
jld = joinpath(results_dir, "$(testid)_best_model.jld2")
163+
@assert isfile(jld) "Missing $(jld). Did you train & save best model for $(tname)?"
164+
@load jld val_obs_pred meta
165+
# split output table
166+
val_tables = Dict{Symbol,DataFrame}()
167+
for t in targets
168+
# expected: t (true), t_pred (pred), and maybe :index if the framework saved it
169+
have_pred = Symbol(t, :_pred)
170+
req = Set((t, have_pred))
171+
@assert issubset(req, Symbol.(names(val_obs_pred))) "val_obs_pred missing $(collect(req)) for $(t). Columns: $(names(val_obs_pred))"
172+
keep = [:index, t, have_pred]
173+
val_tables[t] = val_obs_pred[:, keep]
174+
end
175+
176+
177+
# helper for metrics calculation
178+
r2_mse(y_true, y_pred) = begin
179+
ss_res = sum((y_true .- y_pred).^2)
180+
ss_tot = sum((y_true .- mean(y_true)).^2)
181+
r2 = 1 - ss_res / ss_tot
182+
mse = mean((y_true .- y_pred).^2)
183+
(r2, mse)
184+
end
185+
186+
# accuracy plots for SOCconc, BD, CF in original space
187+
for tname in targets
188+
df_out = val_tables[tname]
189+
@assert all(in(Symbol.(names(df_out))).([tname, Symbol("$(tname)_pred")])) "Expected columns $(tname) and $(tname)_pred in saved val table."
190+
191+
y_val_true = back_transform(df_out[:, tname], tname, MINMAX)
192+
y_val_pred = back_transform(df_out[:, Symbol("$(tname)_pred")], tname, MINMAX)
193+
194+
r2, mse = r2_mse(y_val_true, y_val_pred)
195+
196+
plt = histogram2d(
197+
y_val_true, y_val_pred;
198+
nbins=(40, 40), cbar=true, xlab="True", ylab="Predicted",
199+
title = string(tname, "\nR²=", round(r2, digits=3), ", MSE=", round(mse, digits=3)),
200+
normalize=false
201+
)
202+
lims = extrema(vcat(y_val_true, y_val_pred))
203+
Plots.plot!(plt, [lims[1], lims[2]], [lims[1], lims[2]];
204+
color=:black, linewidth=2, label="1:1 line",
205+
aspect_ratio=:equal, xlims=lims, ylims=lims
206+
)
207+
savefig(plt, joinpath(results_dir, "$(testid)_accuracy_$(tname).png"))
208+
end
209+
210+
# BD vs SOCconc predictions
211+
plt = histogram2d(
212+
df_soc[:,:BD_pred], df_soc[:,:SOCconc_pred];
213+
nbins = (30, 30),
214+
cbar = true,
215+
xlab = "BD",
216+
ylab = "SOCconc",
217+
color = cgrad(:bamako, rev=true),
218+
normalize = false,
219+
size = (460, 400)
220+
)
221+
savefig(plt, joinpath(results_dir, "$(testid)_BD.vs.SOCconc.png"));
222+
223+
224+
# save / print parameters: mBD and per-sample oBD
225+
# mBD global
226+
mBD_learned = EasyHybrid.scale_single_param(:mBD, bm.ps[:mBD], bm.model.parameters) |> vec |> first
227+
@info "Learned mBD ≈ $(round(mBD_learned, digits=4))"
228+
229+
# Try to fetch per-sample oBD predictions from val_diffs (if the trainer provided them)
230+
oBD_vals = nothing
231+
if bm.val_diffs !== nothing && hasproperty(bm.val_diffs, :oBD)
232+
oBD_vals = Array(bm.val_diffs.oBD) # should be a vector matching val rows
233+
@info "Collected $(length(oBD_vals)) oBD predictions from validation."
234+
@save joinpath(results_dir, "$(testid)_val_oBD.jld2") oBD_vals
235+
end

0 commit comments

Comments
 (0)