|
| 1 | +using Pkg |
| 2 | +Pkg.activate("projects/BulkDSOC") |
| 3 | +Pkg.develop(path=pwd()) |
| 4 | +Pkg.instantiate() |
| 5 | + |
| 6 | +using Revise |
| 7 | +using EasyHybrid |
| 8 | +using Lux |
| 9 | +using Optimisers |
| 10 | +using GLMakie |
| 11 | +using Random |
| 12 | +using LuxCore |
| 13 | +using CSV, DataFrames |
| 14 | +using EasyHybrid.MLUtils |
| 15 | +using Statistics |
| 16 | +using Plots |
| 17 | +using JLD2 |
| 18 | + |
| 19 | +# 04 - hybrid |
| 20 | +testid = "04a_hybridBD"; |
| 21 | +results_dir = joinpath(@__DIR__, "eval"); |
| 22 | + |
| 23 | +# input |
| 24 | +raw = CSV.read(joinpath(@__DIR__, "data/lucas_preprocessed.csv"), DataFrame; normalizenames=true); |
| 25 | +raw = dropmissing(raw); # to be discussed, as now train.jl seems to allow training with sparse data |
| 26 | +raw .= Float32.(raw); |
| 27 | +df = raw |
| 28 | + |
| 29 | +# mechanistic model |
| 30 | +function BD_model(; SOCconc, oBD, mBD) |
| 31 | + BD = (oBD .* mBD) ./ (1.724f0 .* SOCconc .* mBD .+ (1f0 .- 1.724f0 .* SOCconc) .* oBD) |
| 32 | + return (; BD, SOCconc, oBD, mBD) # supervise both BD and SOCconc |
| 33 | +end |
| 34 | + |
| 35 | +# param bounds |
| 36 | +parameters = ( |
| 37 | + SOCconc = (0.01f0, 0.0f0, 1.0f0), # fraction |
| 38 | + oBD = (1.30f0, 0.90f0, 1.80f0), # g/cm^3 |
| 39 | + mBD = (1.50f0, 0.80f0, 2.0f0), # global |
| 40 | +) |
| 41 | + |
| 42 | +# define param for hybrid model |
| 43 | +neural_param_names = [:SOCconc, :oBD] |
| 44 | +global_param_names = [:mBD] |
| 45 | +forcing = Symbol[] |
| 46 | +targets = [:BD, :SOCconc] # SOCconc is both a param and a target |
| 47 | + |
| 48 | +# just exclude targets explicitly to be safe |
| 49 | +predictors = setdiff(Symbol.(names(df)), targets); # first 3 and last 1 |
| 50 | +nf = length(predictors); |
| 51 | + |
| 52 | +# search space |
| 53 | +batch_sizes = [32, 64, 128, 256, 512]; |
| 54 | +lrs = [1e-3, 1e-4]; |
| 55 | +acts = [swish, gelu]; |
| 56 | + |
| 57 | +# store results |
| 58 | +results = [] |
| 59 | +best_r2 = -Inf |
| 60 | +best_bundle = nothing |
| 61 | + |
| 62 | +for bs in batch_sizes, lr in lrs, act in acts |
| 63 | + @info "Testing bs=$(bs), lr=$(lr), act=$(act)" |
| 64 | + |
| 65 | + hm = constructHybridModel( |
| 66 | + predictors, # single NN uses a Vector of predictors |
| 67 | + forcing, |
| 68 | + targets, |
| 69 | + BD_model, |
| 70 | + parameters, |
| 71 | + neural_param_names, |
| 72 | + global_param_names; |
| 73 | + hidden_layers = [256, 128, 64, 32, 16], |
| 74 | + activation = act, |
| 75 | + scale_nn_outputs = true, |
| 76 | + input_batchnorm = true, |
| 77 | + start_from_default = true |
| 78 | + ) |
| 79 | + |
| 80 | + res = train( |
| 81 | + hm, df, (); |
| 82 | + nepochs = 200, |
| 83 | + batchsize = bs, |
| 84 | + opt = AdamW(lr), |
| 85 | + training_loss = :mse, |
| 86 | + loss_types = [:mse, :r2], |
| 87 | + shuffleobs = true, |
| 88 | + file_name = nothing, |
| 89 | + random_seed = 42, |
| 90 | + patience = 20, |
| 91 | + yscale = identity, |
| 92 | + monitor_names = [:oBD, :mBD], |
| 93 | + agg = mean, |
| 94 | + return_model = :best, |
| 95 | + show_progress = false |
| 96 | + ) |
| 97 | + |
| 98 | + # retrieve the best epoch metrics: mse and r2 |
| 99 | + agg_name = Symbol("mean") |
| 100 | + r2s = map(vh -> getproperty(vh, agg_name), res.val_history.r2) |
| 101 | + mses = map(vh -> getproperty(vh, agg_name), res.val_history.mse) |
| 102 | + best_idx = findmax(r2s)[2] # index of best r2 |
| 103 | + best_r2_here = r2s[best_idx] |
| 104 | + best_mse_here = mses[best_idx] |
| 105 | + |
| 106 | + push!(results, (bs, lr, act, best_r2_here, best_mse_here, best_idx)) |
| 107 | + |
| 108 | + # keep the whole bundle if better |
| 109 | + if !isnan(best_r2_here) && best_r2_here > best_r2 |
| 110 | + best_r2 = best_r2_here |
| 111 | + |
| 112 | + # map global mBD -> physical |
| 113 | + mBD_phys = EasyHybrid.scale_single_param(:mBD, res.ps[:mBD], hm.parameters) |> vec |> first |
| 114 | + mBD_raw = res.ps[:mBD][1] # unconstrained optimizer value |
| 115 | + |
| 116 | + # per-sample oBD |
| 117 | + oBD_phys = (hasproperty(res, :val_diffs) && hasproperty(res.val_diffs, :oBD)) ? |
| 118 | + collect(res.val_diffs.oBD) : nothing |
| 119 | + |
| 120 | + best_bundle = ( |
| 121 | + ps = deepcopy(res.ps), |
| 122 | + st = deepcopy(res.st), |
| 123 | + model = hm, |
| 124 | + val_obs_pred = deepcopy(res.val_obs_pred), |
| 125 | + val_diffs = hasproperty(res, :val_diffs) ? deepcopy(res.val_diffs) : nothing, |
| 126 | + meta = (bs=bs, lr=lr, act=act, best_epoch=best_idx, |
| 127 | + r2=best_r2_here, mse=best_mse_here), |
| 128 | + # convenience fields |
| 129 | + mBD_physical = mBD_phys, |
| 130 | + mBD_unconstr = mBD_raw, |
| 131 | + oBD_phys = oBD_phys |
| 132 | + ) |
| 133 | + end |
| 134 | +end |
| 135 | + |
| 136 | +df_results = DataFrame( |
| 137 | + batch_size = [r[1] for r in results], |
| 138 | + learning_rate = [r[2] for r in results], |
| 139 | + activation = [string(r[3]) for r in results], |
| 140 | + r2 = [r[4] for r in results], |
| 141 | + mse = [r[5] for r in results], |
| 142 | + best_epoch = [r[6] for r in results] |
| 143 | +) |
| 144 | + |
| 145 | +out_file = joinpath(results_dir, "$(testid)_parameter_search.csv") |
| 146 | +CSV.write(out_file, df_results) |
| 147 | + |
| 148 | + |
| 149 | +# print best model |
| 150 | +@assert best_bundle !== nothing "No valid model found for $testid" |
| 151 | +bm = best_bundle |
| 152 | +@save joinpath(results_dir, "$(testid)_best_model.jld2") \ |
| 153 | + ps=best_bundle.ps st=best_bundle.st model=best_bundle.model \ |
| 154 | + val_obs_pred=best_bundle.val_obs_pred val_diffs=best_bundle.val_diffs \ |
| 155 | + meta=best_bundle.meta \ |
| 156 | + mBD_physical=best_bundle.mBD_physical mBD_unconstr=best_bundle.mBD_unconstr \ |
| 157 | + oBD_phys=best_bundle.oBD_phys |
| 158 | +# @load joinpath(results_dir, "best_model_$(tgt).jld2") ps st model val_obs_pred meta |
| 159 | +@info "Best for $testid: bs=$(bm.meta.bs), lr=$(bm.meta.lr), act=$(bm.meta.act), epoch=$(bm.meta.best_epoch), R2=$(round(best_r2, digits=4))" |
| 160 | + |
| 161 | +# load predictions |
| 162 | +jld = joinpath(results_dir, "$(testid)_best_model.jld2") |
| 163 | +@assert isfile(jld) "Missing $(jld). Did you train & save best model for $(tname)?" |
| 164 | +@load jld val_obs_pred meta |
| 165 | +# split output table |
| 166 | +val_tables = Dict{Symbol,DataFrame}() |
| 167 | +for t in targets |
| 168 | + # expected: t (true), t_pred (pred), and maybe :index if the framework saved it |
| 169 | + have_pred = Symbol(t, :_pred) |
| 170 | + req = Set((t, have_pred)) |
| 171 | + @assert issubset(req, Symbol.(names(val_obs_pred))) "val_obs_pred missing $(collect(req)) for $(t). Columns: $(names(val_obs_pred))" |
| 172 | + keep = [:index, t, have_pred] |
| 173 | + val_tables[t] = val_obs_pred[:, keep] |
| 174 | +end |
| 175 | + |
| 176 | + |
| 177 | +# helper for metrics calculation |
| 178 | +r2_mse(y_true, y_pred) = begin |
| 179 | + ss_res = sum((y_true .- y_pred).^2) |
| 180 | + ss_tot = sum((y_true .- mean(y_true)).^2) |
| 181 | + r2 = 1 - ss_res / ss_tot |
| 182 | + mse = mean((y_true .- y_pred).^2) |
| 183 | + (r2, mse) |
| 184 | +end |
| 185 | + |
| 186 | +# accuracy plots for SOCconc, BD, CF in original space |
| 187 | +for tname in targets |
| 188 | + df_out = val_tables[tname] |
| 189 | + @assert all(in(Symbol.(names(df_out))).([tname, Symbol("$(tname)_pred")])) "Expected columns $(tname) and $(tname)_pred in saved val table." |
| 190 | + |
| 191 | + y_val_true = back_transform(df_out[:, tname], tname, MINMAX) |
| 192 | + y_val_pred = back_transform(df_out[:, Symbol("$(tname)_pred")], tname, MINMAX) |
| 193 | + |
| 194 | + r2, mse = r2_mse(y_val_true, y_val_pred) |
| 195 | + |
| 196 | + plt = histogram2d( |
| 197 | + y_val_true, y_val_pred; |
| 198 | + nbins=(40, 40), cbar=true, xlab="True", ylab="Predicted", |
| 199 | + title = string(tname, "\nR²=", round(r2, digits=3), ", MSE=", round(mse, digits=3)), |
| 200 | + normalize=false |
| 201 | + ) |
| 202 | + lims = extrema(vcat(y_val_true, y_val_pred)) |
| 203 | + Plots.plot!(plt, [lims[1], lims[2]], [lims[1], lims[2]]; |
| 204 | + color=:black, linewidth=2, label="1:1 line", |
| 205 | + aspect_ratio=:equal, xlims=lims, ylims=lims |
| 206 | + ) |
| 207 | + savefig(plt, joinpath(results_dir, "$(testid)_accuracy_$(tname).png")) |
| 208 | +end |
| 209 | + |
| 210 | +# BD vs SOCconc predictions |
| 211 | +plt = histogram2d( |
| 212 | + df_soc[:,:BD_pred], df_soc[:,:SOCconc_pred]; |
| 213 | + nbins = (30, 30), |
| 214 | + cbar = true, |
| 215 | + xlab = "BD", |
| 216 | + ylab = "SOCconc", |
| 217 | + color = cgrad(:bamako, rev=true), |
| 218 | + normalize = false, |
| 219 | + size = (460, 400) |
| 220 | +) |
| 221 | +savefig(plt, joinpath(results_dir, "$(testid)_BD.vs.SOCconc.png")); |
| 222 | + |
| 223 | + |
| 224 | +# save / print parameters: mBD and per-sample oBD |
| 225 | +# mBD global |
| 226 | +mBD_learned = EasyHybrid.scale_single_param(:mBD, bm.ps[:mBD], bm.model.parameters) |> vec |> first |
| 227 | +@info "Learned mBD ≈ $(round(mBD_learned, digits=4))" |
| 228 | + |
| 229 | +# Try to fetch per-sample oBD predictions from val_diffs (if the trainer provided them) |
| 230 | +oBD_vals = nothing |
| 231 | +if bm.val_diffs !== nothing && hasproperty(bm.val_diffs, :oBD) |
| 232 | + oBD_vals = Array(bm.val_diffs.oBD) # should be a vector matching val rows |
| 233 | + @info "Collected $(length(oBD_vals)) oBD predictions from validation." |
| 234 | + @save joinpath(results_dir, "$(testid)_val_oBD.jld2") oBD_vals |
| 235 | +end |
0 commit comments