diff --git a/src/plotting/colormap_defaults.py b/src/plotting/colormap_defaults.py index 52ff050..5e22f72 100644 --- a/src/plotting/colormap_defaults.py +++ b/src/plotting/colormap_defaults.py @@ -4,12 +4,23 @@ from matplotlib import pyplot as plt import warnings from .colormap_loader import load_ncl_colormap - +from matplotlib.colors import BoundaryNorm +import numpy as np def _fallback(): warnings.warn("No colormap found for this parameter, using fallback.", UserWarning) return {"cmap": plt.get_cmap("viridis"), "norm": None, "units": ""} +def symmetric_boundary_norm(nlevels): + """ + Returns a callable that creates a symmetric BoundaryNorm + around zero with `nlevels` discrete colors. Used for creating colormaps for bias. + """ + def _norm(data): + vmax = np.nanmax(np.abs(data)) + boundaries = np.linspace(-vmax, vmax, nlevels + 1) + return BoundaryNorm(boundaries=boundaries, ncolors=nlevels) + return _norm _CMAP_DEFAULTS = { "SP": {"cmap": plt.get_cmap("coolwarm", 11), "vmin": 800 * 100, "vmax": 1100 * 100}, @@ -44,6 +55,41 @@ def _fallback(): "units": "mm", "levels": [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100], }, + + # hard-code this for the moment, can still make smarter later on: + # RMSE and MAE first (is all the same). Sequential colour map to reflect the nature of the data (error, all positive). + # Red is suggestive of 'bad' (high error). + # Use a limited number of levels so that absolute values of error can be read from the map. + # always start at 0 so that the saturation of the colour corresponds to the error magnitude. + + # RMSE: + "U_10M.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "m/s"}, + "V_10M.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "m/s"}, + "TD_2M.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "°C"}, + "T_2M.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "°C"}, + "PMSL.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "Pa"}, + "PS.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "Pa"}, + "TOT_PREC.RMSE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "mm"}, + + # MAE: + "U_10M.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "m/s"}, + "V_10M.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "m/s"}, + "TD_2M.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "°C"}, + "T_2M.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "°C"}, + "PMSL.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "Pa"}, + "PS.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "Pa"}, + "TOT_PREC.MAE.spatial": {"cmap": plt.get_cmap("Reds", 11), "vmin": 0} | {"units": "mm"}, + + # Bias: + # diverging colour scheme for the Bias to reflect the nature of the data (can be positive or negative, symmetric). + # Red-Blue colour scheme for all variables except precipitation, where a Brown-Green scheme is more suggestive. + "U_10M.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "m/s"}, + "V_10M.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "m/s"}, + "TD_2M.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "°C"}, + "T_2M.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "°C"}, + "PMSL.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "Pa"}, + "PS.BIAS.spatial": {"cmap": plt.get_cmap("RdBu", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "Pa"}, + "TOT_PREC.BIAS.spatial": {"cmap": plt.get_cmap("BrBG", 11), "norm": symmetric_boundary_norm(nlevels=11)} | {"units": "mm"} } CMAP_DEFAULTS = defaultdict(_fallback, _CMAP_DEFAULTS) diff --git a/src/plotting/compat.py b/src/plotting/compat.py index 665287e..06609d1 100644 --- a/src/plotting/compat.py +++ b/src/plotting/compat.py @@ -5,6 +5,7 @@ import geopandas as gpd import numpy as np import pandas as pd +import xarray as xr from meteodatalab import data_source from meteodatalab import grib_decoder from shapely.geometry import MultiPoint @@ -90,3 +91,79 @@ def load_state_from_raw( if key.startswith("field_"): state["fields"][key.removeprefix("field_")] = value return state + + +def load_state_from_netcdf( + file: Path, + paramlist: list[str], + *, + season: str = "all", + init_hour: int = -999, + lead_time: int | None = None, # hours +) -> dict: + """ + NetCDF analogue of load_state_from_grib(), restricted to spatial variables. + """ + + ds = xr.open_dataset(file) + + # --- normalize lead_time to hours (float) --- + if ds["lead_time"].dtype.kind == "m": + ds = ds.assign_coords( + lead_time=ds["lead_time"].dt.total_seconds() / 3600 + ) + + # --- select season / init_hour / lead_time --- + ds = ds.sel(season=season, init_hour=init_hour) + if lead_time is not None: + ds = ds.sel(lead_time=lead_time) + + # --- infer reference + valid time --- + # Assumption: forecast_reference_time is not explicitly stored + # We reconstruct something consistent with GRIB usage + forecast_reference_time = None + valid_time = None + if lead_time is not None: + valid_time = pd.to_datetime(lead_time, unit="h", origin="unix") + + # --- get lat / lon (assumed present as coordinates) --- + lat = ds["lat"].values if "lat" in ds.coords else ds["latitude"].values + lon = ds["lon"].values if "lon" in ds.coords else ds["longitude"].values + + lon2d, lat2d = np.meshgrid(lon, lat) + lats = lat2d.flatten() + lons = lon2d.flatten() + + state = { + "forecast_reference_time": forecast_reference_time, + "valid_time": valid_time, + "latitudes": lats, + "longitudes": lons, + "fields": {}, + } + + # --- LAM envelope (convex hull) --- + lam_hull = MultiPoint(list(zip(lons.tolist(), lats.tolist()))).convex_hull + state["lam_envelope"] = gpd.GeoSeries([lam_hull], crs="EPSG:4326") + + # --- extract spatial fields --- + for param in paramlist: + # e.g. U_10M.MAE.spatial + matching_vars = [ + v for v in ds.data_vars + if v.startswith(f"{param}.") and v.endswith(".spatial") + ] + + if not matching_vars: + state["fields"][param] = np.full(lats.size, np.nan, dtype=float) + continue + + # If multiple metrics exist, concatenate them + arrays = [] + for var in matching_vars: + arr = ds[var].values # (lead_time, y, x) or (y, x) + arrays.append(arr.reshape(-1)) + + state["fields"][param] = np.concatenate(arrays) + + return state diff --git a/src/verification/__init__.py b/src/verification/__init__.py index 6273f51..2edcd8f 100644 --- a/src/verification/__init__.py +++ b/src/verification/__init__.py @@ -84,17 +84,26 @@ def _compute_scores( Returns a xarray Dataset with the computed metrics. """ error = fcst - obs - scores = xr.Dataset( - { - f"{prefix}BIAS{suffix}": error.mean(dim=dim, skipna=True), - f"{prefix}MSE{suffix}": (error**2).mean(dim=dim, skipna=True), - f"{prefix}MAE{suffix}": abs(error).mean(dim=dim, skipna=True), - f"{prefix}VAR{suffix}": error.var(dim=dim, skipna=True), - f"{prefix}CORR{suffix}": xr.corr(fcst, obs, dim=dim), - f"{prefix}R2{suffix}": xr.corr(fcst, obs, dim=dim) ** 2, - } - ) - scores = scores.expand_dims({"source": [source]}) + if dim == []: + scores = xr.Dataset( + { + f"{prefix}BIAS{suffix}": error, + f"{prefix}MSE{suffix}": (error**2), + f"{prefix}MAE{suffix}": abs(error), + } + ) + else: + scores = xr.Dataset( + { + f"{prefix}BIAS{suffix}": error.mean(dim=dim, skipna=True), + f"{prefix}MSE{suffix}": (error**2).mean(dim=dim, skipna=True), + f"{prefix}MAE{suffix}": abs(error).mean(dim=dim, skipna=True), + f"{prefix}VAR{suffix}": error.var(dim=dim, skipna=True), + f"{prefix}CORR{suffix}": xr.corr(fcst, obs, dim=dim), + f"{prefix}R2{suffix}": xr.corr(fcst, obs, dim=dim) ** 2, + } + ) + # scores = scores.expand_dims({"source": [source]}) return scores @@ -198,8 +207,17 @@ def verify( score = xr.concat(score, dim="region") fcst_statistics = xr.concat(fcst_statistics, dim="region") obs_statistics = xr.concat(obs_statistics, dim="region") + score_spatial = _compute_scores( + fcst_aligned[param], + obs_aligned[param], + prefix=param + ".", + suffix=".spatial", + dim=[], + ) statistics.append(xr.concat([fcst_statistics, obs_statistics], dim="source")) - scores.append(score) + scores.append( + xr.merge([score, score_spatial], join="outer", compat="no_conflicts") + ) scores = _merge_metrics(scores) statistics = _merge_metrics(statistics) diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index d4def1b..2014e00 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -67,3 +67,34 @@ rule make_forecast_animation: """ convert -delay {params.delay} -loop 0 {input} {output} """ + + +rule plot_summary_stat_maps: + input: + script="workflow/scripts/plot_summary_stat_maps.mo.py", + inference_okfile=rules.execute_inference.output.okfile, + output: + OUT_ROOT / "results/summary_stats/maps/{run_id}/{leadtime}/{metric}_{param}_{region}.png", + wildcard_constraints: + leadtime=r"\d+", # only digits + resources: + slurm_partition="postproc", + cpus_per_task=1, + runtime="10m", + params: + nc_out_dir=lambda wc: ( + Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" + # not sure how to do this, because the baselines are in, e.g., output/data/baselines/COSMO-E/verif_aggregated.nc + # and the runs are in output/data/runs/runID/verif_aggregated.nc + ).resolve(), + shell: + """ + export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) + python {input.script} \ + --input {params.nc_out_dir} --date {wildcards.init_time} --outfn {output[0]} \ + --param {wildcards.param} --leadtime {wildcards.leadtime} --region {wildcards.region} \ + # interactive editing (needs to set localrule: True and use only one core) + # marimo edit {input.script} -- \ + # --input {params.grib_out_dir} --date {wildcards.init_time} --outfn {output[0]}\ + # --param {wildcards.param} --leadtime {wildcards.leadtime} --region {wildcards.region}\ + """ \ No newline at end of file diff --git a/workflow/rules/verif.smk b/workflow/rules/verif.smk index cba4a30..0566ccf 100644 --- a/workflow/rules/verif.smk +++ b/workflow/rules/verif.smk @@ -27,7 +27,7 @@ rule verif_metrics_baseline: analysis_label=config["analysis"].get("label"), regions=REGION_TXT, output: - OUT_ROOT / "data/baselines/{baseline_id}/{init_time}/verif.nc", + temp(OUT_ROOT / "data/baselines/{baseline_id}/{init_time}/verif.nc"), log: OUT_ROOT / "logs/verif_metrics_baseline/{baseline_id}-{init_time}.log", resources: @@ -63,7 +63,7 @@ rule verif_metrics: inference_okfile=rules.execute_inference.output.okfile, analysis_zarr=config["analysis"].get("analysis_zarr"), output: - OUT_ROOT / "data/runs/{run_id}/{init_time}/verif.nc", + temp(OUT_ROOT / "data/runs/{run_id}/{init_time}/verif.nc"), # wildcard_constraints: # run_id="^" # to avoid ambiguitiy with run_baseline_verif # TODO: implement logic to use experiment name instead of run_id as wildcard diff --git a/workflow/scripts/plot_summary_stat_maps.mo.py b/workflow/scripts/plot_summary_stat_maps.mo.py new file mode 100644 index 0000000..27cc957 --- /dev/null +++ b/workflow/scripts/plot_summary_stat_maps.mo.py @@ -0,0 +1,255 @@ +import marimo + +__generated_with = "0.16.5" +app = marimo.App(width="medium") + + +@app.cell +def _(): + + # this sure stays the same. + import logging + from argparse import ArgumentParser + from pathlib import Path + + # this sure stays the same. + import cartopy.crs as ccrs + import earthkit.plots as ekp + import numpy as np + + # this stays the same as well. + from plotting import DOMAINS + + # no changes to StatePlotter required according to ChatGPT. + from plotting import StatePlotter + + # Added some new colour maps for the Bias / MAE / RMSE map plots. + from plotting.colormap_defaults import CMAP_DEFAULTS + + # need to load nc files. But this statement is not needed any more because + # the .nc files can just be read with xr.open_dataset + # from plotting.compat import load_state_from_grib + + return ( + ArgumentParser, + CMAP_DEFAULTS, + Path, + StatePlotter, + ekp, + # load_state_from_grib, + logging, + np, + DOMAINS, + ccrs, + ) + + +@app.cell +def _(logging): + LOG = logging.getLogger(__name__) + LOG_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + logging.basicConfig(level=logging.INFO, format=LOG_FMT) + return (LOG,) + + +@app.cell +def _(ArgumentParser, Path): + parser = ArgumentParser() + + parser.add_argument( + "--input", type=str, default=None, help="Directory to .nc data containing the error fields" + ) + # parser.add_argument("--date", type=str, default=None, help="reference datetime") # to be deleted? + parser.add_argument("--outfn", type=str, help="output filename") + parser.add_argument("--leadtime", type=str, help="leadtime") + parser.add_argument("--param", type=str, help="parameter") + parser.add_argument("--region", type=str, help="name of region") + + args = parser.parse_args() + nc_dir = Path(args.input) + # init_time = args.date # to be deleted? + outfn = Path(args.outfn) + lead_time = args.leadtime + param = args.param + region = args.region + return ( + args, + nc_dir, + # init_time, # to be deleted? + lead_time, + outfn, + param, + region, + ) + + +@app.cell +def _(nc_file, param, lead_time, load_state_from_netcdf): + # load .nc verification file: + if param == "SP_10M": + paramlist = ["U_10M", "V_10M"] + elif param == "SP": + paramlist = ["U", "V"] + else: + paramlist = [param] + + state = load_state_from_netcdf( + nc_file, + paramlist=paramlist, + lead_time=lead_time, + ) + + return (state,) + + +@app.cell +def _(CMAP_DEFAULTS, ekp): + def get_style(param, units_override=None): + """Get style and colormap settings for the plot. + Needed because cmap/norm does not work in Style(colors=cmap), + still needs to be passed as arguments to tripcolor()/tricontourf(). + """ + cfg = CMAP_DEFAULTS[param] + units = units_override if units_override is not None else cfg.get("units", "") + return { + "style": ekp.styles.Style( + levels=cfg.get("bounds", cfg.get("levels", None)), + extend="both", + units=units, + colors=cfg.get("colors", None), + ), + "norm": cfg.get("norm", None), + "cmap": cfg.get("cmap", None), + "levels": cfg.get("levels", None), + "vmin": cfg.get("vmin", None), + "vmax": cfg.get("vmax", None), + "colors": cfg.get("colors", None), + } + + return (get_style,) + + +@app.cell +def _(LOG, np): + """Preprocess fields with pint-based unit conversion and derived quantities.""" + try: + import pint # type: ignore + + _ureg = pint.UnitRegistry() + + def _k_to_c(arr): + # robust conversion with pint, fallback if dtype unsupported + try: + return (_ureg.Quantity(arr, _ureg.kelvin).to(_ureg.degC)).magnitude + except Exception: + return arr - 273.15 + + def _ms_to_knots(arr): + # robust conversion with pint, fallback if dtype unsupported + try: + return ( + _ureg.Quantity(arr, _ureg.meter / _ureg.second).to(_ureg.knot) + ).magnitude + except Exception: + return arr * 1.943844 + + def _m_to_mm(arr): + # robust conversion with pint, fallback if dtype unsupported + try: + return (_ureg.Quantity(arr, _ureg.meter).to(_ureg.millimeter)).magnitude + except Exception: + return arr * 1000 + + except Exception: + LOG.warning("pint not available; falling back hardcoded conversions") + + def _k_to_c(arr): + return arr - 273.15 + + def _ms_to_knots(arr): + return arr * 1.943844 + + def _m_to_mm(arr): + return arr * 1000 + + def preprocess_field(param: str, state: dict): + """ + - Temperatures: K -> °C + - Wind speed: sqrt(u^2 + v^2) + - Precipitation: m -> mm + Returns: (field_array, units_override or None) + """ + fields = state["fields"] + # temperature variables + if param in ("T_2M", "TD_2M", "T", "TD"): + return _k_to_c(fields[param]), "°C" + # 10m wind speed (allow legacy 'uv' alias) + if param == "SP_10M": + u = fields["U_10M"] + v = fields["V_10M"] + return np.sqrt(u**2 + v**2), "m/s" + # wind speed from standard-level components + if param == "SP": + u = fields["U"] + v = fields["V"] + return np.sqrt(u**2 + v**2), "m/s" + if param == "TOT_PREC": + return _m_to_mm(fields[param]), "mm" + # default: passthrough + return fields[param], None + + return (preprocess_field,) + + +@app.cell +def _( + LOG, + StatePlotter, + args, + get_style, + outfn, + param, + preprocess_field, + region, + state, + DOMAINS, + ccrs, +): + # plot individual fields + plotter = StatePlotter( + state["longitudes"], + state["latitudes"], + outfn.parent, + ) + fig = plotter.init_geoaxes( + nrows=1, + ncols=1, + projection=DOMAINS[region]["projection"], + bbox=DOMAINS[region]["extent"], + name=region, + size=(6, 6), + ) + subplot = fig.add_map(row=0, column=0) + + # preprocess field (unit conversion, derived quantities) + field, units_override = preprocess_field(param, state) + + plotter.plot_field(subplot, field, **get_style(args.param, units_override)) + subplot.ax.add_geometries( + state["lam_envelope"], + edgecolor="black", + facecolor="none", + crs=ccrs.PlateCarree(), + ) + validtime = state["valid_time"].strftime("%Y%m%d%H%M") + # leadtime = int(state["lead_time"].total_seconds() // 3600) + + fig.title(f"{param}, time: {validtime}") + + fig.save(outfn, bbox_inches="tight", dpi=200) + LOG.info(f"saved: {outfn}") + return + + +if __name__ == "__main__": + app.run() diff --git a/workflow/scripts/verif_single_init.py b/workflow/scripts/verif_single_init.py index 61d24a0..5ded2ba 100644 --- a/workflow/scripts/verif_single_init.py +++ b/workflow/scripts/verif_single_init.py @@ -110,6 +110,7 @@ def main(args: ScriptConfig): # compute metrics and statistics results = verify(fcst, analysis, args.label, args.analysis_label, args.regions) + LOG.info("Verification results:\n%s", results) # save results to NetCDF args.output.parent.mkdir(parents=True, exist_ok=True)