Skip to content

GCM-Driven SCM Calibration Pipeline: v1 #3765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions calibration/experiments/gcm_driven_scm/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Expand All @@ -19,5 +20,5 @@ NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
ClimaCalibrate = "=0.0.3"
EnsembleKalmanProcesses = "2"
EnsembleKalmanProcesses = "=2.1.2"
ClimaCalibrate = "=0.0.13"
32 changes: 30 additions & 2 deletions calibration/experiments/gcm_driven_scm/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
# Overview of Calibration Pipeline for GCM-Driven Single Column Model (EDMF)

This setup provide tools for calibrating both prognostic and diagnostic EDMF variants to LES profiles, given the same forcings and boundary conditions. The gcm-driven EDMF setup is employed in single-column mode, which uses both interactive radiation and surface fluxes. Forcing profiles include resolved eddy advection, horizontal advection, subsidence, and GCM state relaxation. The setup is run to the top of the atmosphere to compute radiation, but calibrations statistics are computed only on the lower 4km (`z_max`), where LES output is available.
## Pipeline Components

### Configuration Files
- `experiment_config.yml` - Configuration of calibration settings
- Defines spatiotemporal calibration window
- Specifies required pipeline file paths
- Controls batch processing parameters

- `model_config_**.yml` - Configuration for ClimaAtmos single column model
- Defines model-specific parameters
- Allows for multiple model configuration variants

## Best Practices
- Ensure `batch_size` matches available LES configurations
- Verify normalization factors for each variable
- Monitor ensemble convergence using provided plotting tools

This setup provide tools for calibrating both prognostic and diagnostic EDMF variants to LES profiles, given the same forcings and boundary conditions. The gcm-driven EDMF setup is employed in single-column mode, which uses both interactive radiation and surface fluxes. Forcing profiles include resolved eddy advection, horizontal advection, subsidence, and GCM state relaxation. The setup is run to the top of the atmosphere to compute radiation, but calibration statistics are only computed on the calibration grid `z_cal_grid`.

LES profiles are available for different geolocations ("cfsites"), spanning seasons, forcing host models, and climates (AMIP, AMIP4K). A given LES simulation is referred to as a "configuration". Calibrations employ batching by default and stack multiple configurations (a number equal to the `batch_size`) in a given iteration. The observation vector for a single configuration is formed by concatenating profiles across calibration variables, where each geophysical variable is normalized to have approximately unit variance and zero mean. These variable-by-variable normalization factors are precomputed (`norm_factors_dict`) and applied to all observations. Following this operation, the spatiotemporal calibration window is applied and temporal means are computed to form the observation vector `y`. Because variables are normalized to have 0 mean and unit variance, a constant diagonal noise matrix is used (configurable as `const_noise`).


### Observation Map
1. **Time-mean**: Time-mean of profiles taken between [`y_t_start_sec`, `y_t_end_sec`] for `y` and [`g_t_start_sec`, `g_t_end_sec`] for `G`.
2. **Interpolation**: Case-specific (i.e., "shallow", "deep") interpolation to the calibration grid, defined with stretch-grid parameters in `z_cal_grid`.
3. **Normalization**: Variable-specific normalization using the mean and standard deviation defined in `norm_factors_by_var`. Optionally, take log of variables using `log_vars` before normalization.
4. **Concatenation**: Stack across cases in a batch, forming `y`, `G`.

## Getting Started

### Define calibration and model configurations:
Expand All @@ -20,6 +43,11 @@ LES profiles are available for different geolocations ("cfsites"), spanning seas
### Analyze output with:
- `julia --project plot_ensemble.jl` - plots vertical profiles of all ensemble members in a given iteration, given path to calibration output
- `julia --project edmf_ensemble_stats.jl` - computes and plots metrics offline [i.e., root mean squared error (RMSE)] as a function of iteration, given path to calibration output.
- `julia --project plot_eki.jl` - plot eki metrics [loss, var-weighted loss] and `y`, `g` vectors vs iteration, display best particles
- `julia --project plot_eki.jl` - plot eki metrics [loss, variance-weighted loss] and `y`, `G` vectors vs iteration, display best particles

## Troubleshooting

