diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index f522815a5..46b9669f9 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -43,7 +43,12 @@ HAS_PYRSISTENT = True except ImportError: HAS_PYRSISTENT = False +try: + import torch + HAS_TORCH = True +except ImportError: + HAS_TORCH = False def comparator(orig: Any, new: Any, superset_obj=False) -> bool: """Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent.""" @@ -163,6 +168,18 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: except Exception: pass + if HAS_TORCH and isinstance(orig, torch.Tensor): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + if orig.requires_grad != new.requires_grad: + return False + if orig.device != new.device: + return False + return torch.allclose(orig, new, equal_nan=True) + + if HAS_PYRSISTENT and isinstance( orig, ( diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 0fc292f09..d10a48d58 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -532,6 +532,94 @@ class TestClass(PClass): assert not comparator(v, x) +def test_torch(): + try: + import torch # type: ignore + except ImportError: + pytest.skip() + + a = torch.tensor([1, 2, 3]) + b = torch.tensor([1, 2, 3]) + c = torch.tensor([1, 2, 4]) + assert comparator(a, b) + assert not comparator(a, c) + + d = torch.tensor([[1, 2, 3], [4, 5, 6]]) + e = torch.tensor([[1, 2, 3], [4, 5, 6]]) + f = torch.tensor([[1, 2, 3], [4, 5, 7]]) + assert comparator(d, e) + assert not comparator(d, f) + + # Test tensors with different data types + g = torch.tensor([1, 2, 3], dtype=torch.float32) + h = torch.tensor([1, 2, 3], dtype=torch.float32) + i = torch.tensor([1, 2, 3], dtype=torch.int64) + assert comparator(g, h) + assert not comparator(g, i) + + # Test 3D tensors + j = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + k = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + l = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 9]]]) + assert comparator(j, k) + assert not comparator(j, l) + + # Test tensors with different shapes + m = torch.tensor([1, 2, 3]) + n = torch.tensor([[1, 2, 3]]) + assert not comparator(m, n) + + # Test empty tensors + o = torch.tensor([]) + p = torch.tensor([]) + q = torch.tensor([1]) + assert comparator(o, p) + assert not comparator(o, q) + + # Test tensors with NaN values + r = torch.tensor([1.0, float('nan'), 3.0]) + s = torch.tensor([1.0, float('nan'), 3.0]) + t = torch.tensor([1.0, 2.0, 3.0]) + assert comparator(r, s) # NaN == NaN + assert not comparator(r, t) + + # Test tensors with infinity values + u = torch.tensor([1.0, float('inf'), 3.0]) + v = torch.tensor([1.0, float('inf'), 3.0]) + w = torch.tensor([1.0, float('-inf'), 3.0]) + assert comparator(u, v) + assert not comparator(u, w) + + # Test tensors with different devices (if CUDA is available) + if torch.cuda.is_available(): + x = torch.tensor([1, 2, 3]).cuda() + y = torch.tensor([1, 2, 3]).cuda() + z = torch.tensor([1, 2, 3]) + assert comparator(x, y) + assert not comparator(x, z) + + # Test tensors with requires_grad + aa = torch.tensor([1., 2., 3.], requires_grad=True) + bb = torch.tensor([1., 2., 3.], requires_grad=True) + cc = torch.tensor([1., 2., 3.], requires_grad=False) + assert comparator(aa, bb) + assert not comparator(aa, cc) + + # Test complex tensors + dd = torch.tensor([1+2j, 3+4j]) + ee = torch.tensor([1+2j, 3+4j]) + ff = torch.tensor([1+2j, 3+5j]) + assert comparator(dd, ee) + assert not comparator(dd, ff) + + # Test boolean tensors + gg = torch.tensor([True, False, True]) + hh = torch.tensor([True, False, True]) + ii = torch.tensor([True, True, True]) + assert comparator(gg, hh) + assert not comparator(gg, ii) + + def test_returns(): a = Success(5) b = Success(5)