Skip to content

Commit ecf5302

Browse files
committed
Add external forcing types to gcm-driven scm calibration, allowing for both shallow and deep convection. Includes option for defining stretched calibration grid.
Add cal grid to edmf_ensemble_stats gcm driven calibration updates: restart, plotting support for variable grids, add plot prior and normalize var scripts, add microphysics cal option Add NN mixing length closure, add load pretrain weights logic in prior Add clippings for NN inputs, dev calibration Add scm_runner tools to run SCMs (in parallel) across cases for a given parameter set scm runner updates limit nn mixing length by smag length + (1/z), and add new dz input variable. Remove problematic deep convective cases from library. Update minibatcher to specify cases for each epoch. Add serialize_std_model with default in run_cal to allow different priors for NN weights and biases. Increase request limits for top-level slurm sbatch script, preventing minor memory leak from killing calibration. Increase noise by an order of mag. Update, improve restart calibration + allow to work with NNs. Limit number of processes per ensemble member by 5 to avoid queue waits . Add back deep deep conv cases, clip ql, qi above 0 in les cases. Fix get_optimal_particle nearest neighbor mean getter. Simplify nn prior creation at beginning of run_calibration. Request less time on nodes for scm runs. Limit num_cpu to 5 in restart script. Increase noise, use batch size = 20. And cfsite info to plot_ensemble plots. Add radiation metrics to runner model config. Add nn_helper functions. Add precal prior toml. Increase noise and batch size (30). Increase prior sigma for NN (0.05 for weights). Lower prior std for mixing_length_diss_coeff. Turn off accelerator. Make shallow library default in get_les_metadata. Make leaky relu NN default. Change default noise with leaky relu (std_weight = 0.1, std_bias = 0.00001). Increase default runtime for restart for ens members to 180 mins. Increase resource requests for scm_runner. Make leaky relu NN default (and write to scratch). Pass args to scm_runner in sbatch script. Add CLI option for scm_runner, more flexible parsing of cfsite (for different AMIP experiments). Switch between NN and standard priors in restart calibration script Add batch size logic in edmf_ensemble_stats Add forcing parameters as additional tomls
1 parent 1981794 commit ecf5302

33 files changed

+2505
-253
lines changed

calibration/experiments/gcm_driven_scm/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513"
1111
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1212
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1313
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
14+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1415
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
1516
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
1617
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
@@ -19,5 +20,5 @@ NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
1920
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
2021

2122
[compat]
22-
ClimaCalibrate = "=0.0.3"
23-
EnsembleKalmanProcesses = "2"
23+
EnsembleKalmanProcesses = "=2.1.2"
24+
ClimaCalibrate = "=0.0.13"

calibration/experiments/gcm_driven_scm/README.md

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,33 @@
11
# Overview of Calibration Pipeline for GCM-Driven Single Column Model (EDMF)
22

3-
This setup provide tools for calibrating both prognostic and diagnostic EDMF variants to LES profiles, given the same forcings and boundary conditions. The gcm-driven EDMF setup is employed in single-column mode, which uses both interactive radiation and surface fluxes. Forcing profiles include resolved eddy advection, horizontal advection, subsidence, and GCM state relaxation. The setup is run to the top of the atmosphere to compute radiation, but calibrations statistics are computed only on the lower 4km (`z_max`), where LES output is available.
3+
## Pipeline Components
4+
5+
### Configuration Files
6+
- `experiment_config.yml` - Configuration of calibration settings
7+
- Defines spatiotemporal calibration window
8+
- Specifies required pipeline file paths
9+
- Controls batch processing parameters
10+
11+
- `model_config_**.yml` - Configuration for ClimaAtmos single column model
12+
- Defines model-specific parameters
13+
- Allows for multiple model configuration variants
14+
15+
## Best Practices
16+
- Ensure `batch_size` matches available LES configurations
17+
- Verify normalization factors for each variable
18+
- Monitor ensemble convergence using provided plotting tools
19+
20+
This setup provide tools for calibrating both prognostic and diagnostic EDMF variants to LES profiles, given the same forcings and boundary conditions. The gcm-driven EDMF setup is employed in single-column mode, which uses both interactive radiation and surface fluxes. Forcing profiles include resolved eddy advection, horizontal advection, subsidence, and GCM state relaxation. The setup is run to the top of the atmosphere to compute radiation, but calibration statistics are only computed on the calibration grid `z_cal_grid`.
421

