Skip to content

Commit 47e8dc5

Browse files
committed
Write a comparator for PyTorch
1 parent a651fbf commit 47e8dc5

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

codeflash/verification/comparator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@
4343
HAS_PYRSISTENT = True
4444
except ImportError:
4545
HAS_PYRSISTENT = False
46+
try:
47+
import torch
4648

49+
HAS_TORCH = True
50+
except ImportError:
51+
HAS_TORCH = False
4752

4853
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
4954
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
@@ -163,6 +168,13 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
163168
except Exception:
164169
pass
165170

171+
if HAS_TORCH and isinstance(orig, torch.Tensor):
172+
if orig.dtype != new.dtype:
173+
return False
174+
if orig.shape != new.shape:
175+
return False
176+
return torch.allclose(orig, new, equal_nan=True)
177+
166178
if HAS_PYRSISTENT and isinstance(
167179
orig,
168180
(

tests/test_comparator.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,57 @@ class TestClass(PClass):
532532
assert not comparator(v, x)
533533

534534

535+
def test_torch():
536+
try:
537+
import torch # type: ignore
538+
except ImportError:
539+
pytest.skip()
540+
541+
a = torch.tensor([1, 2, 3])
542+
b = torch.tensor([1, 2, 3])
543+
c = torch.tensor([1, 2, 4])
544+
assert comparator(a, b)
545+
assert not comparator(a, c)
546+
547+
d = torch.tensor([[1, 2, 3], [4, 5, 6]])
548+
e = torch.tensor([[1, 2, 3], [4, 5, 6]])
549+
f = torch.tensor([[1, 2, 3], [4, 5, 7]])
550+
assert comparator(d, e)
551+
assert not comparator(d, f)
552+
553+
# Test tensors with different data types
554+
g = torch.tensor([1, 2, 3], dtype=torch.float32)
555+
h = torch.tensor([1, 2, 3], dtype=torch.float32)
556+
i = torch.tensor([1, 2, 3], dtype=torch.int64)
557+
assert comparator(g, h)
558+
assert not comparator(g, i)
559+
560+
# Test 3D tensors
561+
j = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
562+
k = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
563+
l = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
564+
assert comparator(j, k)
565+
assert not comparator(j, l)
566+
567+
# Test tensors with different shapes
568+
m = torch.tensor([1, 2, 3])
569+
n = torch.tensor([[1, 2, 3]])
570+
assert not comparator(m, n)
571+
572+
# Test empty tensors
573+
o = torch.tensor([])
574+
p = torch.tensor([])
575+
q = torch.tensor([1])
576+
assert comparator(o, p)
577+
assert not comparator(o, q)
578+
579+
# Test tensors with NaN values
580+
r = torch.tensor([1.0, float('nan'), 3.0])
581+
s = torch.tensor([1.0, float('nan'), 3.0])
582+
t = torch.tensor([1.0, 2.0, 3.0])
583+
assert comparator(r, s) # NaN == NaN
584+
assert not comparator(r, t)
585+
535586
def test_returns():
536587
a = Success(5)
537588
b = Success(5)

0 commit comments

Comments
 (0)