Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@
except ImportError:
HAS_JAX = False

try:
import xarray # type: ignore

HAS_XARRAY = True
except ImportError:
HAS_XARRAY = False


def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
Expand Down Expand Up @@ -123,6 +130,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return False
return bool(jnp.allclose(orig, new, equal_nan=True))

# Handle xarray objects before numpy to avoid boolean context errors
if HAS_XARRAY and isinstance(orig, (xarray.Dataset, xarray.DataArray)):
return orig.identical(new)

if HAS_SQLALCHEMY:
try:
insp = sqlalchemy.inspection.inspect(orig)
Expand Down
120 changes: 120 additions & 0 deletions tests/test_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,126 @@ def test_jax():
assert not comparator(aa, cc)


def test_xarray():
try:
import xarray as xr
import numpy as np
except ImportError:
pytest.skip()

# Test basic DataArray
a = xr.DataArray([1, 2, 3], dims=['x'])
b = xr.DataArray([1, 2, 3], dims=['x'])
c = xr.DataArray([1, 2, 4], dims=['x'])
assert comparator(a, b)
assert not comparator(a, c)

# Test DataArray with coordinates
d = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x'])
e = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x'])
f = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 3]}, dims=['x'])
assert comparator(d, e)
assert not comparator(d, f)

# Test DataArray with attributes
g = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'})
h = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'})
i = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'feet'})
assert comparator(g, h)
assert not comparator(g, i)

# Test 2D DataArray
j = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y'])
k = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y'])
l = xr.DataArray([[1, 2, 3], [4, 5, 7]], dims=['x', 'y'])
assert comparator(j, k)
assert not comparator(j, l)

# Test DataArray with different dimensions
m = xr.DataArray([1, 2, 3], dims=['x'])
n = xr.DataArray([1, 2, 3], dims=['y'])
assert not comparator(m, n)

# Test DataArray with NaN values
o = xr.DataArray([1.0, np.nan, 3.0], dims=['x'])
p = xr.DataArray([1.0, np.nan, 3.0], dims=['x'])
q = xr.DataArray([1.0, 2.0, 3.0], dims=['x'])
assert comparator(o, p)
assert not comparator(o, q)

# Test Dataset
r = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
'pressure': (['x', 'y'], [[5, 6], [7, 8]])
})
s = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
'pressure': (['x', 'y'], [[5, 6], [7, 8]])
})
t = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
'pressure': (['x', 'y'], [[5, 6], [7, 9]])
})
assert comparator(r, s)
assert not comparator(r, t)

# Test Dataset with coordinates
u = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]])
}, coords={'x': [0, 1], 'y': [0, 1]})
v = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]])
}, coords={'x': [0, 1], 'y': [0, 1]})
w = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]])
}, coords={'x': [0, 2], 'y': [0, 1]})
assert comparator(u, v)
assert not comparator(u, w)

# Test Dataset with attributes
x = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'})
y = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'})
z = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'model'})
assert comparator(x, y)
assert not comparator(x, z)

# Test Dataset with different variables
aa = xr.Dataset({'temp': (['x'], [1, 2, 3])})
bb = xr.Dataset({'temp': (['x'], [1, 2, 3])})
cc = xr.Dataset({'pressure': (['x'], [1, 2, 3])})
assert comparator(aa, bb)
assert not comparator(aa, cc)

# Test empty Dataset
dd = xr.Dataset()
ee = xr.Dataset()
assert comparator(dd, ee)

# Test DataArray with different shapes
ff = xr.DataArray([1, 2, 3], dims=['x'])
gg = xr.DataArray([[1, 2, 3]], dims=['x', 'y'])
assert not comparator(ff, gg)

# Test DataArray with different data types
# Note: xarray.identical() considers int and float arrays with same values as identical
hh = xr.DataArray(np.array([1, 2, 3], dtype='int32'), dims=['x'])
ii = xr.DataArray(np.array([1, 2, 3], dtype='int64'), dims=['x'])
# xarray is permissive with dtype comparisons, treats these as identical
assert comparator(hh, ii)

# Test DataArray with infinity
jj = xr.DataArray([1.0, np.inf, 3.0], dims=['x'])
kk = xr.DataArray([1.0, np.inf, 3.0], dims=['x'])
ll = xr.DataArray([1.0, -np.inf, 3.0], dims=['x'])
assert comparator(jj, kk)
assert not comparator(jj, ll)

# Test Dataset vs DataArray (different types)
mm = xr.DataArray([1, 2, 3], dims=['x'])
nn = xr.Dataset({'data': (['x'], [1, 2, 3])})
assert not comparator(mm, nn)


def test_returns():
a = Success(5)
b = Success(5)
Expand Down
Loading