Skip to content

Commit bd0483c

Browse files
committed
Add external forcing types to gcm-driven scm calibration, allowing for both shallow and deep convection. Includes option for defining stretched calibration grid.
Add cal grid to edmf_ensemble_stats gcm driven calibration updates: restart, plotting support for variable grids, add plot prior and normalize var scripts, add microphysics cal option Add NN mixing length closure, add load pretrain weights logic in prior Add clippings for NN inputs, dev calibration Add scm_runner tools to run SCMs (in parallel) across cases for a given parameter set scm runner updates limit nn mixing length by smag length + (1/z), and add new dz input variable. Remove problematic deep convective cases from library. Update minibatcher to specify cases for each epoch. Add serialize_std_model with default in run_cal to allow different priors for NN weights and biases. Increase request limits for top-level slurm sbatch script, preventing minor memory leak from killing calibration. Increase noise by an order of mag. Update, improve restart calibration + allow to work with NNs. Limit number of processes per ensemble member by 5 to avoid queue waits . Add back deep deep conv cases, clip ql, qi above 0 in les cases. Fix get_optimal_particle nearest neighbor mean getter. Simplify nn prior creation at beginning of run_calibration. Request less time on nodes for scm runs. Limit num_cpu to 5 in restart script. Increase noise, use batch size = 20. And cfsite info to plot_ensemble plots. Add radiation metrics to runner model config. Add nn_helper functions. Add precal prior toml. Increase noise and batch size (30). Increase prior sigma for NN (0.05 for weights). Lower prior std for mixing_length_diss_coeff. Turn off accelerator. Make shallow library default in get_les_metadata. Make leaky relu NN default. Change default noise with leaky relu (std_weight = 0.1, std_bias = 0.00001). Increase default runtime for restart for ens members to 180 mins. Increase resource requests for scm_runner. Make leaky relu NN default (and write to scratch). Pass args to scm_runner in sbatch script. Add CLI option for scm_runner, more flexible parsing of cfsite (for different AMIP experiments). Switch between NN and standard priors in restart calibration script Add batch size logic in edmf_ensemble_stats
1 parent 7954db9 commit bd0483c

30 files changed

+2114
-253
lines changed

calibration/experiments/gcm_driven_scm/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513"
1111
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1212
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1313
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
14+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1415
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
1516
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
1617
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
@@ -19,5 +20,5 @@ NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
1920
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
2021

2122
[compat]
22-
ClimaCalibrate = "=0.0.3"
23-
EnsembleKalmanProcesses = "2"
23+
EnsembleKalmanProcesses = "2.1.2"
24+
ClimaCalibrate = "0.0.8"

calibration/experiments/gcm_driven_scm/README.md

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,33 @@
11
# Overview of Calibration Pipeline for GCM-Driven Single Column Model (EDMF)
22

3-
This setup provide tools for calibrating both prognostic and diagnostic EDMF variants to LES profiles, given the same forcings and boundary conditions. The gcm-driven EDMF setup is employed in single-column mode, which uses both interactive radiation and surface fluxes. Forcing profiles include resolved eddy advection, horizontal advection, subsidence, and GCM state relaxation. The setup is run to the top of the atmosphere to compute radiation, but calibrations statistics are computed only on the lower 4km (`z_max`), where LES output is available.
3+
## Pipeline Components
4+
5+
### Configuration Files
6+
- `experiment_config.yml` - Configuration of calibration settings
7+
- Defines spatiotemporal calibration window
8+
- Specifies required pipeline file paths
9+
- Controls batch processing parameters
10+
11+
- `model_config_**.yml` - Configuration for ClimaAtmos single column model
12+
- Defines model-specific parameters
13+
- Allows for multiple model configuration variants
14+
15+
## Best Practices
16+
- Ensure `batch_size` matches available LES configurations
17+
- Verify normalization factors for each variable
18+
- Monitor ensemble convergence using provided plotting tools
19+
20+
This setup provide tools for calibrating both prognostic and diagnostic EDMF variants to LES profiles, given the same forcings and boundary conditions. The gcm-driven EDMF setup is employed in single-column mode, which uses both interactive radiation and surface fluxes. Forcing profiles include resolved eddy advection, horizontal advection, subsidence, and GCM state relaxation. The setup is run to the top of the atmosphere to compute radiation, but calibration statistics are only computed on the calibration grid `z_cal_grid`.
421