- **Memory Issues**: If you encounter out of memory errors, increase the memory allocation in both `run_calibration.sbatch` and the `experiment_config.yml` file. This is particularly important when working with larger batch sizes. Example error message:
```
srun: error: hpc-92-10: task 9: Out Of Memory
```
57 changes: 39 additions & 18 deletions calibration/experiments/gcm_driven_scm/edmf_ensemble_stats.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#!/usr/bin/env julia

import ClimaComms
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends

using ArgParse
using Distributed
addprocs()

addprocs(1)

@everywhere begin
using EnsembleKalmanProcesses: TOMLInterface
Expand Down Expand Up @@ -60,6 +56,13 @@ function parse_args()
return parse_with_settings(s)
end

@everywhere function validate_ensemble_member(iteration_dir, batch_size)
config_dirs =
filter(x -> isdir(joinpath(iteration_dir, x)), readdir(iteration_dir))
num_configs = count(x -> startswith(x, "config_"), config_dirs)
return num_configs == batch_size
end

function main()
args = parse_args()

Expand Down Expand Up @@ -87,6 +90,7 @@ function main()
cal_vars = config_dict["y_var_names"]
const_noise_by_var = config_dict["const_noise_by_var"]
n_iterations = config_dict["n_iterations"]
batch_size = config_dict["batch_size"]
model_config_dict =
YAML.load_file(joinpath(output_dir, "configs", "model_config.yml"))

Expand All @@ -95,9 +99,6 @@ function main()
end

ref_paths, _ = get_les_calibration_library()
comms_ctx = ClimaComms.SingletonCommsContext()
atmos_config = CA.AtmosConfig(model_config_dict; comms_ctx)
zc_model = get_z_grid(atmos_config, z_max = z_max)

@everywhere function calculate_statistics(y_var)
non_nan_values = y_var[.!isnan.(y_var)]
Expand All @@ -124,9 +125,9 @@ function main()
cal_vars,
const_noise_by_var,
ref_paths,
zc_model,
reduction,
ensemble_size,
batch_size,
)
println("Processing Iteration: $iteration")
stats_df = DataFrame(
Expand All @@ -141,13 +142,25 @@ function main()
rmse_std = Union{Missing, Float64}[],
)
config_indices = get_batch_indicies_in_iteration(iteration, output_dir)
iteration_dir =
joinpath(output_dir, "iteration_$(lpad(iteration, 3, '0'))")

valid_ensemble_members = filter(
config_i -> validate_ensemble_member(
joinpath(iteration_dir, "member_$(lpad(config_i, 3, '0'))"),
batch_size,
),
config_indices,
)

for var_name in var_names
means = Float64[]
maxs = Float64[]
mins = Float64[]
sum_squared_errors = zeros(Float64, ensemble_size)
for config_i in config_indices
data = ensemble_data(

for config_i in valid_ensemble_members
data, zc_model = ensemble_data(
process_profile_variable,
iteration,
config_i,
Expand All @@ -157,6 +170,7 @@ function main()
output_dir = output_dir,
z_max = z_max,
n_vert_levels = n_vert_levels,
return_z_interp = true,
)
for i in 1:size(data, 2)
y_var = data[:, i]
Expand All @@ -166,25 +180,32 @@ function main()
push!(mins, col_min)
end
if in(var_name, cal_vars)
ref_path = ref_paths[config_i]
cfsite_number, _, _, _ = parse_les_path(ref_path)
forcing_type = get_cfsite_type(cfsite_number)

ti = config_dict["y_t_start_sec"]
ti = isa(ti, AbstractFloat) ? ti : ti[forcing_type]
tf = config_dict["y_t_end_sec"]
tf = isa(tf, AbstractFloat) ? tf : tf[forcing_type]

y_true, Σ_obs, norm_vec_obs = get_obs(
ref_paths[config_i],
ref_path,
[var_name],
zc_model;
ti = config_dict["y_t_start_sec"],
tf = config_dict["y_t_end_sec"],
ti = ti,
tf = tf,
Σ_const = const_noise_by_var,
z_score_norm = false,
)
sum_squared_errors +=
compute_ensemble_squared_error(data, y_true)
end
end

if in(var_name, cal_vars)
# Compute RMSE per ensemble member
rmse_per_member = sqrt.(sum_squared_errors / n_vert_levels)
# Filter out NaNs (failed simulations)
valid_rmse = rmse_per_member[.!isnan.(rmse_per_member)]
non_nan_simulation_count = length(valid_rmse)
mean_rmse = mean(valid_rmse)
min_rmse = minimum(valid_rmse)
max_rmse = maximum(valid_rmse)
Expand Down Expand Up @@ -226,9 +247,9 @@ function main()
cal_vars,
const_noise_by_var,
ref_paths,
zc_model,
reduction,
ensemble_size,
batch_size,
),
iterations_list,
)
Expand Down
72 changes: 54 additions & 18 deletions calibration/experiments/gcm_driven_scm/experiment_config.yml
Original file line number Diff line number Diff line change
@@ -1,25 +1,61 @@
prior_path: prior_prognostic_pi_entr.toml
prior_path: prior_prognostic_pi_entr_smooth_entr_detr_impl_0M_v1.toml
ensemble_size: 100
n_iterations: 12
batch_size: 2 # number of cases per iteration
model_config : model_config_prognostic.yml # options {model_config_prognostic.yml, model_config_diagnostic.yml}
output_dir : output/exp_1 # output dir
y_var_names: [thetaa, hus, clw] # calibration variables clw
n_iterations: 8
batch_size: 5 # number of cases per iteration
# model_config : model_config_prognostic.yml # options {model_config_prognostic.yml, model_config_diagnostic.yml}
model_config : model_config_prognostic_impl.yml
output_dir : /central/scratch/cchristo/debug/exp16

