88"""
99
1010from pathlib import Path
11- import collections .abc
1211import json
1312import tempfile
14- import types
15- import warnings
1613
1714from contextlib import AbstractContextManager
18- from typing import Any , Sequence , Optional , Union
15+ from typing import Any , Optional , Union
1916import datetime as dt
2017
2118import logging
@@ -161,8 +158,10 @@ def __enter__(self) -> "CmorSession":
161158 tmp = tempfile .NamedTemporaryFile ("w" , suffix = ".json" , delete = False )
162159 json .dump (cfg , tmp )
163160 tmp .close ()
161+ if not Path (tmp .name ).exists ():
162+ raise FileNotFoundError (f"Temporary dataset_json not found: { tmp .name } " )
164163 cmor .dataset_json (str (tmp .name ))
165-
164+ logger . info ( "CMOR dataset_json loaded from: %s" , tmp . name )
166165 try :
167166 prod = cmor .get_cur_dataset_attribute ("product" ) # type: ignore[attr-defined]
168167 except Exception : # pylint: disable=broad-except
@@ -286,7 +285,18 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
286285 lev_id = None
287286
288287 logger .debug ("[CMOR axis debug] var_dims: %s" , var_dims )
289- if "xh" in var_dims and "yh" in var_dims :
288+ has_latlon_dims = "lat" in var_dims and "lon" in var_dims
289+ has_latitude_dims = "latitude" in var_dims and "longitude" in var_dims
290+ has_mom6_dims = ("xh" in var_dims or "xq" in var_dims ) and (
291+ "yh" in var_dims or "yq" in var_dims
292+ )
293+ if has_latitude_dims and not has_mom6_dims :
294+ raise ValueError (
295+ "Found 'latitude'/'longitude' dims without MOM6 grid; expected 'lat'/'lon'."
296+ )
297+ if ("xh" in var_dims or "xq" in var_dims ) and (
298+ "yh" in var_dims or "yq" in var_dims
299+ ):
290300 # MOM6/curvilinear grid: register xh/yh as generic axes (i/j), not as lat/lon
291301 # Define the native grid using the coordinate arrays
292302 logger .debug (
@@ -299,12 +309,22 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
299309 if not geo_path .exists ():
300310 raise FileNotFoundError (f"Expected geometry file not found: { geo_path } " )
301311 ds_geo = xr .open_dataset (geo_path )
302- lat_vals_1d = ds_geo ["lath" ].values
303- lon_raw = np .mod (ds_geo ["lonh" ].values , 360.0 )
312+ if "xh" in var_dims :
313+ lon_raw = np .mod (ds_geo ["lonh" ].values , 360.0 )
314+ else :
315+ lon_raw = np .mod (ds_geo ["lonq" ].values , 360.0 )
304316 lon_bnds_raw = bounds_from_centers_1d (lon_raw , "lon" )
305317 lon_vals_1d , lon_bnds , shift = roll_for_monotonic_with_bounds (
306318 lon_raw , lon_bnds_raw
307319 )
320+ if "yh" in var_dims :
321+ lat_vals_1d = ds_geo ["lath" ].values
322+ elif "yq" in var_dims :
323+ lat_vals_1d = ds_geo ["latq" ].values
324+ else :
325+ raise KeyError (
326+ "Expected 'yh' or 'yq' dimension for latitude not found."
327+ )
308328 # Fix first and last bounds to wrap correctly
309329 if lon_bnds .shape [0 ] > 1 :
310330 # Ensure bounds are strictly increasing and wrap at dateline
@@ -314,6 +334,7 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
314334 lon_bnds [- 1 , 1 ] = 360.0
315335 # Also correct first upper bound to match the first cell
316336 lon_bnds [0 , 1 ] = lon_bnds [1 , 0 ]
337+ logger .info ("[CMOR axis debug] corrected lon_bnds: %s" , lon_bnds )
317338 # Print lon_bnds for a range (debug)
318339 i_id = cmor .axis (
319340 table_entry = "latitude" ,
@@ -329,13 +350,24 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
329350 cell_bounds = lon_bnds ,
330351 )
331352 axes_ids .extend ([j_id , i_id ])
332- ds [var_name ] = var_da .roll (xh = - shift , roll_coords = True )
333- # rename dims xh and yh to longitude and latitude
334- ds [var_name ] = ds [var_name ].rename ({"xh" : "longitude" , "yh" : "latitude" })
353+ for dim in ("xh" , "xq" , "yh" , "yq" ):
354+ if dim in var_dims :
355+ # rename dims xh and yh to longitude and latitude
356+ logger .info ("[CMOR axis debug] renaming dim %s" , dim )
357+ if dim in ["xh" , "xq" ]:
358+ if dim == "xh" :
359+ ds [var_name ] = var_da .roll (xh = - shift , roll_coords = True )
360+ elif dim == "xq" :
361+ ds [var_name ] = var_da .roll (xq = - shift , roll_coords = True )
362+ ds [var_name ] = ds [var_name ].rename ({dim : "longitude" })
363+ if dim in ["yh" , "yq" ]:
364+ ds [var_name ] = ds [var_name ].rename ({dim : "latitude" })
365+
335366 var_da = ds [var_name ]
336367 var_dims = list (var_da .dims )
368+
337369 # --- horizontal axes (use CMOR names) ----
338- elif "lat" in var_dims and "lon" in var_dims :
370+ elif has_latlon_dims :
339371 logger .info ("*** Define horizontal axes" )
340372 lat_vals , lat_bnds , _ = _get_1d_with_bounds (ds , "lat" , "degrees_north" )
341373 lon_vals , lon_bnds , _ = _get_1d_with_bounds (ds , "lon" , "degrees_east" )
@@ -375,6 +407,14 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
375407 "alevel" ,
376408 "alev" ,
377409 } or "lev" in var_dims :
410+ if (self .primarytable or "" ).lower () not in {
411+ "atmos" ,
412+ "atmoschem" ,
413+ "aerosol" ,
414+ }:
415+ raise ValueError (
416+ "Hybrid sigma coordinates are only supported for atmospheric tables."
417+ )
378418 # names in the native ds
379419 logger .info ("*** Define hybrid sigma axis" )
380420 hyam_name = levels .get ("hyam" , "hyam" ) # A mid (dimensionless)
@@ -481,16 +521,26 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
481521 cell_bounds = pb if pb is not None else None ,
482522 )
483523 elif "sdepth" in var_dims :
524+ # Read sdepth values from ds as before
484525 values = ds ["sdepth" ].values
485526 logger .info ("write sdepth axis" )
486- bnds = bounds_from_centers_1d (values , "sdepth" )
487- if bnds [0 , 0 ] < 0 :
488- bnds [0 , 0 ] = 0.0 # no negative soil depth bounds
527+ # Read depth_bnds from the NetCDF file in the data directory
528+ depth_bnds_path = Path (__file__ ).parent / "data" / "depth_bnds.nc"
529+ with xr .open_dataset (depth_bnds_path ) as ds_bnds :
530+ depth_bnds = ds_bnds ["depth_bnds" ].values
531+ # Ensure depth_bnds matches the length of sdepth
532+ if depth_bnds .shape [0 ] > values .shape [0 ]:
533+ logger .warning (
534+ "Truncating depth_bnds from %d to %d levels to match sdepth" ,
535+ depth_bnds .shape [0 ],
536+ values .shape [0 ],
537+ )
538+ depth_bnds = depth_bnds [: values .shape [0 ], :]
489539 sdepth_id = cmor .axis (
490540 table_entry = "sdepth" ,
491541 units = "m" ,
492542 coord_vals = np .asarray (values ),
493- cell_bounds = bnds ,
543+ cell_bounds = depth_bnds ,
494544 )
495545 elif "z_l" in var_dims :
496546 values = ds ["z_l" ].values
@@ -517,7 +567,21 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
517567 coord_vals = np .asarray (values ),
518568 cell_bounds = bnds ,
519569 )
520-
570+ elif "zi" in var_dims :
571+ ds [var_name ] = ds [var_name ].rename ({"zi" : "olevel" })
572+ var_dims = list (ds [var_name ].dims )
573+ cmor .set_cur_dataset_attribute ("vertical_label" , "olevel" )
574+ logger .info ("*** Define olevel axis" )
575+ values = ds ["olevel" ].values
576+ logger .info ("write olevel axis" )
577+ zl = ds ["zl" ].values
578+ bnds = np .column_stack ((zl [:- 1 ], zl [1 :]))
579+ lev_id = cmor .axis (
580+ table_entry = "depth_coord" ,
581+ units = "m" ,
582+ coord_vals = np .asarray (values ),
583+ cell_bounds = bnds ,
584+ )
521585 # Map dimension names to axis IDs
522586 dim_to_axis = {
523587 "time" : time_id ,
@@ -533,10 +597,13 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
533597 "longitude" : lon_id if lon_id is not None else j_id ,
534598 "xh" : lon_id , # MOM6
535599 "yh" : lat_id , # MOM6
600+ "xq" : lon_id , # MOM6
601+ "yq" : lat_id , # MOM6
536602 }
537603 axes_ids = []
538604 for d in var_dims :
539605 axis_id = dim_to_axis .get (d )
606+ logger .info ("[CMOR axis debug] dim '%s' → axis_id: %s" , d , axis_id )
540607 if axis_id is None :
541608 raise KeyError (
542609 f"No axis ID found for dimension '{ d } ' in variable '{ var_name } ' { var_dims } "
@@ -757,6 +824,8 @@ def write_variable(
757824 data_filled , fillv = filled_for_cmor (data )
758825 if "zl" in data_filled .dims :
759826 data_filled = data_filled .rename ({"zl" : "olevel" })
827+ elif "zi" in data_filled .dims :
828+ data_filled = data_filled .rename ({"zi" : "olevel" })
760829 self .load_table (self .tables_path , self .primarytable )
761830
762831 var_entry = getattr (cmip_var , "branded_variable_name" , varname )
@@ -795,6 +864,7 @@ def write_variable(
795864 np .asarray (data_filled ),
796865 ntimes_passed = nt ,
797866 )
867+ logger .info ("Finished writing CMOR variable %s" , var_id ) # debug
798868 # ---- Hybrid ps streaming (if present) ----
799869 if self ._pending_ps is not None :
800870 ps_id , ps_da = self ._pending_ps
@@ -816,29 +886,3 @@ def write_variable(
816886 self ._pending_ps = None
817887
818888 cmor .close (var_id )
819-
820- def write_variables (
821- self ,
822- ds : xr .Dataset ,
823- cmip_vars : Sequence [str ],
824- mapping : collections .abc .Mapping ,
825- ) -> None :
826- """Write multiple CMIP variables from one dataset."""
827- for v in cmip_vars :
828- cfg = mapping .get_cfg (v ) or {}
829- table = cfg .get ("table" , "Amon" )
830- units = cfg .get ("units" , "" )
831- positive = cfg .get ("positive" ) or None
832- vdef = types .SimpleNamespace (
833- name = v ,
834- table = table ,
835- realm = table ,
836- units = units ,
837- positive = positive ,
838- )
839- # pylint: disable=broad-exception-caught
840- try :
841- self .write_variable (ds , v , vdef )
842- except Exception as e :
843- warnings .warn (f"[cmor] skipping { v } due to error: { e } " , RuntimeWarning )
844- # continue to next variable
0 commit comments