522
LES profiles are available for different geolocations ("cfsites"), spanning seasons, forcing host models, and climates (AMIP, AMIP4K). A given LES simulation is referred to as a "configuration". Calibrations employ batching by default and stack multiple configurations (a number equal to the `batch_size`) in a given iteration. The observation vector for a single configuration is formed by concatenating profiles across calibration variables, where each geophysical variable is normalized to have approximately unit variance and zero mean. These variable-by-variable normalization factors are precomputed (`norm_factors_dict`) and applied to all observations. Following this operation, the spatiotemporal calibration window is applied and temporal means are computed to form the observation vector `y`. Because variables are normalized to have 0 mean and unit variance, a constant diagonal noise matrix is used (configurable as `const_noise`).
623

724

25+
### Observation Map
26+
1. **Time-mean**: Time-mean of profiles taken between [`y_t_start_sec`, `y_t_end_sec`] for `y` and [`g_t_start_sec`, `g_t_end_sec`] for `G`.
27+
2. **Interpolation**: Case-specific (i.e., "shallow", "deep") interpolation to the calibration grid, defined with stretch-grid parameters in `z_cal_grid`.
28+
3. **Normalization**: Variable-specific normalization using the mean and standard deviation defined in `norm_factors_by_var`. Optionally, take log of variables using `log_vars` before normalization.
29+
4. **Concatenation**: Stack across cases in a batch, forming `y`, `G`.
30+
831
## Getting Started
932

1033
### Define calibration and model configurations:
@@ -20,6 +43,5 @@ LES profiles are available for different geolocations ("cfsites"), spanning seas
2043
### Analyze output with:
2144
- `julia --project plot_ensemble.jl` - plots vertical profiles of all ensemble members in a given iteration, given path to calibration output
2245
- `julia --project edmf_ensemble_stats.jl` - computes and plots metrics offline [i.e., root mean squared error (RMSE)] as a function of iteration, given path to calibration output.
23-
- `julia --project plot_eki.jl` - plot eki metrics [loss, var-weighted loss] and `y`, `g` vectors vs iteration, display best particles
24-
46+
- `julia --project plot_eki.jl` - plot eki metrics [loss, variance-weighted loss] and `y`, `G` vectors vs iteration, display best particles
2547

calibration/experiments/gcm_driven_scm/edmf_ensemble_stats.jl

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
#!/usr/bin/env julia
22

3-
import ClimaComms
4-
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends
5-
63
using ArgParse
74
using Distributed
8-
addprocs()
9-
5+
addprocs(1)
106

117
@everywhere begin
128
using EnsembleKalmanProcesses: TOMLInterface
@@ -60,6 +56,12 @@ function parse_args()
6056
return parse_with_settings(s)
6157
end
6258

59+
@everywhere function validate_ensemble_member(iteration_dir, batch_size)
60+
config_dirs = filter(x -> isdir(joinpath(iteration_dir, x)), readdir(iteration_dir))
61+
num_configs = count(x -> startswith(x, "config_"), config_dirs)
62+
return num_configs == batch_size
63+
end
64+
6365
function main()
6466
args = parse_args()
6567

@@ -87,6 +89,7 @@ function main()
8789
cal_vars = config_dict["y_var_names"]
8890
const_noise_by_var = config_dict["const_noise_by_var"]
8991
n_iterations = config_dict["n_iterations"]
92+
batch_size = config_dict["batch_size"]
9093
model_config_dict =
9194
YAML.load_file(joinpath(output_dir, "configs", "model_config.yml"))
9295

@@ -95,9 +98,6 @@ function main()
9598
end
9699

97100
ref_paths, _ = get_les_calibration_library()
98-
comms_ctx = ClimaComms.SingletonCommsContext()
99-
atmos_config = CA.AtmosConfig(model_config_dict; comms_ctx)
100-
zc_model = get_z_grid(atmos_config, z_max = z_max)
101101

