Skip to content

Commit 49e86bd

Browse files
authored
Merge pull request #12 from ESMCI/dev010726
Dev010726. Adds back _normalize_land_field.
2 parents 638dd77 + d78b134 commit 49e86bd

File tree

8 files changed

+210
-128
lines changed

8 files changed

+210
-128
lines changed

cmip7_prep/cmor_writer.py

Lines changed: 87 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,11 @@
88
"""
99

1010
from pathlib import Path
11-
import collections.abc
1211
import json
1312
import tempfile
14-
import types
15-
import warnings
1613

1714
from contextlib import AbstractContextManager
18-
from typing import Any, Sequence, Optional, Union
15+
from typing import Any, Optional, Union
1916
import datetime as dt
2017

2118
import logging
@@ -161,8 +158,10 @@ def __enter__(self) -> "CmorSession":
161158
tmp = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
162159
json.dump(cfg, tmp)
163160
tmp.close()
161+
if not Path(tmp.name).exists():
162+
raise FileNotFoundError(f"Temporary dataset_json not found: {tmp.name}")
164163
cmor.dataset_json(str(tmp.name))
165-
164+
logger.info("CMOR dataset_json loaded from: %s", tmp.name)
166165
try:
167166
prod = cmor.get_cur_dataset_attribute("product") # type: ignore[attr-defined]
168167
except Exception: # pylint: disable=broad-except
@@ -286,7 +285,18 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
286285
lev_id = None
287286

288287
logger.debug("[CMOR axis debug] var_dims: %s", var_dims)
289-
if "xh" in var_dims and "yh" in var_dims:
288+
has_latlon_dims = "lat" in var_dims and "lon" in var_dims
289+
has_latitude_dims = "latitude" in var_dims and "longitude" in var_dims
290+
has_mom6_dims = ("xh" in var_dims or "xq" in var_dims) and (
291+
"yh" in var_dims or "yq" in var_dims
292+
)
293+
if has_latitude_dims and not has_mom6_dims:
294+
raise ValueError(
295+
"Found 'latitude'/'longitude' dims without MOM6 grid; expected 'lat'/'lon'."
296+
)
297+
if ("xh" in var_dims or "xq" in var_dims) and (
298+
"yh" in var_dims or "yq" in var_dims
299+
):
290300
# MOM6/curvilinear grid: register xh/yh as generic axes (i/j), not as lat/lon
291301
# Define the native grid using the coordinate arrays
292302
logger.debug(
@@ -299,12 +309,22 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
299309
if not geo_path.exists():
300310
raise FileNotFoundError(f"Expected geometry file not found: {geo_path}")
301311
ds_geo = xr.open_dataset(geo_path)
302-
lat_vals_1d = ds_geo["lath"].values
303-
lon_raw = np.mod(ds_geo["lonh"].values, 360.0)
312+
if "xh" in var_dims:
313+
lon_raw = np.mod(ds_geo["lonh"].values, 360.0)
314+
else:
315+
lon_raw = np.mod(ds_geo["lonq"].values, 360.0)
304316
lon_bnds_raw = bounds_from_centers_1d(lon_raw, "lon")
305317
lon_vals_1d, lon_bnds, shift = roll_for_monotonic_with_bounds(
306318
lon_raw, lon_bnds_raw
307319
)
320+
if "yh" in var_dims:
321+
lat_vals_1d = ds_geo["lath"].values
322+
elif "yq" in var_dims:
323+
lat_vals_1d = ds_geo["latq"].values
324+
else:
325+
raise KeyError(
326+
"Expected 'yh' or 'yq' dimension for latitude not found."
327+
)
308328
# Fix first and last bounds to wrap correctly
309329
if lon_bnds.shape[0] > 1:
310330
# Ensure bounds are strictly increasing and wrap at dateline
@@ -314,6 +334,7 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
314334
lon_bnds[-1, 1] = 360.0
315335
# Also correct first upper bound to match the first cell
316336
lon_bnds[0, 1] = lon_bnds[1, 0]
337+
logger.info("[CMOR axis debug] corrected lon_bnds: %s", lon_bnds)
317338
# Print lon_bnds for a range (debug)
318339
i_id = cmor.axis(
319340
table_entry="latitude",
@@ -329,13 +350,24 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
329350
cell_bounds=lon_bnds,
330351
)
331352
axes_ids.extend([j_id, i_id])
332-
ds[var_name] = var_da.roll(xh=-shift, roll_coords=True)
333-
# rename dims xh and yh to longitude and latitude
334-
ds[var_name] = ds[var_name].rename({"xh": "longitude", "yh": "latitude"})
353+
for dim in ("xh", "xq", "yh", "yq"):
354+
if dim in var_dims:
355+
# rename dims xh and yh to longitude and latitude
356+
logger.info("[CMOR axis debug] renaming dim %s", dim)
357+
if dim in ["xh", "xq"]:
358+
if dim == "xh":
359+
ds[var_name] = var_da.roll(xh=-shift, roll_coords=True)
360+
elif dim == "xq":
361+
ds[var_name] = var_da.roll(xq=-shift, roll_coords=True)
362+
ds[var_name] = ds[var_name].rename({dim: "longitude"})
363+
if dim in ["yh", "yq"]:
364+
ds[var_name] = ds[var_name].rename({dim: "latitude"})
365+
335366
var_da = ds[var_name]
336367
var_dims = list(var_da.dims)
368+
337369
# --- horizontal axes (use CMOR names) ----
338-
elif "lat" in var_dims and "lon" in var_dims:
370+
elif has_latlon_dims:
339371
logger.info("*** Define horizontal axes")
340372
lat_vals, lat_bnds, _ = _get_1d_with_bounds(ds, "lat", "degrees_north")
341373
lon_vals, lon_bnds, _ = _get_1d_with_bounds(ds, "lon", "degrees_east")
@@ -375,6 +407,14 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
375407
"alevel",
376408
"alev",
377409
} or "lev" in var_dims:
410+
if (self.primarytable or "").lower() not in {
411+
"atmos",
412+
"atmoschem",
413+
"aerosol",
414+
}:
415+
raise ValueError(
416+
"Hybrid sigma coordinates are only supported for atmospheric tables."
417+
)
378418
# names in the native ds
379419
logger.info("*** Define hybrid sigma axis")
380420
hyam_name = levels.get("hyam", "hyam") # A mid (dimensionless)
@@ -481,16 +521,26 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
481521
cell_bounds=pb if pb is not None else None,
482522
)
483523
elif "sdepth" in var_dims:
524+
# Read sdepth values from ds as before
484525
values = ds["sdepth"].values
485526
logger.info("write sdepth axis")
486-
bnds = bounds_from_centers_1d(values, "sdepth")
487-
if bnds[0, 0] < 0:
488-
bnds[0, 0] = 0.0 # no negative soil depth bounds
527+
# Read depth_bnds from the NetCDF file in the data directory
528+
depth_bnds_path = Path(__file__).parent / "data" / "depth_bnds.nc"
529+
with xr.open_dataset(depth_bnds_path) as ds_bnds:
530+
depth_bnds = ds_bnds["depth_bnds"].values
531+
# Ensure depth_bnds matches the length of sdepth
532+
if depth_bnds.shape[0] > values.shape[0]:
533+
logger.warning(
534+
"Truncating depth_bnds from %d to %d levels to match sdepth",
535+
depth_bnds.shape[0],
536+
values.shape[0],
537+
)
538+
depth_bnds = depth_bnds[: values.shape[0], :]
489539
sdepth_id = cmor.axis(
490540
table_entry="sdepth",
491541
units="m",
492542
coord_vals=np.asarray(values),
493-
cell_bounds=bnds,
543+
cell_bounds=depth_bnds,
494544
)
495545
elif "z_l" in var_dims:
496546
values = ds["z_l"].values
@@ -517,7 +567,21 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
517567
coord_vals=np.asarray(values),
518568
cell_bounds=bnds,
519569
)
520-
570+
elif "zi" in var_dims:
571+
ds[var_name] = ds[var_name].rename({"zi": "olevel"})
572+
var_dims = list(ds[var_name].dims)
573+
cmor.set_cur_dataset_attribute("vertical_label", "olevel")
574+
logger.info("*** Define olevel axis")
575+
values = ds["olevel"].values
576+
logger.info("write olevel axis")
577+
zl = ds["zl"].values
578+
bnds = np.column_stack((zl[:-1], zl[1:]))
579+
lev_id = cmor.axis(
580+
table_entry="depth_coord",
581+
units="m",
582+
coord_vals=np.asarray(values),
583+
cell_bounds=bnds,
584+
)
521585
# Map dimension names to axis IDs
522586
dim_to_axis = {
523587
"time": time_id,
@@ -533,10 +597,13 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
533597
"longitude": lon_id if lon_id is not None else j_id,
534598
"xh": lon_id, # MOM6
535599
"yh": lat_id, # MOM6
600+
"xq": lon_id, # MOM6
601+
"yq": lat_id, # MOM6
536602
}
537603
axes_ids = []
538604
for d in var_dims:
539605
axis_id = dim_to_axis.get(d)
606+
logger.info("[CMOR axis debug] dim '%s' → axis_id: %s", d, axis_id)
540607
if axis_id is None:
541608
raise KeyError(
542609
f"No axis ID found for dimension '{d}' in variable '{var_name}' {var_dims}"
@@ -757,6 +824,8 @@ def write_variable(
757824
data_filled, fillv = filled_for_cmor(data)
758825
if "zl" in data_filled.dims:
759826
data_filled = data_filled.rename({"zl": "olevel"})
827+
elif "zi" in data_filled.dims:
828+
data_filled = data_filled.rename({"zi": "olevel"})
760829
self.load_table(self.tables_path, self.primarytable)
761830

762831
var_entry = getattr(cmip_var, "branded_variable_name", varname)
@@ -795,6 +864,7 @@ def write_variable(
795864
np.asarray(data_filled),
796865
ntimes_passed=nt,
797866
)
867+
logger.info("Finished writing CMOR variable %s", var_id) # debug
798868
# ---- Hybrid ps streaming (if present) ----
799869
if self._pending_ps is not None:
800870
ps_id, ps_da = self._pending_ps
@@ -816,29 +886,3 @@ def write_variable(
816886
self._pending_ps = None
817887

818888
cmor.close(var_id)
819-
820-
def write_variables(
821-
self,
822-
ds: xr.Dataset,
823-
cmip_vars: Sequence[str],
824-
mapping: collections.abc.Mapping,
825-
) -> None:
826-
"""Write multiple CMIP variables from one dataset."""
827-
for v in cmip_vars:
828-
cfg = mapping.get_cfg(v) or {}
829-
table = cfg.get("table", "Amon")
830-
units = cfg.get("units", "")
831-
positive = cfg.get("positive") or None
832-
vdef = types.SimpleNamespace(
833-
name=v,
834-
table=table,
835-
realm=table,
836-
units=units,
837-
positive=positive,
838-
)
839-
# pylint: disable=broad-exception-caught
840-
try:
841-
self.write_variable(ds, v, vdef)
842-
except Exception as e:
843-
warnings.warn(f"[cmor] skipping {v} due to error: {e}", RuntimeWarning)
844-
# continue to next variable

0 commit comments

Comments
 (0)