522
LES profiles are available for different geolocations ("cfsites"), spanning seasons, forcing host models, and climates (AMIP, AMIP4K). A given LES simulation is referred to as a "configuration". Calibrations employ batching by default and stack multiple configurations (a number equal to the `batch_size`) in a given iteration. The observation vector for a single configuration is formed by concatenating profiles across calibration variables, where each geophysical variable is normalized to have approximately unit variance and zero mean. These variable-by-variable normalization factors are precomputed (`norm_factors_dict`) and applied to all observations. Following this operation, the spatiotemporal calibration window is applied and temporal means are computed to form the observation vector `y`. Because variables are normalized to have 0 mean and unit variance, a constant diagonal noise matrix is used (configurable as `const_noise`).
623

724

25+
### Observation Map
26+
1. **Time-mean**: Time-mean of profiles taken between [`y_t_start_sec`, `y_t_end_sec`] for `y` and [`g_t_start_sec`, `g_t_end_sec`] for `G`.
27+
2. **Interpolation**: Case-specific (i.e., "shallow", "deep") interpolation to the calibration grid, defined with stretch-grid parameters in `z_cal_grid`.
28+
3. **Normalization**: Variable-specific normalization using the mean and standard deviation defined in `norm_factors_by_var`. Optionally, take log of variables using `log_vars` before normalization.
29+
4. **Concatenation**: Stack across cases in a batch, forming `y`, `G`.
30+
831
## Getting Started
932

1033
### Define calibration and model configurations:
@@ -20,6 +43,11 @@ LES profiles are available for different geolocations ("cfsites"), spanning seas
2043
### Analyze output with:
2144
- `julia --project plot_ensemble.jl` - plots vertical profiles of all ensemble members in a given iteration, given path to calibration output
2245
- `julia --project edmf_ensemble_stats.jl` - computes and plots metrics offline [i.e., root mean squared error (RMSE)] as a function of iteration, given path to calibration output.
23-
- `julia --project plot_eki.jl` - plot eki metrics [loss, var-weighted loss] and `y`, `g` vectors vs iteration, display best particles
46+
- `julia --project plot_eki.jl` - plot eki metrics [loss, variance-weighted loss] and `y`, `G` vectors vs iteration, display best particles
2447

48+
## Troubleshooting
2549

50+
- **Memory Issues**: If you encounter out of memory errors, increase the memory allocation in both `run_calibration.sbatch` and the `experiment_config.yml` file. This is particularly important when working with larger batch sizes. Example error message:
51+
```
52+
srun: error: hpc-92-10: task 9: Out Of Memory
53+
```

calibration/experiments/gcm_driven_scm/edmf_ensemble_stats.jl

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
#!/usr/bin/env julia
22

3-
import ClimaComms
4-
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends
5-
63
using ArgParse
74
using Distributed
8-
addprocs()
9-
5+
addprocs(1)
106

117
@everywhere begin
128
using EnsembleKalmanProcesses: TOMLInterface
@@ -60,6 +56,13 @@ function parse_args()
6056
return parse_with_settings(s)
6157
end
6258

59+
@everywhere function validate_ensemble_member(iteration_dir, batch_size)
60+
config_dirs =
61+
filter(x -> isdir(joinpath(iteration_dir, x)), readdir(iteration_dir))
62+
num_configs = count(x -> startswith(x, "config_"), config_dirs)
63+
return num_configs == batch_size
64+
end
65+
6366
function main()
6467
args = parse_args()
6568

@@ -87,6 +90,7 @@ function main()
8790
cal_vars = config_dict["y_var_names"]
8891
const_noise_by_var = config_dict["const_noise_by_var"]
8992
n_iterations = config_dict["n_iterations"]
93+
batch_size = config_dict["batch_size"]
9094
model_config_dict =
9195
YAML.load_file(joinpath(output_dir, "configs", "model_config.yml"))
9296

@@ -95,9 +99,6 @@ function main()
9599
end
96100

