Skip to content

Commit 5ee8ad9

Browse files
Merge pull request #49 from codeflash-ai/pytorch-comparator
Write a comparator for PyTorch
2 parents 0f81492 + 89ca648 commit 5ee8ad9

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

codeflash/verification/comparator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@
4343
HAS_PYRSISTENT = True
4444
except ImportError:
4545
HAS_PYRSISTENT = False
46+
try:
47+
import torch
4648

49+
HAS_TORCH = True
50+
except ImportError:
51+
HAS_TORCH = False
4752

4853
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
4954
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
@@ -163,6 +168,18 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
163168
except Exception:
164169
pass
165170

171+
if HAS_TORCH and isinstance(orig, torch.Tensor):
172+
if orig.dtype != new.dtype:
173+
return False
174+
if orig.shape != new.shape:
175+
return False
176+
if orig.requires_grad != new.requires_grad:
177+
return False
178+
if orig.device != new.device:
179+
return False
180+
return torch.allclose(orig, new, equal_nan=True)
181+
182+
166183
if HAS_PYRSISTENT and isinstance(
167184
orig,
168185
(

tests/test_comparator.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,94 @@ class TestClass(PClass):
532532
assert not comparator(v, x)
533533

534534

535+
def test_torch():
536+
try:
537+
import torch # type: ignore
538+
except ImportError:
539+
pytest.skip()
540+
541+
a = torch.tensor([1, 2, 3])
542+
b = torch.tensor([1, 2, 3])
543+
c = torch.tensor([1, 2, 4])
544+
assert comparator(a, b)
545+
assert not comparator(a, c)
546+
547+
d = torch.tensor([[1, 2, 3], [4, 5, 6]])
548+
e = torch.tensor([[1, 2, 3], [4, 5, 6]])
549+
f = torch.tensor([[1, 2, 3], [4, 5, 7]])
550+
assert comparator(d, e)
551+
assert not comparator(d, f)
552+
553+
# Test tensors with different data types
554+
g = torch.tensor([1, 2, 3], dtype=torch.float32)
555+
h = torch.tensor([1, 2, 3], dtype=torch.float32)
556+
i = torch.tensor([1, 2, 3], dtype=torch.int64)
557+
assert comparator(g, h)
558+
assert not comparator(g, i)
559+
560+
# Test 3D tensors
561+
j = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
562+
k = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
563+
l = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
564+
assert comparator(j, k)
565+
assert not comparator(j, l)
566+
567+
# Test tensors with different shapes
568+
m = torch.tensor([1, 2, 3])
569+
n = torch.tensor([[1, 2, 3]])
570+
assert not comparator(m, n)
571+
572+
# Test empty tensors
573+
o = torch.tensor([])
574+
p = torch.tensor([])
575+
q = torch.tensor([1])
576+
assert comparator(o, p)
577+
assert not comparator(o, q)
578+
579+
# Test tensors with NaN values
580+
r = torch.tensor([1.0, float('nan'), 3.0])
581+
s = torch.tensor([1.0, float('nan'), 3.0])
582+
t = torch.tensor([1.0, 2.0, 3.0])
583+
assert comparator(r, s) # NaN == NaN
584+
assert not comparator(r, t)
585+
586+
# Test tensors with infinity values
587+
u = torch.tensor([1.0, float('inf'), 3.0])
588+
v = torch.tensor([1.0, float('inf'), 3.0])
589+
w = torch.tensor([1.0, float('-inf'), 3.0])
590+
assert comparator(u, v)
591+
assert not comparator(u, w)
592+
593+
# Test tensors with different devices (if CUDA is available)
594+
if torch.cuda.is_available():
595+
x = torch.tensor([1, 2, 3]).cuda()
596+
y = torch.tensor([1, 2, 3]).cuda()
597+
z = torch.tensor([1, 2, 3])
598+
assert comparator(x, y)
599+
assert not comparator(x, z)
600+
601+
# Test tensors with requires_grad
602+
aa = torch.tensor([1., 2., 3.], requires_grad=True)
603+
bb = torch.tensor([1., 2., 3.], requires_grad=True)
604+
cc = torch.tensor([1., 2., 3.], requires_grad=False)
605+
assert comparator(aa, bb)
606+
assert not comparator(aa, cc)
607+
608+
# Test complex tensors
609+
dd = torch.tensor([1+2j, 3+4j])
610+
ee = torch.tensor([1+2j, 3+4j])
611+
ff = torch.tensor([1+2j, 3+5j])
612+
assert comparator(dd, ee)
613+
assert not comparator(dd, ff)
614+
615+
# Test boolean tensors
616+
gg = torch.tensor([True, False, True])
617+
hh = torch.tensor([True, False, True])
618+
ii = torch.tensor([True, True, True])
619+
assert comparator(gg, hh)
620+
assert not comparator(gg, ii)
621+
622+
535623
def test_returns():
536624
a = Success(5)
537625
b = Success(5)

0 commit comments

Comments
 (0)