Skip to content

Commit 05d94ae

Browse files
authored
Merge pull request #14 from ESMCI/dev011526
Dev011526
2 parents 9bf685f + 6523cb5 commit 05d94ae

File tree

9 files changed

+430
-275
lines changed

9 files changed

+430
-275
lines changed

cmip7_prep/cache_tools.py

Lines changed: 34 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import xarray as xr
88
import numpy as np
99

10-
from cmip7_prep.cmor_utils import bounds_from_centers_1d
11-
1210
logger = logging.getLogger(__name__)
1311

1412

@@ -48,81 +46,42 @@ def _make_dummy_grids(mapfile: Path) -> tuple[xr.Dataset, xr.Dataset]:
4846
"""Construct minimal ds_in/ds_out satisfying xESMF when reusing weights.
4947
Adds CF-style bounds for both lat and lon so conservative methods don’t
5048
trigger cf-xarray’s bounds inference on size-1 dimensions."""
51-
with open_nc(mapfile) as m:
52-
nlon_in, nlat_in = _get_src_shape(m)
53-
lat_out_1d, lon_out_1d = _get_dst_latlon_1d(mapfile=mapfile)
54-
55-
# --- Dummy INPUT grid (unstructured → represent as 2D with length-1 lat) ---
56-
lat_in = np.arange(
57-
-90.0, 90.0, 180.0 / nlat_in, dtype="f8"
58-
) # e.g., [0], length can be 1
59-
lon_in = np.arange(0.5, 360.5, 360.0 / nlon_in, dtype="f8")
60-
ds_in = xr.Dataset(
61-
data_vars={
62-
"lat_bnds": (
63-
("lat", "nbnds"),
64-
bounds_from_centers_1d(lat_in, "lat"),
65-
),
66-
"lon_bnds": (
67-
("lon", "nbnds"),
68-
bounds_from_centers_1d(lon_in, "lon"),
69-
),
70-
},
71-
coords={
72-
"lat": (
73-
"lat",
74-
lat_in,
75-
{
76-
"units": "degrees_north",
77-
"standard_name": "latitude",
78-
"bounds": "lat_bnds",
79-
},
80-
),
81-
"lon": (
82-
"lon",
83-
lon_in,
84-
{
85-
"units": "degrees_east",
86-
"standard_name": "longitude",
87-
"bounds": "lon_bnds",
88-
},
89-
),
90-
"nbnds": ("nbnds", np.array([0, 1], dtype="i4")),
91-
},
92-
)
9349

