Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@
HAS_PYRSISTENT = True
except ImportError:
HAS_PYRSISTENT = False
try:
import torch

HAS_TORCH = True
except ImportError:
HAS_TORCH = False

def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
Expand Down Expand Up @@ -163,6 +168,18 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
except Exception:
pass

if HAS_TORCH and isinstance(orig, torch.Tensor):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
if orig.requires_grad != new.requires_grad:
return False
if orig.device != new.device:
return False
return torch.allclose(orig, new, equal_nan=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.allclose with default params doesn't work as expected for tensors with low magnitude - might want to default to just using rtol?



if HAS_PYRSISTENT and isinstance(
orig,
(
Expand Down
88 changes: 88 additions & 0 deletions tests/test_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,94 @@ class TestClass(PClass):
assert not comparator(v, x)


def test_torch():
try:
import torch # type: ignore
except ImportError:
pytest.skip()

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
c = torch.tensor([1, 2, 4])
assert comparator(a, b)
assert not comparator(a, c)

d = torch.tensor([[1, 2, 3], [4, 5, 6]])
e = torch.tensor([[1, 2, 3], [4, 5, 6]])
f = torch.tensor([[1, 2, 3], [4, 5, 7]])
assert comparator(d, e)
assert not comparator(d, f)

# Test tensors with different data types
g = torch.tensor([1, 2, 3], dtype=torch.float32)
h = torch.tensor([1, 2, 3], dtype=torch.float32)
i = torch.tensor([1, 2, 3], dtype=torch.int64)
assert comparator(g, h)
assert not comparator(g, i)

# Test 3D tensors
j = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
k = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
l = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
assert comparator(j, k)
assert not comparator(j, l)

# Test tensors with different shapes
m = torch.tensor([1, 2, 3])
n = torch.tensor([[1, 2, 3]])
assert not comparator(m, n)

# Test empty tensors
o = torch.tensor([])
p = torch.tensor([])
q = torch.tensor([1])
assert comparator(o, p)
assert not comparator(o, q)

# Test tensors with NaN values
r = torch.tensor([1.0, float('nan'), 3.0])
s = torch.tensor([1.0, float('nan'), 3.0])
t = torch.tensor([1.0, 2.0, 3.0])
assert comparator(r, s) # NaN == NaN
assert not comparator(r, t)

# Test tensors with infinity values
u = torch.tensor([1.0, float('inf'), 3.0])
v = torch.tensor([1.0, float('inf'), 3.0])
w = torch.tensor([1.0, float('-inf'), 3.0])
assert comparator(u, v)
assert not comparator(u, w)

# Test tensors with different devices (if CUDA is available)
if torch.cuda.is_available():
x = torch.tensor([1, 2, 3]).cuda()
y = torch.tensor([1, 2, 3]).cuda()
z = torch.tensor([1, 2, 3])
assert comparator(x, y)
assert not comparator(x, z)

# Test tensors with requires_grad
aa = torch.tensor([1., 2., 3.], requires_grad=True)
bb = torch.tensor([1., 2., 3.], requires_grad=True)
cc = torch.tensor([1., 2., 3.], requires_grad=False)
assert comparator(aa, bb)
assert not comparator(aa, cc)

# Test complex tensors
dd = torch.tensor([1+2j, 3+4j])
ee = torch.tensor([1+2j, 3+4j])
ff = torch.tensor([1+2j, 3+5j])
assert comparator(dd, ee)
assert not comparator(dd, ff)

# Test boolean tensors
gg = torch.tensor([True, False, True])
hh = torch.tensor([True, False, True])
ii = torch.tensor([True, True, True])
assert comparator(gg, hh)
assert not comparator(gg, ii)


def test_returns():
a = Success(5)
b = Success(5)
Expand Down
Loading