Skip to content

Commit b1c43da

Browse files
authored
Merge pull request #1007 from codeflash-ai/comparator-numpy
Fix: add support for numpy.dtypes for the comparator
2 parents 50a9d07 + 6dc53c1 commit b1c43da

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

codeflash/verification/comparator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
223223
return False
224224
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
225225

226+
# Handle np.dtype instances (including numpy.dtypes.* classes like Float64DType, Int64DType, etc.)
227+
if isinstance(orig, np.dtype):
228+
return orig == new
229+
226230
if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
227231
if orig.dtype != new.dtype:
228232
return False

tests/test_comparator.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2422,4 +2422,64 @@ def test_numpy_0d_array() -> None:
24222422
x = np.array(5)
24232423
y = np.array([5])
24242424
# Different shapes
2425-
assert not comparator(x, y)
2425+
assert not comparator(x, y)
2426+
2427+
def test_numpy_dtypes() -> None:
2428+
"""Test comparator for numpy.dtypes types like Float64DType, Int64DType, etc."""
2429+
try:
2430+
import numpy as np
2431+
import numpy.dtypes as dtypes
2432+
except ImportError:
2433+
pytest.skip("numpy not available")
2434+
2435+
# Test Float64DType
2436+
a = dtypes.Float64DType()
2437+
b = dtypes.Float64DType()
2438+
assert comparator(a, b)
2439+
2440+
# Test Int64DType
2441+
c = dtypes.Int64DType()
2442+
d = dtypes.Int64DType()
2443+
assert comparator(c, d)
2444+
2445+
# Test different DType classes should not be equal
2446+
assert not comparator(a, c) # Float64DType vs Int64DType
2447+
2448+
# Test various numeric DType classes
2449+
assert comparator(dtypes.Int8DType(), dtypes.Int8DType())
2450+
assert comparator(dtypes.Int16DType(), dtypes.Int16DType())
2451+
assert comparator(dtypes.Int32DType(), dtypes.Int32DType())
2452+
assert comparator(dtypes.UInt8DType(), dtypes.UInt8DType())
2453+
assert comparator(dtypes.UInt16DType(), dtypes.UInt16DType())
2454+
assert comparator(dtypes.UInt32DType(), dtypes.UInt32DType())
2455+
assert comparator(dtypes.UInt64DType(), dtypes.UInt64DType())
2456+
assert comparator(dtypes.Float32DType(), dtypes.Float32DType())
2457+
assert comparator(dtypes.Complex64DType(), dtypes.Complex64DType())
2458+
assert comparator(dtypes.Complex128DType(), dtypes.Complex128DType())
2459+
assert comparator(dtypes.BoolDType(), dtypes.BoolDType())
2460+
2461+
# Test cross-type comparisons should be False
2462+
assert not comparator(dtypes.Int32DType(), dtypes.Int64DType())
2463+
assert not comparator(dtypes.Float32DType(), dtypes.Float64DType())
2464+
assert not comparator(dtypes.UInt32DType(), dtypes.Int32DType())
2465+
2466+
# Test regular np.dtype instances
2467+
e = np.dtype('float64')
2468+
f = np.dtype('float64')
2469+
assert comparator(e, f)
2470+
2471+
g = np.dtype('int64')
2472+
h = np.dtype('int64')
2473+
assert comparator(g, h)
2474+
2475+
assert not comparator(e, g) # float64 vs int64
2476+
2477+
# Test DType class instances vs regular np.dtype (they should be equal if same underlying type)
2478+
assert comparator(dtypes.Float64DType(), np.dtype('float64'))
2479+
assert comparator(dtypes.Int64DType(), np.dtype('int64'))
2480+
assert comparator(dtypes.Int32DType(), np.dtype('int32'))
2481+
assert comparator(dtypes.BoolDType(), np.dtype('bool'))
2482+
2483+
# Test that DType and np.dtype of different types are not equal
2484+
assert not comparator(dtypes.Float64DType(), np.dtype('int64'))
2485+
assert not comparator(dtypes.Int32DType(), np.dtype('float32'))

0 commit comments

Comments
 (0)