Skip to content

Commit beb4e29

Browse files
committed
bleh
1 parent dd7c770 commit beb4e29

File tree

2 files changed

+211
-20
lines changed

2 files changed

+211
-20
lines changed

regional_mom6/validate.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,56 @@
66

77
from pathlib import Path
88
import xarray as xr
9+
import numpy as np
10+
import re
911
from .utils import setup_logger
1012

1113
logger = setup_logger(__name__)
1214

1315

1416
def get_file(file: Path | xr.Dataset):
1517
"""accept a filepath or xarray dataset and return the xarray dataset"""
16-
if type(file) == xr.Dataset:
18+
if isinstance(file, xr.Dataset):
1719
return file
1820
else:
1921
return xr.open_dataset(file)
2022

2123

2224
def check(condition, warning):
23-
condition or logger.warn(warning)
25+
if not condition:
26+
logger.warning(warning)
27+
return condition
2428

2529

2630
# Individual validation rule functions
2731
def _check_fill_value(da: xr.DataArray):
2832
"""Check that fill values are set correctly"""
29-
check("_FillValue" in da.attrs, f"{var_name} does not have a FillValue attribute")
30-
31-
check(
32-
not np.isnan(da.attrs["_FillValue"]),
33-
f"Fill Value for variable {var_name} is NaN (normally not wanted)",
33+
condition = check(
34+
"_FillValue" in da.attrs, f"{da.name} does not have a FillValue attribute"
3435
)
3536

37+
if condition:
38+
check(
39+
not np.isnan(da.attrs["_FillValue"]),
40+
f"Fill Value for variable {da.name} is NaN (normally not wanted)",
41+
)
42+
3643

3744
def _check_coordinates(ds: xr.Dataset, var_name: str):
38-
"""Check that missing values are set correctly"""
45+
"""Check that coordinates attribute exists and all listed coordinates are present in the dataset"""
3946

4047
assert var_name in ds
41-
check(
48+
condition = check(
4249
"coordinates" in ds[var_name].attrs,
4350
f"{var_name} does not have a coordinates attribute",
4451
)
45-
46-
coordinates = ds[var_name].attrs["coordinates"]
47-
coordinates = coordinates.strip(" ")
48-
for coord in coordinates:
49-
check(coord in ds, f"Coordinate {coord} for variable {var_name} does not exist")
52+
if condition:
53+
coordinates = ds[var_name].attrs["coordinates"].strip()
54+
for coord in coordinates.split():
55+
check(
56+
coord in ds,
57+
f"Coordinate {coord} for variable {var_name} does not exist",
58+
)
5059

5160

5261
def _check_required_dimensions(da: xr.DataArray, surface=False):
@@ -60,12 +69,13 @@ def _check_required_dimensions(da: xr.DataArray, surface=False):
6069

6170

6271
def validate_obc_file(
63-
file: Path | xr.Dataset, variable_names: list, encoding_dict={}, surface_var="eta"
72+
file: Path | xr.Dataset, variable_names: list, encoding_dict=None, surface_var="eta"
6473
):
6574
"""Validate boundary condition file specifically (requires additional segment number validation)"""
75+
if encoding_dict is None:
76+
encoding_dict = {}
6677
ds = get_file(file)
6778

68-
# Check individual data variable specifications (nothing that starts with dz)
6979
print(
7080
"This function identifies variables by if they have the word 'segment' in the name and don't start with nz,dz,lon,lat."
7181
)
@@ -79,12 +89,12 @@ def validate_obc_file(
7989
)
8090
check(
8191
"segment" in var,
82-
f"Variable {var} does not end with a 3 digit number. OBC file variables must end with a number",
92+
f"Variable {var} does not contain 'segment'. OBC file variables must include 'segment'",
8393
)
8494

8595
# Add encodings
8696
if var in encoding_dict:
87-
for key, value in encoding_dict[var].item():
97+
for key, value in encoding_dict[var].items():
8898
ds[var].attrs[key] = value
8999

90100
# Check if there is a non-NaN fill value
@@ -94,7 +104,7 @@ def validate_obc_file(
94104
_check_coordinates(ds, var_name=var)
95105

96106
# Check the correct number of dimensions
97-
_check_required_dimensions(ds[var], surface=(var == surface_var)) # just two
107+
_check_required_dimensions(ds[var], surface=(var == surface_var))
98108

99109
# Check for thickness variable
100110
if var != surface_var:
@@ -105,4 +115,4 @@ def validate_obc_file(
105115

106116

107117
def ends_with_3_digits(s: str) -> bool:
108-
return bool(re.search(r"\d{3}$", s))
118+
return bool(re.search(r"_\d{3}$", s))

tests/test_validate.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""
2+
Test suite for the validate module
3+
"""
4+
5+
import pytest
6+
import xarray as xr
7+
import numpy as np
8+
from pathlib import Path
9+
from regional_mom6.validate import (
10+
_check_fill_value,
11+
_check_coordinates,
12+
_check_required_dimensions,
13+
ends_with_3_digits,
14+
validate_obc_file,
15+
)
16+
17+
18+
# _check_fill_value tests
19+
def test_check_fill_value_valid(caplog):
20+
"""DataArray with valid fill value logs no warnings"""
21+
da = xr.DataArray(
22+
[1, 2, 3], dims="x", name="temperature", attrs={"_FillValue": -999.0}
23+
)
24+
_check_fill_value(da)
25+
assert "FillValue" not in caplog.text
26+
27+
28+
def test_check_fill_value_missing(caplog):
29+
"""DataArray without _FillValue attribute logs warning"""
30+
da = xr.DataArray([1, 2, 3], dims="x", name="temperature")
31+
_check_fill_value(da)
32+
assert "FillValue" in caplog.text
33+
34+
35+
def test_check_fill_value_nan(caplog):
36+
"""DataArray with NaN fill value logs warning"""
37+
da = xr.DataArray(
38+
[1, 2, 3], dims="x", name="temperature", attrs={"_FillValue": np.nan}
39+
)
40+
_check_fill_value(da)
41+
assert "NaN" in caplog.text
42+
43+
44+
# _check_coordinates tests
45+
def test_check_coordinates_valid(caplog):
46+
"""DataArray with valid coordinates attribute logs no warnings"""
47+
ds = xr.Dataset(
48+
{
49+
"temperature": (["x", "y"], np.random.rand(3, 4)),
50+
"lon": (["x", "y"], np.random.rand(3, 4)),
51+
"lat": (["x", "y"], np.random.rand(3, 4)),
52+
}
53+
)
54+
ds["temperature"].attrs["coordinates"] = "lon lat"
55+
_check_coordinates(ds, "temperature")
56+
assert "coordinate" not in caplog.text.lower()
57+
58+
59+
def test_check_coordinates_missing_attribute(caplog):
60+
"""DataArray without coordinates attribute logs warning"""
61+
ds = xr.Dataset({"temperature": (["x", "y"], np.random.rand(3, 4))})
62+
_check_coordinates(ds, "temperature")
63+
assert "coordinates" in caplog.text.lower()
64+
65+
66+
def test_check_coordinates_missing_in_dataset(caplog):
67+
"""Missing coordinate variable logs warning"""
68+
ds = xr.Dataset({"temperature": (["x", "y"], np.random.rand(3, 4))})
69+
ds["temperature"].attrs["coordinates"] = "lon lat"
70+
_check_coordinates(ds, "temperature")
71+
assert "does not exist" in caplog.text
72+
73+
74+
# _check_required_dimensions tests
75+
def test_check_required_dimensions_valid_4d(caplog):
76+
"""4D variable passes check when surface=False"""
77+
da = xr.DataArray(
78+
np.random.rand(2, 3, 4, 5), dims=["time", "z", "x", "y"], name="temperature"
79+
)
80+
_check_required_dimensions(da, surface=False)
81+
assert "dimension" not in caplog.text.lower()
82+
83+
84+
def test_check_required_dimensions_invalid_3d_for_4d(caplog):
85+
"""3D variable fails check when surface=False"""
86+
da = xr.DataArray(np.random.rand(3, 4, 5), dims=["x", "y", "z"], name="temperature")
87+
_check_required_dimensions(da, surface=False)
88+
assert "dimension" in caplog.text.lower()
89+
90+
91+
def test_check_required_dimensions_valid_3d_surface(caplog):
92+
"""3D variable passes check when surface=True"""
93+
da = xr.DataArray(np.random.rand(2, 3, 4), dims=["time", "x", "y"], name="eta")
94+
_check_required_dimensions(da, surface=True)
95+
assert "dimension" not in caplog.text.lower()
96+
97+
98+
def test_check_required_dimensions_invalid_4d_for_surface(caplog):
99+
"""4D variable fails check when surface=True"""
100+
da = xr.DataArray(
101+
np.random.rand(2, 3, 4, 5), dims=["time", "z", "x", "y"], name="eta"
102+
)
103+
_check_required_dimensions(da, surface=True)
104+
assert "dimension" in caplog.text.lower()
105+
106+
107+
# ends_with_3_digits tests
108+
def test_ends_with_3_digits_valid_cases():
109+
"""String ending with 3 digits returns True"""
110+
assert ends_with_3_digits("temp_001") is True
111+
assert ends_with_3_digits("var_999") is True
112+
assert ends_with_3_digits("_000") is True
113+
assert ends_with_3_digits("temp_01") is False
114+
assert ends_with_3_digits("temp_0001") is False
115+
assert ends_with_3_digits("temp_abc") is False
116+
assert ends_with_3_digits("temp") is False
117+
118+
119+
# validate_obc_file tests
120+
121+
122+
def test_validate_obc_file_valid(caplog):
123+
"""Valid OBC file with all required attributes passes"""
124+
ds = xr.Dataset(
125+
{
126+
"temp_segment_001": (["time", "z", "x", "y"], np.random.rand(2, 3, 4, 5)),
127+
"dz_temp_segment_001": (
128+
["time", "z", "x", "y"],
129+
np.random.rand(2, 3, 4, 5),
130+
),
131+
"eta_segment_001": (["time", "x", "y"], np.random.rand(2, 4, 5)),
132+
"lon": (["x", "y"], np.random.rand(4, 5)),
133+
"lat": (["x", "y"], np.random.rand(4, 5)),
134+
}
135+
)
136+
137+
for var in ds.data_vars:
138+
ds[var].attrs["_FillValue"] = -999.0
139+
ds[var].attrs["coordinates"] = "lon lat"
140+
141+
validate_obc_file(ds, ["temp_segment_001"], surface_var="eta_segment_001")
142+
143+
144+
def test_validate_obc_file_issues(caplog):
145+
"""OBC file with missing segment and thickness logs warnings"""
146+
ds = xr.Dataset(
147+
{
148+
"temp_001": (["time", "z", "x", "y"], np.random.rand(2, 3, 4, 5)),
149+
"lon": (["x", "y"], np.random.rand(4, 5)),
150+
"lat": (["x", "y"], np.random.rand(4, 5)),
151+
}
152+
)
153+
ds["temp_001"].attrs["_FillValue"] = -999.0
154+
ds["temp_001"].attrs["coordinates"] = "lon lat"
155+
156+
validate_obc_file(ds, ["temp_001"])
157+
assert "segment" in caplog.text
158+
assert "thickness" in caplog.text or "dz_temp_001" in caplog.text
159+
160+
161+
def test_validate_obc_file_encoding_dict():
162+
"""Encoding dict is applied to variables"""
163+
ds = xr.Dataset(
164+
{
165+
"temp_segment_001": (["time", "z", "x", "y"], np.random.rand(2, 3, 4, 5)),
166+
"dz_temp_segment_001": (
167+
["time", "z", "x", "y"],
168+
np.random.rand(2, 3, 4, 5),
169+
),
170+
"lon": (["x", "y"], np.random.rand(4, 5)),
171+
"lat": (["x", "y"], np.random.rand(4, 5)),
172+
}
173+
)
174+
ds["temp_segment_001"].attrs["_FillValue"] = -999.0
175+
ds["temp_segment_001"].attrs["coordinates"] = "lon lat"
176+
ds["dz_temp_segment_001"].attrs["_FillValue"] = -999.0
177+
178+
encoding_dict = {"temp_segment_001": {"units": "celsius"}}
179+
validate_obc_file(ds, ["temp_segment_001"], encoding_dict=encoding_dict)
180+
181+
assert ds["temp_segment_001"].attrs["units"] == "celsius"

0 commit comments

Comments
 (0)