102102
@everywhere function calculate_statistics(y_var)
103103
non_nan_values = y_var[.!isnan.(y_var)]
@@ -124,9 +124,9 @@ function main()
124124
cal_vars,
125125
const_noise_by_var,
126126
ref_paths,
127-
zc_model,
128127
reduction,
129128
ensemble_size,
129+
batch_size,
130130
)
131131
println("Processing Iteration: $iteration")
132132
stats_df = DataFrame(
@@ -141,13 +141,18 @@ function main()
141141
rmse_std = Union{Missing, Float64}[],
142142
)
143143
config_indices = get_batch_indicies_in_iteration(iteration, output_dir)
144+
iteration_dir = joinpath(output_dir, "iteration_$(lpad(iteration, 3, '0'))")
145+
146+
valid_ensemble_members = filter(config_i -> validate_ensemble_member(joinpath(iteration_dir, "member_$(lpad(config_i, 3, '0'))"), batch_size), config_indices)
147+
144148
for var_name in var_names
145149
means = Float64[]
146150
maxs = Float64[]
147151
mins = Float64[]
148152
sum_squared_errors = zeros(Float64, ensemble_size)
149-
for config_i in config_indices
150-
data = ensemble_data(
153+
154+
for config_i in valid_ensemble_members
155+
data, zc_model = ensemble_data(
151156
process_profile_variable,
152157
iteration,
153158
config_i,
@@ -157,6 +162,7 @@ function main()
157162
output_dir = output_dir,
158163
z_max = z_max,
159164
n_vert_levels = n_vert_levels,
165+
return_z_interp = true,
160166
)
161167
for i in 1:size(data, 2)
162168
y_var = data[:, i]
@@ -166,25 +172,31 @@ function main()
166172
push!(mins, col_min)
167173
end
168174
if in(var_name, cal_vars)
175+
ref_path = ref_paths[config_i]
176+
cfsite_number, _, _, _ = parse_les_path(ref_path)
177+
forcing_type = get_cfsite_type(cfsite_number)
178+
179+
ti = config_dict["y_t_start_sec"]
180+
ti = isa(ti, AbstractFloat) ? ti : ti[forcing_type]
181+
tf = config_dict["y_t_end_sec"]
182+
tf = isa(tf, AbstractFloat) ? tf : tf[forcing_type]
183+
169184
y_true, Σ_obs, norm_vec_obs = get_obs(
170-
ref_paths[config_i],
185+
ref_path,
171186
[var_name],
172187
zc_model;
173-
ti = config_dict["y_t_start_sec"],
174-
tf = config_dict["y_t_end_sec"],
188+
ti = ti,
189+
tf = tf,
175190
Σ_const = const_noise_by_var,
176191
z_score_norm = false,
177192
)
178-
sum_squared_errors +=
179-
compute_ensemble_squared_error(data, y_true)
193+
sum_squared_errors += compute_ensemble_squared_error(data, y_true)
180194
end
181195
end
196+
182197
if in(var_name, cal_vars)
183-
# Compute RMSE per ensemble member
184198
rmse_per_member = sqrt.(sum_squared_errors / n_vert_levels)
185-
# Filter out NaNs (failed simulations)
186199
valid_rmse = rmse_per_member[.!isnan.(rmse_per_member)]
187-
non_nan_simulation_count = length(valid_rmse)
188200
mean_rmse = mean(valid_rmse)
189201
min_rmse = minimum(valid_rmse)
190202
max_rmse = maximum(valid_rmse)
@@ -226,9 +238,9 @@ function main()
226238
cal_vars,
227239
const_noise_by_var,
228240
ref_paths,
229-
zc_model,
230241
reduction,
231242
ensemble_size,
243+
batch_size,
232244
),
233245
iterations_list,
234246
)
Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,57 @@
1-
prior_path: prior_prognostic_pi_entr.toml
2-
ensemble_size: 100
3-
n_iterations: 12
1+
2+
3+
prior_path: prior_prognostic_pi_entr_smooth_entr_detr_impl_0M_v2.toml
4+
5+
# ensemble_size: 130
6+
ensemble_size: 200
7+
n_iterations: 5
48
batch_size: 2 # number of cases per iteration
5-
model_config : model_config_prognostic.yml # options {model_config_prognostic.yml, model_config_diagnostic.yml}
6-
output_dir : output/exp_1 # output dir
7-
y_var_names: [thetaa, hus, clw] # calibration variables clw
9+
# batch_size: 5 # number of cases per iteration
10+
# model_config : model_config_prognostic.yml # options {model_config_prognostic.yml, model_config_diagnostic.yml}
11+
model_config : model_config_prognostic_impl.yml
12+
# output_dir : /groups/esm/cchristo/climaatmos_scm_calibrations/output_ml_mix/exp_43 # output dir
13+
14+
output_dir : /central/scratch/cchristo/edmf_impl_dev4/exp25
15+
y_var_names: [thetaa, hus, clw] # calibration variables clw clw]
816
log_vars: ["clw"] # take log(var) when forming y, g
9-
z_max : 4000 # spatial subsetting: use statistics from [0, z_max] (in [m]) for calibration
10-
dims_per_var : 29 # num dimensions per variable (num cells in vertical profile below z_max)
17+
# log_vars: []
18+
19+
nice_loc_ug: 0.01
20+
nice_loc_gg: 0.5
21+
22+
z_max: null
23+
z_cal_grid: # calibration grid (stretch-grid parameters). In general, `z_elem` should be the same for all types
24+
shallow:
25+
z_max: 4000.0
26+
z_elem: 30
27+
dz_bottom: 30
28+
deep:
29+
z_max: 15000.0
30+
z_elem: 30
31+
dz_bottom: 30
32+
dims_per_var : 30 # num dimensions per variable (num cells in vertical profile below z_max)
1133
# eki_timestep: 0.1 # timestep of eki, if using default
12-
y_t_start_sec : 475200.0 # start time of LES averaging window [s] : 5.5 days
13-
y_t_end_sec : 518400.0 # end time of LES averaging window [s] : 6 days (LES length = 6 days)
14-
g_t_start_sec : 216000.0 # start time of SCM averaging window [s] : 2.5 days
15-
g_t_end_sec : 259200.0 # end time of SCM averaging window [s] : 3 days (SCM length = 3 days)
34+
35+
y_t_start_sec: # start time of LES averaging window [s]
36+
shallow: 475200.0 # 5.5 days
37+
deep: 302400.0 # 3.5 days
38+
y_t_end_sec: # end time of LES averaging window [s]
39+
shallow: 518400.0 # 6 days (LES length = 6 days)
40+
deep: 345600.0 # 4 days (LES length = 4 days)
41+
g_t_start_sec: 216000.0 # start time of SCM averaging window [s] : 2.5 days
42+
g_t_end_sec: 259200.0 # end time of SCM averaging window [s] : 3 days (SCM length = 3 days)
1643