# Slurm resource configuration
slurm_time: "02:00:00"
slurm_mem_per_cpu: "25G"
slurm_cpus_per_task: 1

y_var_names: [thetaa, hus, clw] # calibration variables clw clw]
log_vars: ["clw"] # take log(var) when forming y, g
z_max : 4000 # spatial subsetting: use statistics from [0, z_max] (in [m]) for calibration
dims_per_var : 29 # num dimensions per variable (num cells in vertical profile below z_max)
# log_vars: []

nice_loc_ug: 0.01
nice_loc_gg: 0.5

z_max: null
z_cal_grid: # calibration grid (stretch-grid parameters). In general, `z_elem` should be the same for all types
shallow:
z_max: 4000.0
z_elem: 30
dz_bottom: 30
deep:
z_max: 15000.0
z_elem: 30
dz_bottom: 30
dims_per_var : 30 # num dimensions per variable (num cells in vertical profile below z_max)
# eki_timestep: 0.1 # timestep of eki, if using default
y_t_start_sec : 475200.0 # start time of LES averaging window [s] : 5.5 days
y_t_end_sec : 518400.0 # end time of LES averaging window [s] : 6 days (LES length = 6 days)
g_t_start_sec : 216000.0 # start time of SCM averaging window [s] : 2.5 days
g_t_end_sec : 259200.0 # end time of SCM averaging window [s] : 3 days (SCM length = 3 days)

y_t_start_sec: # start time of LES averaging window [s]
shallow: 475200.0 # 5.5 days
deep: 302400.0 # 3.5 days
y_t_end_sec: # end time of LES averaging window [s]
shallow: 518400.0 # 6 days (LES length = 6 days)
deep: 345600.0 # 4 days (LES length = 4 days)
g_t_start_sec: 216000.0 # start time of SCM averaging window [s] : 2.5 days
g_t_end_sec: 259200.0 # end time of SCM averaging window [s] : 3 days (SCM length = 3 days)

norm_factors_by_var:
thetaa: [298.828, 8.617]
hus: [0.00676, 0.00423]
clw: [-9.808, 3.116] # log norm factors
thetaa: [301.218, 15.235]
hus: [0.00672, 0.00477]
clw: [-9.579, 3.164] # log norm factors
# cli: [-11.697, 1.304] # log norm factors

const_noise_by_var:
thetaa: 0.00005
hus: 0.00005
clw: 0.00005
thetaa: 0.0016
hus: 0.0016
clw: 0.0045
# clw: 0.0016
# cli: 0.01

pretrained_nn_path: "/home/cchristo/ml_mixing_length/nn_666p_leaky_relu.jld2"

