diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 7737900df..704d19b3c 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -223,6 +223,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields) + # Handle np.dtype instances (including numpy.dtypes.* classes like Float64DType, Int64DType, etc.) + if isinstance(orig, np.dtype): + return orig == new + if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): if orig.dtype != new.dtype: return False diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 06cc1180c..aa556db32 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -2422,4 +2422,64 @@ def test_numpy_0d_array() -> None: x = np.array(5) y = np.array([5]) # Different shapes - assert not comparator(x, y) \ No newline at end of file + assert not comparator(x, y) + +def test_numpy_dtypes() -> None: + """Test comparator for numpy.dtypes types like Float64DType, Int64DType, etc.""" + try: + import numpy as np + import numpy.dtypes as dtypes + except ImportError: + pytest.skip("numpy not available") + + # Test Float64DType + a = dtypes.Float64DType() + b = dtypes.Float64DType() + assert comparator(a, b) + + # Test Int64DType + c = dtypes.Int64DType() + d = dtypes.Int64DType() + assert comparator(c, d) + + # Test different DType classes should not be equal + assert not comparator(a, c) # Float64DType vs Int64DType + + # Test various numeric DType classes + assert comparator(dtypes.Int8DType(), dtypes.Int8DType()) + assert comparator(dtypes.Int16DType(), dtypes.Int16DType()) + assert comparator(dtypes.Int32DType(), dtypes.Int32DType()) + assert comparator(dtypes.UInt8DType(), dtypes.UInt8DType()) + assert comparator(dtypes.UInt16DType(), dtypes.UInt16DType()) + assert comparator(dtypes.UInt32DType(), dtypes.UInt32DType()) + assert comparator(dtypes.UInt64DType(), dtypes.UInt64DType()) + assert comparator(dtypes.Float32DType(), dtypes.Float32DType()) + assert comparator(dtypes.Complex64DType(), dtypes.Complex64DType()) + assert comparator(dtypes.Complex128DType(), dtypes.Complex128DType()) + assert comparator(dtypes.BoolDType(), dtypes.BoolDType()) + + # Test cross-type comparisons should be False + assert not comparator(dtypes.Int32DType(), dtypes.Int64DType()) + assert not comparator(dtypes.Float32DType(), dtypes.Float64DType()) + assert not comparator(dtypes.UInt32DType(), dtypes.Int32DType()) + + # Test regular np.dtype instances + e = np.dtype('float64') + f = np.dtype('float64') + assert comparator(e, f) + + g = np.dtype('int64') + h = np.dtype('int64') + assert comparator(g, h) + + assert not comparator(e, g) # float64 vs int64 + + # Test DType class instances vs regular np.dtype (they should be equal if same underlying type) + assert comparator(dtypes.Float64DType(), np.dtype('float64')) + assert comparator(dtypes.Int64DType(), np.dtype('int64')) + assert comparator(dtypes.Int32DType(), np.dtype('int32')) + assert comparator(dtypes.BoolDType(), np.dtype('bool')) + + # Test that DType and np.dtype of different types are not equal + assert not comparator(dtypes.Float64DType(), np.dtype('int64')) + assert not comparator(dtypes.Int32DType(), np.dtype('float32'))