Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 34 additions & 75 deletions cmip7_prep/cache_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import xarray as xr
import numpy as np

from cmip7_prep.cmor_utils import bounds_from_centers_1d

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -48,81 +46,42 @@ def _make_dummy_grids(mapfile: Path) -> tuple[xr.Dataset, xr.Dataset]:
"""Construct minimal ds_in/ds_out satisfying xESMF when reusing weights.
Adds CF-style bounds for both lat and lon so conservative methods don’t
trigger cf-xarray’s bounds inference on size-1 dimensions."""
with open_nc(mapfile) as m:
nlon_in, nlat_in = _get_src_shape(m)
lat_out_1d, lon_out_1d = _get_dst_latlon_1d(mapfile=mapfile)

# --- Dummy INPUT grid (unstructured → represent as 2D with length-1 lat) ---
lat_in = np.arange(
-90.0, 90.0, 180.0 / nlat_in, dtype="f8"
) # e.g., [0], length can be 1
lon_in = np.arange(0.5, 360.5, 360.0 / nlon_in, dtype="f8")
ds_in = xr.Dataset(
data_vars={
"lat_bnds": (
("lat", "nbnds"),
bounds_from_centers_1d(lat_in, "lat"),
),
"lon_bnds": (
("lon", "nbnds"),
bounds_from_centers_1d(lon_in, "lon"),
),
},
coords={
"lat": (
"lat",
lat_in,
{
"units": "degrees_north",
"standard_name": "latitude",
"bounds": "lat_bnds",
},
),
"lon": (
"lon",
lon_in,
{
"units": "degrees_east",
"standard_name": "longitude",
"bounds": "lon_bnds",
},
),
"nbnds": ("nbnds", np.array([0, 1], dtype="i4")),
},
)

# --- OUTPUT grid from weights (canonical 1° lat/lon) ---
lat_out_bnds = bounds_from_centers_1d(lat_out_1d, "lat")
lon_out_bnds = bounds_from_centers_1d(lon_out_1d, "lon")

ds_out = xr.Dataset(
data_vars={
"lat_bnds": (("lat", "nbnds"), lat_out_bnds),
"lon_bnds": (("lon", "nbnds"), lon_out_bnds),
},
coords={
"lat": (
"lat",
lat_out_1d,
{
"units": "degrees_north",
"standard_name": "latitude",
"bounds": "lat_bnds",
},
),
"lon": (
"lon",
lon_out_1d,
{
"units": "degrees_east",
"standard_name": "longitude",
"bounds": "lon_bnds",
},
),
"nbnds": ("nbnds", np.array([0, 1], dtype="i4")),
},
weights = xr.open_dataset(mapfile)
in_shape = weights.src_grid_dims.load().data

# Since xESMF expects 2D vars, we'll insert a dummy dimension of size-1
if len(in_shape) == 1:
in_shape = [1, in_shape.item()]

# output variable shape
out_shape = weights.dst_grid_dims.load().data.tolist()[::-1]

# Some prep to get the bounds:
# Note that bounds are needed for conservative regridding and not for bilinear
lat_b_out = np.zeros(out_shape[0] + 1)
lon_b_out = weights.xv_b.data[: out_shape[1] + 1, 0]
lat_b_out[:-1] = weights.yv_b.data[np.arange(out_shape[0]) * out_shape[1], 0]
lat_b_out[-1] = weights.yv_b.data[-1, -1]

dummy_in = xr.Dataset(
{
"lat": ("lat", np.empty((in_shape[0],))),
"lon": ("lon", np.empty((in_shape[1],))),
"lat_b": ("lat_b", np.empty((in_shape[0] + 1,))),
"lon_b": ("lon_b", np.empty((in_shape[1] + 1,))),
}
)
return ds_in, ds_out
dummy_out = xr.Dataset(
{
"lat": ("lat", weights.yc_b.data.reshape(out_shape)[:, 0]),
"lon": ("lon", weights.xc_b.data.reshape(out_shape)[0, :]),
"lat_b": ("lat_b", lat_b_out),
"lon_b": ("lon_b", lon_b_out),
}
)

return dummy_in, dummy_out


# -------------------------
Expand Down
42 changes: 32 additions & 10 deletions cmip7_prep/cmor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,15 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
vals, bnds, _ = roll_for_monotonic_with_bounds(vals, bnds)
return vals, bnds, units

logger.info("Defining CMOR axes for variable: %s", vdef.name)
axes_ids = []
var_name = getattr(vdef, "name", None)
if var_name is None or var_name not in ds:
raise KeyError(f"Variable to write not found in dataset: {var_name!r}")
var_name = getattr(vdef, "branded_variable_name", None)
if var_name is None or var_name not in ds:
raise KeyError(f"Variable to write not found in dataset: {var_name!r}")
var_da = ds[str(var_name)]
logger.info("found var_da with name: %s", var_da.name)
var_dims = list(var_da.dims)
alev_id = None
plev_id = None
Expand Down Expand Up @@ -398,10 +402,10 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
coord_vals=tvals,
cell_bounds=tbnds if tbnds is not None else None,
)
logger.info("time axis id: %s", time_id)
logger.info("time axis id: %s var_dims=%s", time_id, var_dims)
# --- vertical: standard_hybrid_sigma ---
levels = getattr(vdef, "levels", {}) or {}

