Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/plotting/colormap_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@
from matplotlib import pyplot as plt
import warnings
from .colormap_loader import load_ncl_colormap

from matplotlib.colors import BoundaryNorm
import numpy as np

def _fallback():
warnings.warn("No colormap found for this parameter, using fallback.", UserWarning)
return {"cmap": plt.get_cmap("viridis"), "norm": None, "units": ""}

def symmetric_boundary_norm(nlevels):
"""
Returns a callable that creates a symmetric BoundaryNorm
around zero with `nlevels` discrete colors. Used for creating colormaps for bias.
"""
def _norm(data):
vmax = np.nanmax(np.abs(data))
boundaries = np.linspace(-vmax, vmax, nlevels + 1)
return BoundaryNorm(boundaries=boundaries, ncolors=nlevels)
return _norm

_CMAP_DEFAULTS = {
"SP": {"cmap": plt.get_cmap("coolwarm", 11), "vmin": 800 * 100, "vmax": 1100 * 100},
Expand Down Expand Up @@ -44,6 +55,41 @@ def _fallback():
"units": "mm",
"levels": [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100],
},

# hard-code this for the moment, can still make smarter later on:
# RMSE and MAE first (is all the same). Sequential colour map to reflect the nature of the data (error, all positive).
# Red is suggestive of 'bad' (high error).
# Use a limited number of levels so that absolute values of error can be read from the map.
# always start at 0 so that the saturation of the colour corresponds to the error magnitude.

# RMSE:
"U_10M.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "m/s"},
"V_10M.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "m/s"},
"TD_2M.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "°C"},
"T_2M.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "°C"},
"PMSL.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "Pa"},
"PS.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "Pa"},
"TOT_PREC.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "mm"},

# MAE:
"U_10M.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "m/s"},
"V_10M.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "m/s"},
"TD_2M.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "°C"},
"T_2M.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "°C"},
"PMSL.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "Pa"},
"PS.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "Pa"},
"TOT_PREC.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "mm"},

# Bias:
# diverging colour scheme for the Bias to reflect the nature of the data (can be positive or negative, symmetric).
# Red-Blue colour scheme for all variables except precipitation, where a Brown-Green scheme is more suggestive.
"U_10M.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "m/s"},
"V_10M.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "m/s"},
"TD_2M.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "°C"},
"T_2M.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "°C"},
"PMSL.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "Pa"},
"PS.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "Pa"},
"TOT_PREC.BIAS.spatial": {"cmap": plt.get_cmap("BrBG", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "mm"}
}

