Skip to content

Commit 3f0cc34

Browse files
Pint Addition (reopened with respect to main) (COSIMA#261)
* Retry * black formatting * more black formatting * new black formatting * ForOM2 * Black * New black grr * add a comment to explain why we keep manual conversion --------- Co-authored-by: Ashley Barnes <53282288+ashjbarnes@users.noreply.github.com> Co-authored-by: ashjbarnes <ashjbarnes97@gmail.com>
1 parent 81e3489 commit 3f0cc34

File tree

9 files changed

+151
-18
lines changed

9 files changed

+151
-18
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ dependencies = [
1717
"xarray",
1818
"xesmf >= 0.8.4",
1919
"f90nml >= 1.4.1",
20-
"copernicusmarine >= 2.0.0,<2.1.0"
20+
"copernicusmarine >= 2.0.0,<2.1.0",
21+
"pint_xarray"
2122
]
2223

2324
[build-system]
@@ -31,6 +32,7 @@ packages = ["regional_mom6", "regional_mom6.demos.premade_run_directories"]
3132
"regional_mom6.demos" = "demos"
3233

3334
[tool.setuptools.package-data]
35+
"regional_mom6" = ["*.txt"] # include rm6_unit_defs.txt here
3436
"regional_mom6.demos.premade_run_directories" = ["**/*"]
3537

3638
[tool.setuptools_scm]

regional_mom6/MOM_parameter_tools.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def change_MOM_parameter(
3535
MOM_override_dict = read_MOM_file_as_dict("MOM_override", directory)
3636
original_val = "No original val"
3737
if not delete:
38-
3938
if param_name in MOM_override_dict:
4039
original_val = MOM_override_dict[param_name]["value"]
4140
print(

regional_mom6/regional_mom6.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
ep2ap,
2626
rotate,
2727
find_files_by_pattern,
28+
try_pint_convert,
2829
)
2930

30-
3131
warnings.filterwarnings("ignore")
3232

3333
__all__ = [
@@ -39,6 +39,14 @@
3939
"get_glorys_data",
4040
]
4141

42+
# If the array is pint possible, ensure we have the right units for main fields (eta, u, v, temp),
43+
# salinity and bgc tracers are a bit more abstract and should be already in the correct units, a TODO: would be to add functionality to convert these tracers
44+
main_field_target_units = {
45+
"eta": "m",
46+
"u": "m/s",
47+
"v": "m/s",
48+
"temp": "degC",
49+
}
4250

4351
## Mapping Functions
4452

@@ -135,7 +143,6 @@ def longitude_slicer(data, longitude_extent, longitude_coords):
135143

136144
for i in range(-1, 2, 1):
137145
if data[lon][0] <= central_longitude + 360 * i <= data[lon][-1]:
138-
139146
## Shifted version of target midpoint; e.g., could be -90 vs 270
140147
## integer i keeps track of what how many multiples of 360 we need to shift entire
141148
## grid by to match central_longitude
@@ -231,11 +238,9 @@ def get_glorys_data(
231238

232239
file = open(Path(path / "get_glorys_data.sh"), "w")
233240

234-
lines.append(
235-
f"""
241+
lines.append(f"""
236242
copernicusmarine subset --dataset-id cmems_mod_glo_phy_my_0.083deg_P1D-m --variable so --variable thetao --variable uo --variable vo --variable zos --start-datetime {str(timerange[0]).replace(" ","T")} --end-datetime {str(timerange[1]).replace(" ","T")} --minimum-longitude {longitude_extent[0] - buffer} --maximum-longitude {longitude_extent[1] + buffer} --minimum-latitude {latitude_extent[0] - buffer} --maximum-latitude {latitude_extent[1] + buffer} --minimum-depth 0 --maximum-depth 6000 -o {str(path)} -f {segment_name}.nc\n
237-
"""
238-
)
243+
""")
239244
file.writelines(lines)
240245
file.close()
241246
return Path(path / "get_glorys_data.sh")
@@ -607,7 +612,6 @@ def __init__(
607612
regridding_method="bilinear",
608613
fill_method=rgd.fill_missing_data,
609614
):
610-
611615
# Creates an empty experiment object for testing and experienced user manipulation.
612616
if create_empty:
613617
return
@@ -830,7 +834,6 @@ def bathymetry_path(self):
830834
return "Not Found"
831835

832836
def __getattr__(self, name):
833-
834837
## First, check whether the attribute is an input file
835838
if "segment" in name:
836839
try:
@@ -904,7 +907,6 @@ def _make_hgrid(self):
904907
), "only even_spacing grid type is implemented"
905908

906909
if self.hgrid_type == "even_spacing":
907-
908910
# longitudes are evenly spaced based on resolution and bounds
909911
nx = int(
910912
(self.longitude_extent[1] - self.longitude_extent[0])
@@ -1023,6 +1025,22 @@ def setup_initial_condition(
10231025
if type(reprocessed_var_map["depth_coord"]) == list:
10241026
reprocessed_var_map["depth_coord"] = reprocessed_var_map["depth_coord"][0]
10251027

1028+
# Convert zdim if possible & needed
1029+
ic_raw[reprocessed_var_map["depth_coord"]] = try_pint_convert(
1030+
ic_raw[reprocessed_var_map["depth_coord"]],
1031+
"m",
1032+
reprocessed_var_map["depth_coord"],
1033+
)
1034+
1035+
# Convert values
1036+
for var in main_field_target_units:
1037+
if var == "temp" or var == "salt":
1038+
value_name = reprocessed_var_map["tracer_var_names"][var]
1039+
else:
1040+
value_name = reprocessed_var_map[var + "_var_name"]
1041+
ic_raw[value_name] = try_pint_convert(
1042+
ic_raw[value_name], main_field_target_units[var], var
1043+
)
10261044
# Remove time dimension if present in the IC.
10271045
# Assume that the first time dim is the intended one if more than one is present
10281046

@@ -1060,6 +1078,8 @@ def setup_initial_condition(
10601078

10611079
## if min(temperature) > 100 then assume that units must be degrees K
10621080
## (otherwise we can't be on Earth) and convert to degrees C
1081+
## Although we now attempt a pint convert, we're leaving this manual conversion in for now
1082+
## just in case, as K->C is absolutely necessary, and for some inputs pint may fail where this won't.
10631083
if np.nanmin(ic_raw[reprocessed_var_map["tracer_var_names"]["temp"]]) > 100:
10641084
ic_raw[reprocessed_var_map["tracer_var_names"]["temp"]] -= 273.15
10651085
ic_raw[reprocessed_var_map["tracer_var_names"]["temp"]].attrs[
@@ -2377,7 +2397,6 @@ def setup_run_directory(
23772397
)
23782398
# Tides OBC adjustments
23792399
if with_tides:
2380-
23812400
# Include internal tide forcing
23822401
MOM_override_dict["TIDES"]["value"] = "True"
23832402

@@ -2655,6 +2674,14 @@ def regrid_velocity_tracers(
26552674

26562675
coords = rgd.coords(self.hgrid, self.orientation, self.segment_name)
26572676

2677+
# Convert z coordinates to meters if pint-enabled
2678+
if type(reprocessed_var_map["depth_coord"]) != list:
2679+
dc_list = [reprocessed_var_map["depth_coord"]]
2680+
else:
2681+
dc_list = reprocessed_var_map["depth_coord"]
2682+
for dc in dc_list:
2683+
rawseg[dc] = try_pint_convert(rawseg[dc], "m", dc)
2684+
26582685
regridders = create_vt_regridders(
26592686
reprocessed_var_map,
26602687
rawseg,
@@ -2789,6 +2816,17 @@ def regrid_velocity_tracers(
27892816
## Rename each variable in dataset
27902817
segment_out = segment_out.rename({allfields[var]: v})
27912818

2819+
# Try Pint Conversion
2820+
if var in main_field_target_units:
2821+
# Apply raw data units if they exist
2822+
units = rawseg[allfields[var]].attrs.get("units")
2823+
if units is not None:
2824+
segment_out[v].attrs["units"] = units
2825+
2826+
segment_out[v] = try_pint_convert(
2827+
segment_out[v], main_field_target_units[var], var
2828+
)
2829+
27922830
# Find out if the tracer has depth, and if so, what is it's z dimension (z dimension being a list is an edge case for MARBL BGC)
27932831
variable_has_depth = False
27942832
depth_coord = None
@@ -3093,7 +3131,6 @@ def encode_tidal_files_and_output(self, ds, filename):
30933131
## Expand Tidal Dimensions ##
30943132

30953133
for var in ds:
3096-
30973134
ds = rgd.add_secondary_dimension(ds, str(var), coords, self.segment_name)
30983135

30993136
## Rename Tidal Dimensions ##

regional_mom6/regridding.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,6 @@ def add_secondary_dimension(
381381
)
382382
insert_behind_by = 0
383383
if not to_beginning:
384-
385384
if any(
386385
coord.startswith("nz") or coord == "constituent" for coord in ds[var].dims
387386
):
@@ -579,7 +578,6 @@ def mask_dataset(
579578
mask = mask[np.newaxis, :]
580579

581580
for var in ds.data_vars.keys():
582-
583581
# Drop to just the Boundary Dim
584582
da = ds[var].isel({dim: 0 for dim in list(ds[var].dims)[:-2]}).squeeze()
585583

regional_mom6/rm6_unit_defs.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Insert Custom NetCDF Units Here
2+
3+
# Units defs for CESM output (degrees are dimensionless, so 1)
4+
degrees_east = 1
5+
degrees_north = 1
6+
degrees_N = 1

regional_mom6/rotation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def modulo_around_point(x, x0, L):
121121
if L <= 0:
122122
return x
123123
else:
124-
125124
# Find that boundary point x0 + L/2
126125
edge_indexes = np.where((x == x0 + L / 2))
127126

regional_mom6/utils.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,97 @@
44
import xarray as xr
55
from regional_mom6 import regridding as rgd
66
from pathlib import Path
7+
import pint
8+
import pint_xarray
9+
import importlib.resources
10+
11+
# from pint_xarray.errors import PintExceptionGroup # This is only supported when pint_xarray is 0.6.0, which is not currently supported in the CI
12+
13+
14+
# Handle Unit Registry (only done once)
15+
ureg = pint.UnitRegistry(
16+
force_ndarray_like=True
17+
) # The force option is required for pint_xarray
18+
19+
unit_path = Path(importlib.resources.files("regional_mom6") / "rm6_unit_defs.txt")
20+
ureg.load_definitions(unit_path)
21+
22+
23+
def try_pint_convert(da, target_units, var_name=None, debug=False):
24+
"""
25+
Attempt to quantify and convert an xarray DataArray using Pint.
26+
27+
Steps:
28+
1. Check if the DataArray has a 'units' attribute.
29+
- If not, Pint cannot quantify it, so we raise a ValueError.
30+
2. Quantify the DataArray using Pint (attach units).
31+
- This converts the DataArray into a pint-aware object.
32+
- If already a Pint Quantity, skip quantification.
33+
3. Convert the DataArray to the target units if necessary.
34+
- Uses Pint's `.to()` to perform unit conversion.
35+
- Dequantify afterwards to return a plain xarray DataArray.
36+
4. If any step fails (missing units, invalid units, etc.),
37+
- Log a warning and return the original DataArray unchanged.
38+
39+
Parameters
40+
----------
41+
da : xarray.DataArray
42+
The DataArray to quantify and/or convert.
43+
target_units : str
44+
Units to convert the DataArray to.
45+
var_name : str, optional
46+
Name of the variable (used for logging messages).
47+
debug : bool, optional
48+
If True, print debug information about the error with subexceptions.
49+
50+
Returns
51+
-------
52+
xarray.DataArray
53+
A DataArray with units converted if successful; otherwise the original.
54+
"""
55+
try:
56+
# Get the units string from the DataArray attributes
57+
source_units = da.attrs.get("units", None)
58+
if not source_units:
59+
raise ValueError(f"DataArray '{var_name}' has no units; cannot quantify.")
60+
61+
# Only quantify if not already a Pint Quantity
62+
if not isinstance(da.data, pint.Quantity):
63+
da_quantified = da.pint.quantify(unit_registry=ureg)
64+
65+
# code for PintExceptionGroup (not supported in current CI until we can use pint 0.6.0)
66+
# Allows catching multiple quantification errors at once
67+
# except PintExceptionGroup as ex_group:
68+
# print(f"PintExceptionGroup: could not quantify some elements of {var_name}")
69+
# for idx, exc in enumerate(ex_group.exceptions):
70+
# print(f" Sub-exception {idx+1}: {exc}")
71+
# raise ex_group
72+
73+
# Convert to the target units if they differ
74+
if source_units != target_units:
75+
da_converted = da_quantified.pint.to(target_units).pint.dequantify()
76+
utils_logger.warning(
77+
f"Converted {var_name} from {source_units} to {target_units}"
78+
)
79+
return da_converted
80+
else:
81+
utils_logger.info(f"Units for {var_name} did not need to be converted")
82+
83+
except Exception as e:
84+
# If any error occurs (bad units, missing Pint, etc.), fall back gracefully
85+
utils_logger.warning(
86+
f"regional_mom6 could not use pint for data array {var_name}, assuming it's in the correct units"
87+
)
88+
if debug:
89+
if hasattr(e, "exceptions"):
90+
for i, exc in enumerate(e.exceptions):
91+
print(f"\n--- Sub-exception {i} ---")
92+
print(type(exc).__name__, exc)
93+
else:
94+
print(e)
95+
96+
# Return the original DataArray if quantification or conversion failed
97+
return da
798

899

9100
def vecdot(v1, v2):
@@ -432,3 +523,6 @@ def get_edge(ds, edge, x_name=None, y_name=None):
432523
return ds.isel({x_name: -1})
433524
if edge == "west":
434525
return ds.isel({x_name: 0})
526+
527+
528+
utils_logger = setup_logger(__name__, set_handler=False)

tests/test_config_class.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def test_write_config(create_expt, tmp_path):
101101

102102

103103
def test_read_config(create_expt, tmp_path):
104-
105104
expt = create_expt
106105
path = tmp_path / "testing_config.json"
107106
Config.save_to_json(expt, path)

tests/test_expt_class.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ def test_rectangular_boundaries(
249249
hgrid_type,
250250
tmp_path,
251251
):
252-
253252
eastern_boundary = xr.Dataset(
254253
{
255254
"temp": xr.DataArray(

0 commit comments

Comments
 (0)