logger.info("levels dict: %s", levels)
if (levels.get("name") or "").lower() in {
"standard_hybrid_sigma",
"alevel",
Expand Down Expand Up @@ -555,6 +559,9 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
elif "zl" in var_dims:
ds[var_name] = ds[var_name].rename({"zl": "olevel"})
var_dims = list(ds[var_name].dims)
logger.info(
"rename zl to olevel var_name %s var_dims=%s", var_name, var_dims
)
cmor.set_cur_dataset_attribute("vertical_label", "olevel")
logger.info("*** Define olevel axis")
values = ds["olevel"].values
Expand All @@ -567,9 +574,13 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
coord_vals=np.asarray(values),
cell_bounds=bnds,
)
elif "zi" in var_dims:
ds[var_name] = ds[var_name].rename({"zi": "olevel"})
elif "zl" in var_dims:
logger.info("found zl axis in var_dims for variable %s", var_name)
ds[var_name] = ds[var_name].rename({"zl": "olevel"})
var_dims = list(ds[var_name].dims)
logger.info(
"rename zl to olevel var_name %s var_dims=%s", var_name, var_dims
)
cmor.set_cur_dataset_attribute("vertical_label", "olevel")
logger.info("*** Define olevel axis")
values = ds["olevel"].values
Expand Down Expand Up @@ -809,16 +820,27 @@ def write_variable(
"Using CMOR table key: %s %s", self.tables_path, self.primarytable
) # debug
self.load_table(self.tables_path, self.primarytable)
varname = getattr(cmip_var, "physical_parameter").name
logger.info("Preparing to write variable: %s", varname) # debug
data = ds[str(varname)]
bvn_attr = getattr(cmip_var, "branded_variable_name", None)
bvn = bvn_attr.name if bvn_attr is not None else None
if bvn is None:
# Fall back to vdef.name when no branded variable name is provided
bvn = getattr(vdef, "name", None)
if bvn is None:
raise ValueError(
"Cannot determine branded variable name: both "
"`cmip_var.branded_variable_name` and `vdef.name` are missing."
)
if bvn not in ds:
ds = ds.rename({vdef.name: bvn})
logger.info("Preparing to write variable: %s", bvn) # debug
data = ds[str(bvn)]

logger.info("Ensure fx variables are written and cached") # debug
self.ensure_fx_written_and_cached(ds)

units = getattr(vdef, "units", "") or ""
self.load_table(self.tables_path, self.primarytable)
logger.info("Define CMOR axes for variable %s", vdef.name) # debug
logger.info("Define CMOR axes for variable %s", bvn) # debug
axes_ids = self._define_axes(ds, vdef)
logger.info("Prepare data for CMOR %s", data.dtype) # debug
data_filled, fillv = filled_for_cmor(data)
Expand All @@ -828,7 +850,7 @@ def write_variable(
data_filled = data_filled.rename({"zi": "olevel"})
self.load_table(self.tables_path, self.primarytable)

var_entry = getattr(cmip_var, "branded_variable_name", varname)
var_entry = getattr(cmip_var, "branded_variable_name", bvn)
if hasattr(var_entry, "name"):
var_entry = var_entry.name
elif hasattr(var_entry, "value"):
Expand Down
Loading