1744
norm_factors_by_var:
18-
thetaa: [298.828, 8.617]
19-
hus: [0.00676, 0.00423]
20-
clw: [-9.808, 3.116] # log norm factors
45+
thetaa: [301.218, 15.235]
46+
hus: [0.00672, 0.00477]
47+
clw: [-9.579, 3.164] # log norm factors
48+
# cli: [-11.697, 1.304] # log norm factors
2149

2250
const_noise_by_var:
23-
thetaa: 0.00005
24-
hus: 0.00005
25-
clw: 0.00005
51+
thetaa: 0.0016
52+
hus: 0.0016
53+
clw: 0.0045
54+
# clw: 0.0016
55+
# cli: 0.01
56+
57+
pretrained_nn_path: "/home/cchristo/ml_mixing_length/nn_666p_leaky_relu.jld2"

calibration/experiments/gcm_driven_scm/get_les_metadata.jl

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@ using Glob
33
"""
44
"""
55

6+
# cfSite numbers
7+
CFSITE_TYPES = Dict("shallow" => (collect(4:15)..., collect(17:23)...),
8+
"deep" => (collect(30:33)..., collect(66:70)..., 82, 92, 94, 96, 99, 100))
9+
610
function get_les_calibration_library()
711
les_library = get_shallow_LES_library()
8-
# AMIP4K data: July, NE Pacific
9-
cfsite_numbers = (17, 23)
12+
# AMIP data: July, NE Pacific
13+
# cfsite_numbers = (17, 18, 22, 23, 30, 94)
14+
# cfsite_numbers = (17, 22, 23, 30, 33, 94)
15+
# cfsite_numbers = (17, 21, 23, 30, 33,)# 94)
16+
cfsite_numbers = (30, 33,)# 94)
1017
les_kwargs = (forcing_model = "HadGEM2-A", month = 7, experiment = "amip")
1118
ref_paths = [
1219
get_stats_path(get_cfsite_les_dir(cfsite_number; les_kwargs...)) for
@@ -15,6 +22,20 @@ function get_les_calibration_library()
1522
return (ref_paths, cfsite_numbers)
1623
end
1724

25+
function get_cfsite_type(i, cfsite_numbers)
26+
return get_cfsite_type(cfsite_numbers[i])
27+
end
28+
29+
function get_cfsite_type(cfsite_number)
30+
if cfsite_number in CFSITE_TYPES["shallow"]
31+
return "shallow"
32+
elseif cfsite_number in CFSITE_TYPES["deep"]
33+
return "deep"
34+
else
35+
@error "cfSite number $(cfsite_number) not found in available sites."
36+
end
37+
end
38+
1839
"""
1940
get_LES_library
2041
@@ -25,7 +46,18 @@ and experiments.
2546
"""
2647
function get_LES_library()
2748
LES_library = get_shallow_LES_library()
28-
deep_sites = (collect(30:33)..., collect(66:70)..., 82, 92, 94, 96, 99, 100)
49+
deep_sites = deepcopy(CFSITE_TYPES["deep"])
50+
51+
52+
# remove <0 ql/cli cases
53+
# sites_07 = deepcopy(setdiff(deep_sites, [92, 99, 100]))
54+
# append!(LES_library["HadGEM2-A"]["07"]["cfsite_numbers"], sites_07)
55+
# sites_01 = deepcopy(setdiff(deep_sites, [99,]))
56+
# append!(LES_library["HadGEM2-A"]["01"]["cfsite_numbers"], sites_01)
57+
# sites_04 = deepcopy(setdiff(deep_sites, [32, 92, 94, 96, 99, 100]))
58+
# append!(LES_library["HadGEM2-A"]["04"]["cfsite_numbers"], sites_04)
59+
# sites_10 = deepcopy(setdiff(deep_sites, [92, 94, 99, 100]))
60+
# append!(LES_library["HadGEM2-A"]["10"]["cfsite_numbers"], sites_10)
2961

3062
append!(LES_library["HadGEM2-A"]["07"]["cfsite_numbers"], deep_sites)
3163
append!(LES_library["HadGEM2-A"]["01"]["cfsite_numbers"], deep_sites)
@@ -34,6 +66,7 @@ function get_LES_library()
3466
sites_10 = deepcopy(setdiff(deep_sites, [94, 100]))
3567
append!(LES_library["HadGEM2-A"]["10"]["cfsite_numbers"], sites_10)
3668

69+
3770
LES_library_full = deepcopy(LES_library)
3871
for model in keys(LES_library_full)
3972
for month in keys(LES_library_full[model])
@@ -103,8 +136,7 @@ function get_shallow_LES_library()
103136
"CNRM-CM5" => Dict(),
104137
"CNRM-CM6-1" => Dict(),
105138
)
106-
Shen_et_al_sites = collect(4:15)
107-
append!(Shen_et_al_sites, collect(17:23))
139+
Shen_et_al_sites = collect(deepcopy(CFSITE_TYPES["shallow"]))
108140

109141
# HadGEM2-A model (76 AMIP-AMIP4K pairs)
110142
LES_library["HadGEM2-A"]["10"] = Dict()

0 commit comments

Comments
 (0)