Skip to content

Commit 8c77424

Browse files
Merge pull request #788 from codeflash-ai/xarray-comparator
xarray comparator
2 parents f26df43 + bfa6eac commit 8c77424

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed

codeflash/verification/comparator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@
6161
except ImportError:
6262
HAS_JAX = False
6363

64+
try:
65+
import xarray # type: ignore
66+
67+
HAS_XARRAY = True
68+
except ImportError:
69+
HAS_XARRAY = False
70+
6471

6572
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
6673
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
@@ -123,6 +130,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
123130
return False
124131
return bool(jnp.allclose(orig, new, equal_nan=True))
125132

133+
# Handle xarray objects before numpy to avoid boolean context errors
134+
if HAS_XARRAY and isinstance(orig, (xarray.Dataset, xarray.DataArray)):
135+
return orig.identical(new)
136+
126137
if HAS_SQLALCHEMY:
127138
try:
128139
insp = sqlalchemy.inspection.inspect(orig)

tests/test_comparator.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,126 @@ def test_jax():
787787
assert not comparator(aa, cc)
788788

789789

790+
def test_xarray():
791+
try:
792+
import xarray as xr
793+
import numpy as np
794+
except ImportError:
795+
pytest.skip()
796+
797+
# Test basic DataArray
798+
a = xr.DataArray([1, 2, 3], dims=['x'])
799+
b = xr.DataArray([1, 2, 3], dims=['x'])
800+
c = xr.DataArray([1, 2, 4], dims=['x'])
801+
assert comparator(a, b)
802+
assert not comparator(a, c)
803+
804+
# Test DataArray with coordinates
805+
d = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x'])
806+
e = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x'])
807+
f = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 3]}, dims=['x'])
808+
assert comparator(d, e)
809+
assert not comparator(d, f)
810+
811+
# Test DataArray with attributes
812+
g = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'})
813+
h = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'})
814+
i = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'feet'})
815+
assert comparator(g, h)
816+
assert not comparator(g, i)
817+
818+
# Test 2D DataArray
819+
j = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y'])
820+
k = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y'])
821+
l = xr.DataArray([[1, 2, 3], [4, 5, 7]], dims=['x', 'y'])
822+
assert comparator(j, k)
823+
assert not comparator(j, l)
824+
825+
# Test DataArray with different dimensions
826+
m = xr.DataArray([1, 2, 3], dims=['x'])
827+
n = xr.DataArray([1, 2, 3], dims=['y'])
828+
assert not comparator(m, n)
829+
830+
# Test DataArray with NaN values
831+
o = xr.DataArray([1.0, np.nan, 3.0], dims=['x'])
832+
p = xr.DataArray([1.0, np.nan, 3.0], dims=['x'])
833+
q = xr.DataArray([1.0, 2.0, 3.0], dims=['x'])
834+
assert comparator(o, p)
835+
assert not comparator(o, q)
836+
837+
# Test Dataset
838+
r = xr.Dataset({
839+
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
840+
'pressure': (['x', 'y'], [[5, 6], [7, 8]])
841+
})
842+
s = xr.Dataset({
843+
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
844+
'pressure': (['x', 'y'], [[5, 6], [7, 8]])
845+
})
846+
t = xr.Dataset({
847+
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
848+
'pressure': (['x', 'y'], [[5, 6], [7, 9]])
849+
})
850+
assert comparator(r, s)
851+
assert not comparator(r, t)
852+
853+
# Test Dataset with coordinates
854+
u = xr.Dataset({
855+
'temp': (['x', 'y'], [[1, 2], [3, 4]])
856+
}, coords={'x': [0, 1], 'y': [0, 1]})
857+
v = xr.Dataset({
858+
'temp': (['x', 'y'], [[1, 2], [3, 4]])
859+
}, coords={'x': [0, 1], 'y': [0, 1]})
860+
w = xr.Dataset({
861+
'temp': (['x', 'y'], [[1, 2], [3, 4]])
862+
}, coords={'x': [0, 2], 'y': [0, 1]})
863+
assert comparator(u, v)
864+
assert not comparator(u, w)
865+
866+
# Test Dataset with attributes
867+
x = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'})
868+
y = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'})
869+
z = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'model'})
870+
assert comparator(x, y)
871+
assert not comparator(x, z)
872+
873+
# Test Dataset with different variables
874+
aa = xr.Dataset({'temp': (['x'], [1, 2, 3])})
875+
bb = xr.Dataset({'temp': (['x'], [1, 2, 3])})
876+
cc = xr.Dataset({'pressure': (['x'], [1, 2, 3])})
877+
assert comparator(aa, bb)
878+
assert not comparator(aa, cc)
879+
880+
# Test empty Dataset
881+
dd = xr.Dataset()
882+
ee = xr.Dataset()
883+
assert comparator(dd, ee)
884+
885+
# Test DataArray with different shapes
886+
ff = xr.DataArray([1, 2, 3], dims=['x'])
887+
gg = xr.DataArray([[1, 2, 3]], dims=['x', 'y'])
888+
assert not comparator(ff, gg)
889+
890+
# Test DataArray with different data types
891+
# Note: xarray.identical() considers int and float arrays with same values as identical
892+
hh = xr.DataArray(np.array([1, 2, 3], dtype='int32'), dims=['x'])
893+
ii = xr.DataArray(np.array([1, 2, 3], dtype='int64'), dims=['x'])
894+
# xarray is permissive with dtype comparisons, treats these as identical
895+
assert comparator(hh, ii)
896+
897+
# Test DataArray with infinity
898+
jj = xr.DataArray([1.0, np.inf, 3.0], dims=['x'])
899+
kk = xr.DataArray([1.0, np.inf, 3.0], dims=['x'])
900+
ll = xr.DataArray([1.0, -np.inf, 3.0], dims=['x'])
901+
assert comparator(jj, kk)
902+
assert not comparator(jj, ll)
903+
904+
# Test Dataset vs DataArray (different types)
905+
mm = xr.DataArray([1, 2, 3], dims=['x'])
906+
nn = xr.Dataset({'data': (['x'], [1, 2, 3])})
907+
assert not comparator(mm, nn)
908+
909+
790910
def test_returns():
791911
a = Success(5)
792912
b = Success(5)

0 commit comments

Comments
 (0)