# Config files for deep and shallow cases
forcing_toml_files:
shallow: "scm_tomls/gcmdriven_relaxation_shallow_forcing.toml"
deep: "scm_tomls/gcmdriven_relaxation_deep_forcing.toml"
47 changes: 42 additions & 5 deletions calibration/experiments/gcm_driven_scm/get_les_metadata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@ using Glob
"""
"""

# cfSite numbers
CFSITE_TYPES = Dict(
"shallow" => (collect(4:15)..., collect(17:23)...),
"deep" =>
(collect(30:33)..., collect(66:70)..., 82, 92, 94, 96, 99, 100),
)

function get_les_calibration_library()
les_library = get_shallow_LES_library()
# AMIP4K data: July, NE Pacific
cfsite_numbers = (17, 23)
# AMIP data: July, NE Pacific
# cfsite_numbers = (17, 18, 22, 23, 30, 94)
# cfsite_numbers = (17, 22, 23, 30, 33, 94)
cfsite_numbers = (17, 21, 23, 30, 33)# 94)
# cfsite_numbers = (30, 33,)# 94)

# cfsite_numbers = (17, 30,)# 94)
les_kwargs = (forcing_model = "HadGEM2-A", month = 7, experiment = "amip")
ref_paths = [
get_stats_path(get_cfsite_les_dir(cfsite_number; les_kwargs...)) for
Expand All @@ -15,6 +27,20 @@ function get_les_calibration_library()
return (ref_paths, cfsite_numbers)
end

function get_cfsite_type(i, cfsite_numbers)
return get_cfsite_type(cfsite_numbers[i])
end

function get_cfsite_type(cfsite_number::Int)
if cfsite_number in CFSITE_TYPES["shallow"]
return "shallow"
elseif cfsite_number in CFSITE_TYPES["deep"]
return "deep"
else
@error "cfSite number $(cfsite_number) not found in available sites."
end
end

"""
get_LES_library
Expand All @@ -25,7 +51,18 @@ and experiments.
"""
function get_LES_library()
LES_library = get_shallow_LES_library()
deep_sites = (collect(30:33)..., collect(66:70)..., 82, 92, 94, 96, 99, 100)
deep_sites = deepcopy(CFSITE_TYPES["deep"])


# remove <0 ql/cli cases
# sites_07 = deepcopy(setdiff(deep_sites, [92, 99, 100]))
# append!(LES_library["HadGEM2-A"]["07"]["cfsite_numbers"], sites_07)
# sites_01 = deepcopy(setdiff(deep_sites, [99,]))
# append!(LES_library["HadGEM2-A"]["01"]["cfsite_numbers"], sites_01)
# sites_04 = deepcopy(setdiff(deep_sites, [32, 92, 94, 96, 99, 100]))
# append!(LES_library["HadGEM2-A"]["04"]["cfsite_numbers"], sites_04)
# sites_10 = deepcopy(setdiff(deep_sites, [92, 94, 99, 100]))
# append!(LES_library["HadGEM2-A"]["10"]["cfsite_numbers"], sites_10)

append!(LES_library["HadGEM2-A"]["07"]["cfsite_numbers"], deep_sites)
append!(LES_library["HadGEM2-A"]["01"]["cfsite_numbers"], deep_sites)
Expand All @@ -34,6 +71,7 @@ function get_LES_library()
sites_10 = deepcopy(setdiff(deep_sites, [94, 100]))
append!(LES_library["HadGEM2-A"]["10"]["cfsite_numbers"], sites_10)


LES_library_full = deepcopy(LES_library)
for model in keys(LES_library_full)
for month in keys(LES_library_full[model])
Expand Down Expand Up @@ -103,8 +141,7 @@ function get_shallow_LES_library()
"CNRM-CM5" => Dict(),
"CNRM-CM6-1" => Dict(),
)
Shen_et_al_sites = collect(4:15)
append!(Shen_et_al_sites, collect(17:23))
Shen_et_al_sites = collect(deepcopy(CFSITE_TYPES["shallow"]))

# HadGEM2-A model (76 AMIP-AMIP4K pairs)
LES_library["HadGEM2-A"]["10"] = Dict()
Expand Down
Loading
Loading