CMAP_DEFAULTS = defaultdict(_fallback, _CMAP_DEFAULTS)
77 changes: 77 additions & 0 deletions src/plotting/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import geopandas as gpd
import numpy as np
import pandas as pd
import xarray as xr
from meteodatalab import data_source
from meteodatalab import grib_decoder
from shapely.geometry import MultiPoint
Expand Down Expand Up @@ -90,3 +91,79 @@ def load_state_from_raw(
if key.startswith("field_"):
state["fields"][key.removeprefix("field_")] = value
return state


def load_state_from_netcdf(
file: Path,
paramlist: list[str],
*,
season: str = "all",
init_hour: int = -999,
lead_time: int | None = None, # hours
) -> dict:
"""
NetCDF analogue of load_state_from_grib(), restricted to spatial variables.
"""

ds = xr.open_dataset(file)

# --- normalize lead_time to hours (float) ---
if ds["lead_time"].dtype.kind == "m":
ds = ds.assign_coords(
lead_time=ds["lead_time"].dt.total_seconds() / 3600
)

# --- select season / init_hour / lead_time ---
ds = ds.sel(season=season, init_hour=init_hour)
if lead_time is not None:
ds = ds.sel(lead_time=lead_time)

# --- infer reference + valid time ---
# Assumption: forecast_reference_time is not explicitly stored
# We reconstruct something consistent with GRIB usage
forecast_reference_time = None
valid_time = None
if lead_time is not None:
valid_time = pd.to_datetime(lead_time, unit="h", origin="unix")

# --- get lat / lon (assumed present as coordinates) ---
lat = ds["lat"].values if "lat" in ds.coords else ds["latitude"].values
lon = ds["lon"].values if "lon" in ds.coords else ds["longitude"].values

lon2d, lat2d = np.meshgrid(lon, lat)
lats = lat2d.flatten()
lons = lon2d.flatten()

state = {
"forecast_reference_time": forecast_reference_time,
"valid_time": valid_time,
"latitudes": lats,
"longitudes": lons,
"fields": {},
}

# --- LAM envelope (convex hull) ---
lam_hull = MultiPoint(list(zip(lons.tolist(), lats.tolist()))).convex_hull
state["lam_envelope"] = gpd.GeoSeries([lam_hull], crs="EPSG:4326")

# --- extract spatial fields ---
for param in paramlist:
# e.g. U_10M.MAE.spatial
matching_vars = [
v for v in ds.data_vars
if v.startswith(f"{param}.") and v.endswith(".spatial")
]

if not matching_vars:
state["fields"][param] = np.full(lats.size, np.nan, dtype=float)
continue

# If multiple metrics exist, concatenate them
arrays = []
for var in matching_vars:
arr = ds[var].values # (lead_time, y, x) or (y, x)
arrays.append(arr.reshape(-1))

state["fields"][param] = np.concatenate(arrays)

return state
42 changes: 30 additions & 12 deletions src/verification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,26 @@ def _compute_scores(
Returns a xarray Dataset with the computed metrics.
"""
error = fcst - obs
scores = xr.Dataset(
{
f"{prefix}BIAS{suffix}": error.mean(dim=dim, skipna=True),
f"{prefix}MSE{suffix}": (error**2).mean(dim=dim, skipna=True),
f"{prefix}MAE{suffix}": abs(error).mean(dim=dim, skipna=True),
f"{prefix}VAR{suffix}": error.var(dim=dim, skipna=True),
f"{prefix}CORR{suffix}": xr.corr(fcst, obs, dim=dim),
f"{prefix}R2{suffix}": xr.corr(fcst, obs, dim=dim) ** 2,
}
)
scores = scores.expand_dims({"source": [source]})
if dim == []:
scores = xr.Dataset(
{
f"{prefix}BIAS{suffix}": error,
f"{prefix}MSE{suffix}": (error**2),
f"{prefix}MAE{suffix}": abs(error),
}
)
else:
scores = xr.Dataset(
{
f"{prefix}BIAS{suffix}": error.mean(dim=dim, skipna=True),
f"{prefix}MSE{suffix}": (error**2).mean(dim=dim, skipna=True),
f"{prefix}MAE{suffix}": abs(error).mean(dim=dim, skipna=True),
f"{prefix}VAR{suffix}": error.var(dim=dim, skipna=True),
f"{prefix}CORR{suffix}": xr.corr(fcst, obs, dim=dim),
f"{prefix}R2{suffix}": xr.corr(fcst, obs, dim=dim) ** 2,
}
)
# scores = scores.expand_dims({"source": [source]})
return scores


Expand Down Expand Up @@ -198,8 +207,17 @@ def verify(
score = xr.concat(score, dim="region")
fcst_statistics = xr.concat(fcst_statistics, dim="region")
obs_statistics = xr.concat(obs_statistics, dim="region")
score_spatial = _compute_scores(
fcst_aligned[param],
obs_aligned[param],
prefix=param + ".",
suffix=".spatial",
dim=[],
)
statistics.append(xr.concat([fcst_statistics, obs_statistics], dim="source"))
scores.append(score)
scores.append(
xr.merge([score, score_spatial], join="outer", compat="no_conflicts")
)

scores = _merge_metrics(scores)
statistics = _merge_metrics(statistics)
Expand Down
31 changes: 31 additions & 0 deletions workflow/rules/plot.smk
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,34 @@ rule make_forecast_animation:
"""
convert -delay {params.delay} -loop 0 {input} {output}
"""


rule plot_summary_stat_maps:
input:
script="workflow/scripts/plot_summary_stat_maps.mo.py",
inference_okfile=rules.execute_inference.output.okfile,
output:
OUT_ROOT / "results/summary_stats/maps/{run_id}/{leadtime}/{metric}_{param}_{region}.png",
wildcard_constraints:
leadtime=r"\d+", # only digits
resources:
slurm_partition="postproc",
cpus_per_task=1,
runtime="10m",
params:
nc_out_dir=lambda wc: (
Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib"
# not sure how to do this, because the baselines are in, e.g., output/data/baselines/COSMO-E/verif_aggregated.nc
# and the runs are in output/data/runs/runID/verif_aggregated.nc
).resolve(),
shell:
"""
export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions)
python {input.script} \
--input {params.nc_out_dir} --date {wildcards.init_time} --outfn {output[0]} \
--param {wildcards.param} --leadtime {wildcards.leadtime} --region {wildcards.region} \
# interactive editing (needs to set localrule: True and use only one core)
# marimo edit {input.script} -- \
# --input {params.grib_out_dir} --date {wildcards.init_time} --outfn {output[0]}\
# --param {wildcards.param} --leadtime {wildcards.leadtime} --region {wildcards.region}\
"""
4 changes: 2 additions & 2 deletions workflow/rules/verif.smk
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ rule verif_metrics_baseline:
analysis_label=config["analysis"].get("label"),
regions=REGION_TXT,
output:
OUT_ROOT / "data/baselines/{baseline_id}/{init_time}/verif.nc",
temp(OUT_ROOT / "data/baselines/{baseline_id}/{init_time}/verif.nc"),
log:
OUT_ROOT / "logs/verif_metrics_baseline/{baseline_id}-{init_time}.log",
resources:
Expand Down Expand Up @@ -63,7 +63,7 @@ rule verif_metrics:
inference_okfile=rules.execute_inference.output.okfile,
analysis_zarr=config["analysis"].get("analysis_zarr"),
output:
OUT_ROOT / "data/runs/{run_id}/{init_time}/verif.nc",
temp(OUT_ROOT / "data/runs/{run_id}/{init_time}/verif.nc"),
# wildcard_constraints:
# run_id="^" # to avoid ambiguitiy with run_baseline_verif
# TODO: implement logic to use experiment name instead of run_id as wildcard
Expand Down
Loading
Loading