94-
# --- OUTPUT grid from weights (canonical 1° lat/lon) ---
95-
lat_out_bnds = bounds_from_centers_1d(lat_out_1d, "lat")
96-
lon_out_bnds = bounds_from_centers_1d(lon_out_1d, "lon")
97-
98-
ds_out = xr.Dataset(
99-
data_vars={
100-
"lat_bnds": (("lat", "nbnds"), lat_out_bnds),
101-
"lon_bnds": (("lon", "nbnds"), lon_out_bnds),
102-
},
103-
coords={
104-
"lat": (
105-
"lat",
106-
lat_out_1d,
107-
{
108-
"units": "degrees_north",
109-
"standard_name": "latitude",
110-
"bounds": "lat_bnds",
111-
},
112-
),
113-
"lon": (
114-
"lon",
115-
lon_out_1d,
116-
{
117-
"units": "degrees_east",
118-
"standard_name": "longitude",
119-
"bounds": "lon_bnds",
120-
},
121-
),
122-
"nbnds": ("nbnds", np.array([0, 1], dtype="i4")),
123-
},
50+
weights = xr.open_dataset(mapfile)
51+
in_shape = weights.src_grid_dims.load().data
52+
53+
# Since xESMF expects 2D vars, we'll insert a dummy dimension of size-1
54+
if len(in_shape) == 1:
55+
in_shape = [1, in_shape.item()]
56+
57+
# output variable shape
58+
out_shape = weights.dst_grid_dims.load().data.tolist()[::-1]
59+
60+
# Some prep to get the bounds:
61+
# Note that bounds are needed for conservative regridding and not for bilinear
62+
lat_b_out = np.zeros(out_shape[0] + 1)
63+
lon_b_out = weights.xv_b.data[: out_shape[1] + 1, 0]
64+
lat_b_out[:-1] = weights.yv_b.data[np.arange(out_shape[0]) * out_shape[1], 0]
65+
lat_b_out[-1] = weights.yv_b.data[-1, -1]
66+
67+
dummy_in = xr.Dataset(
68+
{
69+
"lat": ("lat", np.empty((in_shape[0],))),
70+
"lon": ("lon", np.empty((in_shape[1],))),
71+
"lat_b": ("lat_b", np.empty((in_shape[0] + 1,))),
72+
"lon_b": ("lon_b", np.empty((in_shape[1] + 1,))),
73+
}
12474
)
125-
return ds_in, ds_out
75+
dummy_out = xr.Dataset(
76+
{
77+
"lat": ("lat", weights.yc_b.data.reshape(out_shape)[:, 0]),
78+
"lon": ("lon", weights.xc_b.data.reshape(out_shape)[0, :]),
79+
"lat_b": ("lat_b", lat_b_out),
80+
"lon_b": ("lon_b", lon_b_out),
81+
}
82+
)
83+
84+
return dummy_in, dummy_out
12685

12786

12887
# -------------------------

cmip7_prep/cmor_writer.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,15 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
271271
vals, bnds, _ = roll_for_monotonic_with_bounds(vals, bnds)
272272
return vals, bnds, units
273273

274+
logger.info("Defining CMOR axes for variable: %s", vdef.name)
274275
axes_ids = []
275276
var_name = getattr(vdef, "name", None)
276277
if var_name is None or var_name not in ds:
277-
raise KeyError(f"Variable to write not found in dataset: {var_name!r}")
278+
var_name = getattr(vdef, "branded_variable_name", None)
279+
if var_name is None or var_name not in ds:
280+
raise KeyError(f"Variable to write not found in dataset: {var_name!r}")
278281
var_da = ds[str(var_name)]
282+
logger.info("found var_da with name: %s", var_da.name)
279283
var_dims = list(var_da.dims)
280284
alev_id = None
281285
plev_id = None
@@ -398,10 +402,10 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
398402
coord_vals=tvals,
399403
cell_bounds=tbnds if tbnds is not None else None,
400404
)
401-
logger.info("time axis id: %s", time_id)
405+
logger.info("time axis id: %s var_dims=%s", time_id, var_dims)
402406
# --- vertical: standard_hybrid_sigma ---
403407
levels = getattr(vdef, "levels", {}) or {}
404-
408+
logger.info("levels dict: %s", levels)
405409
if (levels.get("name") or "").lower() in {
406410
"standard_hybrid_sigma",
407411
"alevel",
@@ -555,6 +559,9 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
555559
elif "zl" in var_dims:
556560
ds[var_name] = ds[var_name].rename({"zl": "olevel"})
557561
var_dims = list(ds[var_name].dims)
562+
logger.info(
563+
"rename zl to olevel var_name %s var_dims=%s", var_name, var_dims
564+
)
558565
cmor.set_cur_dataset_attribute("vertical_label", "olevel")
559566
logger.info("*** Define olevel axis")
560567
values = ds["olevel"].values
@@ -567,9 +574,13 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
567574
coord_vals=np.asarray(values),
568575
cell_bounds=bnds,
569576
)
570-
elif "zi" in var_dims:
571-
ds[var_name] = ds[var_name].rename({"zi": "olevel"})
577+
elif "zl" in var_dims:
578+
logger.info("found zl axis in var_dims for variable %s", var_name)
579+
ds[var_name] = ds[var_name].rename({"zl": "olevel"})
572580
var_dims = list(ds[var_name].dims)
581+
logger.info(
582+
"rename zl to olevel var_name %s var_dims=%s", var_name, var_dims
583+
)
573584
cmor.set_cur_dataset_attribute("vertical_label", "olevel")
574585
logger.info("*** Define olevel axis")
575586
values = ds["olevel"].values
@@ -809,16 +820,27 @@ def write_variable(
809820
"Using CMOR table key: %s %s", self.tables_path, self.primarytable
810821
) # debug
811822
self.load_table(self.tables_path, self.primarytable)
812-
varname = getattr(cmip_var, "physical_parameter").name
813-
logger.info("Preparing to write variable: %s", varname) # debug
814-
data = ds[str(varname)]
823+
bvn_attr = getattr(cmip_var, "branded_variable_name", None)
824+
bvn = bvn_attr.name if bvn_attr is not None else None
825+
if bvn is None:
826+
# Fall back to vdef.name when no branded variable name is provided
827+
bvn = getattr(vdef, "name", None)
828+
if bvn is None:
829+
raise ValueError(
830+
"Cannot determine branded variable name: both "
831+
"`cmip_var.branded_variable_name` and `vdef.name` are missing."
832+
)
833+
if bvn not in ds:
834+
ds = ds.rename({vdef.name: bvn})
835+
logger.info("Preparing to write variable: %s", bvn) # debug
836+
data = ds[str(bvn)]
815837

816838
logger.info("Ensure fx variables are written and cached") # debug
817839
self.ensure_fx_written_and_cached(ds)
818840

819841
units = getattr(vdef, "units", "") or ""
820842
self.load_table(self.tables_path, self.primarytable)
821-
logger.info("Define CMOR axes for variable %s", vdef.name) # debug
843+
logger.info("Define CMOR axes for variable %s", bvn) # debug
822844
axes_ids = self._define_axes(ds, vdef)
823845
logger.info("Prepare data for CMOR %s", data.dtype) # debug
824846
data_filled, fillv = filled_for_cmor(data)
@@ -828,7 +850,7 @@ def write_variable(
828850
data_filled = data_filled.rename({"zi": "olevel"})
829851
self.load_table(self.tables_path, self.primarytable)
830852

831-
var_entry = getattr(cmip_var, "branded_variable_name", varname)
853+
var_entry = getattr(cmip_var, "branded_variable_name", bvn)
832854
if hasattr(var_entry, "name"):
833855
var_entry = var_entry.name
834856
elif hasattr(var_entry, "value"):

0 commit comments

Comments
 (0)