97101
ref_paths, _ = get_les_calibration_library()
98-
comms_ctx = ClimaComms.SingletonCommsContext()
99-
atmos_config = CA.AtmosConfig(model_config_dict; comms_ctx)
100-
zc_model = get_z_grid(atmos_config, z_max = z_max)
101102

102103
@everywhere function calculate_statistics(y_var)
103104
non_nan_values = y_var[.!isnan.(y_var)]
@@ -124,9 +125,9 @@ function main()
124125
cal_vars,
125126
const_noise_by_var,
126127
ref_paths,
127-
zc_model,
128128
reduction,
129129
ensemble_size,
130+
batch_size,
130131
)
131132
println("Processing Iteration: $iteration")
132133
stats_df = DataFrame(
@@ -141,13 +142,25 @@ function main()
141142
rmse_std = Union{Missing, Float64}[],
142143
)
143144
config_indices = get_batch_indicies_in_iteration(iteration, output_dir)
145+
iteration_dir =
146+
joinpath(output_dir, "iteration_$(lpad(iteration, 3, '0'))")
147+
148+
valid_ensemble_members = filter(
149+
config_i -> validate_ensemble_member(
150+
joinpath(iteration_dir, "member_$(lpad(config_i, 3, '0'))"),
151+
batch_size,
152+
),
153+
config_indices,
154+
)
155+
144156
for var_name in var_names
145157
means = Float64[]
146158
maxs = Float64[]
147159
mins = Float64[]
148160
sum_squared_errors = zeros(Float64, ensemble_size)
149-
for config_i in config_indices
150-
data = ensemble_data(
161+
162+
for config_i in valid_ensemble_members
163+
data, zc_model = ensemble_data(
151164
process_profile_variable,
152165
iteration,
153166
config_i,
@@ -157,6 +170,7 @@ function main()
157170
output_dir = output_dir,
158171
z_max = z_max,
159172
n_vert_levels = n_vert_levels,
173+
return_z_interp = true,
160174
)
161175
for i in 1:size(data, 2)
162176
y_var = data[:, i]
@@ -166,25 +180,32 @@ function main()
166180
push!(mins, col_min)
167181
end
168182
if in(var_name, cal_vars)
183+
ref_path = ref_paths[config_i]
184+
cfsite_number, _, _, _ = parse_les_path(ref_path)
185+
forcing_type = get_cfsite_type(cfsite_number)
186+
187+
ti = config_dict["y_t_start_sec"]
188+
ti = isa(ti, AbstractFloat) ? ti : ti[forcing_type]
189+
tf = config_dict["y_t_end_sec"]
190+
tf = isa(tf, AbstractFloat) ? tf : tf[forcing_type]
191+
169192
y_true, Σ_obs, norm_vec_obs = get_obs(
170-
ref_paths[config_i],
193+
ref_path,
171194
[var_name],
172195
zc_model;
173-
ti = config_dict["y_t_start_sec"],
174-
tf = config_dict["y_t_end_sec"],
196+
ti = ti,
197+
tf = tf,
175198
Σ_const = const_noise_by_var,
176199
z_score_norm = false,
177200
)
178201
sum_squared_errors +=
179202
compute_ensemble_squared_error(data, y_true)
180203
end
181204
end
205+
182206
if in(var_name, cal_vars)
183-
# Compute RMSE per ensemble member
184207
rmse_per_member = sqrt.(sum_squared_errors / n_vert_levels)
185-
# Filter out NaNs (failed simulations)
186208
valid_rmse = rmse_per_member[.!isnan.(rmse_per_member)]
187-
non_nan_simulation_count = length(valid_rmse)
188209
mean_rmse = mean(valid_rmse)
189210
min_rmse = minimum(valid_rmse)
190211
max_rmse = maximum(valid_rmse)
@@ -226,9 +247,9 @@ function main()
226247
cal_vars,
227248
const_noise_by_var,
228249
ref_paths,
229-
zc_model,
230250
reduction,
231251
ensemble_size,
252+
batch_size,
232253
),
233254
iterations_list,
234255
)
Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,61 @@
1-
prior_path: prior_prognostic_pi_entr.toml
1+
prior_path: prior_prognostic_pi_entr_smooth_entr_detr_impl_0M_v1.toml
22
ensemble_size: 100
3-
n_iterations: 12
4-
batch_size: 2 # number of cases per iteration
5-
model_config : model_config_prognostic.yml # options {model_config_prognostic.yml, model_config_diagnostic.yml}
6-
output_dir : output/exp_1 # output dir
7-
y_var_names: [thetaa, hus, clw] # calibration variables clw
3+
n_iterations: 8
4+
batch_size: 5 # number of cases per iteration
5+
# model_config : model_config_prognostic.yml # options {model_config_prognostic.yml, model_config_diagnostic.yml}
6+
model_config : model_config_prognostic_impl.yml
7+
output_dir : /central/scratch/cchristo/debug/exp16
8+
9+
# Slurm resource configuration
10+
slurm_time: "02:00:00"
11+
slurm_mem_per_cpu: "25G"
12+
slurm_cpus_per_task: 1
13+
14+
y_var_names: [thetaa, hus, clw] # calibration variables clw clw]
815
log_vars: ["clw"] # take log(var) when forming y, g
9-
z_max : 4000 # spatial subsetting: use statistics from [0, z_max] (in [m]) for calibration
10-
dims_per_var : 29 # num dimensions per variable (num cells in vertical profile below z_max)
16+
# log_vars: []
17+
18+
nice_loc_ug: 0.01
19+
nice_loc_gg: 0.5
20+
21+
z_max: null
22+
z_cal_grid: # calibration grid (stretch-grid parameters). In general, `z_elem` should be the same for all types
23+
shallow:
24+
z_max: 4000.0
25+
z_elem: 30
26+
dz_bottom: 30
27+
deep:
28+
z_max: 15000.0
29+
z_elem: 30
30+
dz_bottom: 30
31+
dims_per_var : 30 # num dimensions per variable (num cells in vertical profile below z_max)
1132
# eki_timestep: 0.1 # timestep of eki, if using default
12-
y_t_start_sec : 475200.0 # start time of LES averaging window [s] : 5.5 days
13-
y_t_end_sec : 518400.0 # end time of LES averaging window [s] : 6 days (LES length = 6 days)
14-
g_t_start_sec : 216000.0 # start time of SCM averaging window [s] : 2.5 days
15-
g_t_end_sec : 259200.0 # end time of SCM averaging window [s] : 3 days (SCM length = 3 days)
33+
34+
y_t_start_sec: # start time of LES averaging window [s]
35+
shallow: 475200.0 # 5.5 days
36+
deep: 302400.0 # 3.5 days
37+
y_t_end_sec: # end time of LES averaging window [s]
38+
shallow: 518400.0 # 6 days (LES length = 6 days)
39+
deep: 345600.0 # 4 days (LES length = 4 days)
40+
g_t_start_sec: 216000.0 # start time of SCM averaging window [s] : 2.5 days
41+
g_t_end_sec: 259200.0 # end time of SCM averaging window [s] : 3 days (SCM length = 3 days)
1642

