diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 3881a6888..50a2fe33b 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -61,6 +61,13 @@ except ImportError: HAS_JAX = False +try: + import xarray # type: ignore + + HAS_XARRAY = True +except ImportError: + HAS_XARRAY = False + def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 """Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent.""" @@ -123,6 +130,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return bool(jnp.allclose(orig, new, equal_nan=True)) + # Handle xarray objects before numpy to avoid boolean context errors + if HAS_XARRAY and isinstance(orig, (xarray.Dataset, xarray.DataArray)): + return orig.identical(new) + if HAS_SQLALCHEMY: try: insp = sqlalchemy.inspection.inspect(orig) diff --git a/tests/test_comparator.py b/tests/test_comparator.py index f3f14b86c..06d178f95 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -787,6 +787,126 @@ def test_jax(): assert not comparator(aa, cc) +def test_xarray(): + try: + import xarray as xr + import numpy as np + except ImportError: + pytest.skip() + + # Test basic DataArray + a = xr.DataArray([1, 2, 3], dims=['x']) + b = xr.DataArray([1, 2, 3], dims=['x']) + c = xr.DataArray([1, 2, 4], dims=['x']) + assert comparator(a, b) + assert not comparator(a, c) + + # Test DataArray with coordinates + d = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x']) + e = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x']) + f = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 3]}, dims=['x']) + assert comparator(d, e) + assert not comparator(d, f) + + # Test DataArray with attributes + g = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'}) + h = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'}) + i = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'feet'}) + assert comparator(g, h) + assert not comparator(g, i) + + # Test 2D DataArray + j = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y']) + k = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y']) + l = xr.DataArray([[1, 2, 3], [4, 5, 7]], dims=['x', 'y']) + assert comparator(j, k) + assert not comparator(j, l) + + # Test DataArray with different dimensions + m = xr.DataArray([1, 2, 3], dims=['x']) + n = xr.DataArray([1, 2, 3], dims=['y']) + assert not comparator(m, n) + + # Test DataArray with NaN values + o = xr.DataArray([1.0, np.nan, 3.0], dims=['x']) + p = xr.DataArray([1.0, np.nan, 3.0], dims=['x']) + q = xr.DataArray([1.0, 2.0, 3.0], dims=['x']) + assert comparator(o, p) + assert not comparator(o, q) + + # Test Dataset + r = xr.Dataset({ + 'temp': (['x', 'y'], [[1, 2], [3, 4]]), + 'pressure': (['x', 'y'], [[5, 6], [7, 8]]) + }) + s = xr.Dataset({ + 'temp': (['x', 'y'], [[1, 2], [3, 4]]), + 'pressure': (['x', 'y'], [[5, 6], [7, 8]]) + }) + t = xr.Dataset({ + 'temp': (['x', 'y'], [[1, 2], [3, 4]]), + 'pressure': (['x', 'y'], [[5, 6], [7, 9]]) + }) + assert comparator(r, s) + assert not comparator(r, t) + + # Test Dataset with coordinates + u = xr.Dataset({ + 'temp': (['x', 'y'], [[1, 2], [3, 4]]) + }, coords={'x': [0, 1], 'y': [0, 1]}) + v = xr.Dataset({ + 'temp': (['x', 'y'], [[1, 2], [3, 4]]) + }, coords={'x': [0, 1], 'y': [0, 1]}) + w = xr.Dataset({ + 'temp': (['x', 'y'], [[1, 2], [3, 4]]) + }, coords={'x': [0, 2], 'y': [0, 1]}) + assert comparator(u, v) + assert not comparator(u, w) + + # Test Dataset with attributes + x = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'}) + y = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'}) + z = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'model'}) + assert comparator(x, y) + assert not comparator(x, z) + + # Test Dataset with different variables + aa = xr.Dataset({'temp': (['x'], [1, 2, 3])}) + bb = xr.Dataset({'temp': (['x'], [1, 2, 3])}) + cc = xr.Dataset({'pressure': (['x'], [1, 2, 3])}) + assert comparator(aa, bb) + assert not comparator(aa, cc) + + # Test empty Dataset + dd = xr.Dataset() + ee = xr.Dataset() + assert comparator(dd, ee) + + # Test DataArray with different shapes + ff = xr.DataArray([1, 2, 3], dims=['x']) + gg = xr.DataArray([[1, 2, 3]], dims=['x', 'y']) + assert not comparator(ff, gg) + + # Test DataArray with different data types + # Note: xarray.identical() considers int and float arrays with same values as identical + hh = xr.DataArray(np.array([1, 2, 3], dtype='int32'), dims=['x']) + ii = xr.DataArray(np.array([1, 2, 3], dtype='int64'), dims=['x']) + # xarray is permissive with dtype comparisons, treats these as identical + assert comparator(hh, ii) + + # Test DataArray with infinity + jj = xr.DataArray([1.0, np.inf, 3.0], dims=['x']) + kk = xr.DataArray([1.0, np.inf, 3.0], dims=['x']) + ll = xr.DataArray([1.0, -np.inf, 3.0], dims=['x']) + assert comparator(jj, kk) + assert not comparator(jj, ll) + + # Test Dataset vs DataArray (different types) + mm = xr.DataArray([1, 2, 3], dims=['x']) + nn = xr.Dataset({'data': (['x'], [1, 2, 3])}) + assert not comparator(mm, nn) + + def test_returns(): a = Success(5) b = Success(5)