Skip to content

Commit 620ba9c

Browse files
committed
add array.array to comparator
1 parent 3dd4459 commit 620ba9c

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

codeflash/verification/comparator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import array
12
import ast
23
import datetime
34
import decimal
@@ -170,6 +171,15 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
170171
if HAS_PANDAS and pandas.isna(orig) and pandas.isna(new):
171172
return True
172173

174+
if isinstance(orig, array.array):
175+
if not isinstance(new, array.array):
176+
return False
177+
if orig.typecode != new.typecode:
178+
return False
179+
if len(orig) != len(new):
180+
return False
181+
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
182+
173183
# This should be at the end of all numpy checking
174184
try:
175185
if HAS_NUMPY and np.isnan(orig):

tests/test_comparator.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
from enum import Enum, Flag, IntFlag, auto
99
from pathlib import Path
10+
import array # Add import for array
1011

1112
import pydantic
1213
import pytest
@@ -203,6 +204,32 @@ class Color4(IntFlag):
203204
assert not comparator(a, c)
204205
assert not comparator(a, d)
205206

207+
arr1 = array.array('i', [1, 2, 3])
208+
arr2 = array.array('i', [1, 2, 3])
209+
arr3 = array.array('i', [4, 5, 6])
210+
arr4 = array.array('f', [1.0, 2.0, 3.0])
211+
212+
assert comparator(arr1, arr2)
213+
assert not comparator(arr1, arr3)
214+
assert not comparator(arr1, arr4)
215+
assert not comparator(arr1, [1, 2, 3])
216+
217+
empty_arr_i1 = array.array('i')
218+
empty_arr_i2 = array.array('i')
219+
empty_arr_f = array.array('f')
220+
assert comparator(empty_arr_i1, empty_arr_i2)
221+
assert not comparator(empty_arr_i1, empty_arr_f)
222+
assert not comparator(empty_arr_i1, arr1)
223+
224+
arr_b = array.array('b', [1, 2, 3])
225+
arr_h = array.array('h', [1, 2, 3])
226+
arr_l = array.array('l', [1, 2, 3])
227+
assert comparator(arr_b, arr_h)
228+
assert comparator(arr_h, arr_l)
229+
assert not comparator(arr_b, arr_h)
230+
assert not comparator(arr_h, arr_l)
231+
232+
206233

207234
def test_numpy():
208235
try:

0 commit comments

Comments
 (0)