Skip to content

Commit c2da811

Browse files
committed
more rigorous
1 parent 47e8dc5 commit c2da811

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

codeflash/verification/comparator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,13 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
173173
return False
174174
if orig.shape != new.shape:
175175
return False
176+
if orig.requires_grad != new.requires_grad:
177+
return False
178+
if orig.device != new.device:
179+
return False
176180
return torch.allclose(orig, new, equal_nan=True)
177181

182+
178183
if HAS_PYRSISTENT and isinstance(
179184
orig,
180185
(

tests/test_comparator.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,43 @@ def test_torch():
583583
assert comparator(r, s) # NaN == NaN
584584
assert not comparator(r, t)
585585

586+
# Test tensors with infinity values
587+
u = torch.tensor([1.0, float('inf'), 3.0])
588+
v = torch.tensor([1.0, float('inf'), 3.0])
589+
w = torch.tensor([1.0, float('-inf'), 3.0])
590+
assert comparator(u, v)
591+
assert not comparator(u, w)
592+
593+
# Test tensors with different devices (if CUDA is available)
594+
if torch.cuda.is_available():
595+
x = torch.tensor([1, 2, 3]).cuda()
596+
y = torch.tensor([1, 2, 3]).cuda()
597+
z = torch.tensor([1, 2, 3])
598+
assert comparator(x, y)
599+
assert not comparator(x, z)
600+
601+
# Test tensors with requires_grad
602+
aa = torch.tensor([1., 2., 3.], requires_grad=True)
603+
bb = torch.tensor([1., 2., 3.], requires_grad=True)
604+
cc = torch.tensor([1., 2., 3.], requires_grad=False)
605+
assert comparator(aa, bb)
606+
assert not comparator(aa, cc)
607+
608+
# Test complex tensors
609+
dd = torch.tensor([1+2j, 3+4j])
610+
ee = torch.tensor([1+2j, 3+4j])
611+
ff = torch.tensor([1+2j, 3+5j])
612+
assert comparator(dd, ee)
613+
assert not comparator(dd, ff)
614+
615+
# Test boolean tensors
616+
gg = torch.tensor([True, False, True])
617+
hh = torch.tensor([True, False, True])
618+
ii = torch.tensor([True, True, True])
619+
assert comparator(gg, hh)
620+
assert not comparator(gg, ii)
621+
622+
586623
def test_returns():
587624
a = Success(5)
588625
b = Success(5)

0 commit comments

Comments
 (0)