1743
norm_factors_by_var:
18-
thetaa: [298.828, 8.617]
19-
hus: [0.00676, 0.00423]
20-
clw: [-9.808, 3.116] # log norm factors
44+
thetaa: [301.218, 15.235]
45+
hus: [0.00672, 0.00477]
46+
clw: [-9.579, 3.164] # log norm factors
47+
# cli: [-11.697, 1.304] # log norm factors
2148

2249
const_noise_by_var:
23-
thetaa: 0.00005
24-
hus: 0.00005
25-
clw: 0.00005
50+
thetaa: 0.0016
51+
hus: 0.0016
52+
clw: 0.0045
53+
# clw: 0.0016
54+
# cli: 0.01
55+
56+
pretrained_nn_path: "/home/cchristo/ml_mixing_length/nn_666p_leaky_relu.jld2"
57+
58+
# Config files for deep and shallow cases
59+
forcing_toml_files:
60+
shallow: "scm_tomls/gcmdriven_relaxation_shallow_forcing.toml"
61+
deep: "scm_tomls/gcmdriven_relaxation_deep_forcing.toml"

calibration/experiments/gcm_driven_scm/get_les_metadata.jl

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,22 @@ using Glob
33
"""
44
"""
55

6+
# cfSite numbers
7+
CFSITE_TYPES = Dict(
8+
"shallow" => (collect(4:15)..., collect(17:23)...),
9+
"deep" =>
10+
(collect(30:33)..., collect(66:70)..., 82, 92, 94, 96, 99, 100),
11+
)
12+
613
function get_les_calibration_library()
714
les_library = get_shallow_LES_library()
8-
# AMIP4K data: July, NE Pacific
9-
cfsite_numbers = (17, 23)
15+
# AMIP data: July, NE Pacific
16+
# cfsite_numbers = (17, 18, 22, 23, 30, 94)
17+
# cfsite_numbers = (17, 22, 23, 30, 33, 94)
18+
cfsite_numbers = (17, 21, 23, 30, 33)# 94)
19+
# cfsite_numbers = (30, 33,)# 94)
20+
21+
# cfsite_numbers = (17, 30,)# 94)
1022
les_kwargs = (forcing_model = "HadGEM2-A", month = 7, experiment = "amip")
1123
ref_paths = [
1224
get_stats_path(get_cfsite_les_dir(cfsite_number; les_kwargs...)) for
@@ -15,6 +27,20 @@ function get_les_calibration_library()
1527
return (ref_paths, cfsite_numbers)
1628
end
1729

30+
function get_cfsite_type(i, cfsite_numbers)
31+
return get_cfsite_type(cfsite_numbers[i])
32+
end
33+
34+
function get_cfsite_type(cfsite_number::Int)
35+
if cfsite_number in CFSITE_TYPES["shallow"]
36+
return "shallow"
37+
elseif cfsite_number in CFSITE_TYPES["deep"]
38+
return "deep"
39+
else
40+
@error "cfSite number $(cfsite_number) not found in available sites."
41+
end
42+
end
43+
1844
"""
1945
get_LES_library
2046
@@ -25,7 +51,18 @@ and experiments.
2551
"""
2652
function get_LES_library()
2753
LES_library = get_shallow_LES_library()
28-
deep_sites = (collect(30:33)..., collect(66:70)..., 82, 92, 94, 96, 99, 100)
54+
deep_sites = deepcopy(CFSITE_TYPES["deep"])
55+
56+
57+
# remove <0 ql/cli cases
58+
# sites_07 = deepcopy(setdiff(deep_sites, [92, 99, 100]))
59+
# append!(LES_library["HadGEM2-A"]["07"]["cfsite_numbers"], sites_07)
60+
# sites_01 = deepcopy(setdiff(deep_sites, [99,]))
61+
# append!(LES_library["HadGEM2-A"]["01"]["cfsite_numbers"], sites_01)
62+
# sites_04 = deepcopy(setdiff(deep_sites, [32, 92, 94, 96, 99, 100]))
63+
# append!(LES_library["HadGEM2-A"]["04"]["cfsite_numbers"], sites_04)
64+
# sites_10 = deepcopy(setdiff(deep_sites, [92, 94, 99, 100]))
65+
# append!(LES_library["HadGEM2-A"]["10"]["cfsite_numbers"], sites_10)
2966

3067
append!(LES_library["HadGEM2-A"]["07"]["cfsite_numbers"], deep_sites)
3168
append!(LES_library["HadGEM2-A"]["01"]["cfsite_numbers"], deep_sites)
@@ -34,6 +71,7 @@ function get_LES_library()
3471
sites_10 = deepcopy(setdiff(deep_sites, [94, 100]))
3572
append!(LES_library["HadGEM2-A"]["10"]["cfsite_numbers"], sites_10)
3673

74+
3775
LES_library_full = deepcopy(LES_library)
3876
for model in keys(LES_library_full)
3977
for month in keys(LES_library_full[model])
@@ -103,8 +141,7 @@ function get_shallow_LES_library()
103141
"CNRM-CM5" => Dict(),
104142
"CNRM-CM6-1" => Dict(),
105143
)
106-
Shen_et_al_sites = collect(4:15)
107-
append!(Shen_et_al_sites, collect(17:23))
144+
Shen_et_al_sites = collect(deepcopy(CFSITE_TYPES["shallow"]))
108145

109146
# HadGEM2-A model (76 AMIP-AMIP4K pairs)
110147
LES_library["HadGEM2-A"]["10"] = Dict()

0 commit comments

Comments
 (0)