Skip to content

Commit b822509

Browse files
jp-darkihnorton
authored andcommitted
Add full check of attribute properties in __eq__ method
1 parent 22ff29a commit b822509

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

tiledb/attribute.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,44 @@ def __init__(
108108
def __eq__(self, other):
109109
if not isinstance(other, Attr):
110110
return False
111-
if self.name != other.name or self.dtype != other.dtype:
111+
if self.isnullable != other.isnullable or self.dtype != other.dtype:
112112
return False
113-
return True
113+
if not self.isnullable:
114+
# Check the fill values are equal.
115+
def equal_or_nan(x, y):
116+
return x == y or (np.isnan(x) and np.isnan(y))
117+
118+
if self.ncells == 1:
119+
if not equal_or_nan(self.fill, other.fill):
120+
return False
121+
elif np.issubdtype(self.dtype, np.bytes_) or np.issubdtype(
122+
self.dtype, np.str_
123+
):
124+
if self.fill != other.fill:
125+
return False
126+
elif self.dtype in {np.dtype("complex64"), np.dtype("complex128")}:
127+
if not (
128+
equal_or_nan(np.real(self.fill), np.real(other.fill))
129+
and equal_or_nan(np.imag(self.fill), np.imag(other.fill))
130+
):
131+
return False
132+
else:
133+
if not all(
134+
equal_or_nan(x, y)
135+
or (
136+
isinstance(x, str)
137+
and x.lower() == "nat"
138+
and isinstance(y, str)
139+
and y.lower() == "nat"
140+
)
141+
for x, y in zip(self.fill[0], other.fill[0])
142+
):
143+
return False
144+
return (
145+
self._internal_name == other._internal_name
146+
and self.isvar == other.isvar
147+
and self.filters == other.filters
148+
)
114149

115150
def dump(self):
116151
"""Dumps a string representation of the Attr object to standard output (stdout)"""

tiledb/tests/test_attribute.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
class AttributeTest(DiskTestCase):
1414
def test_minimal_attribute(self):
1515
attr = tiledb.Attr()
16+
self.assertEqual(attr, attr)
1617
self.assertTrue(attr.isanon)
1718
self.assertEqual(attr.name, "")
1819
self.assertEqual(attr.dtype, np.float_)
@@ -30,6 +31,7 @@ def test_attribute(self, capfd):
3031
attr.dump()
3132
assert_captured(capfd, "Name: foo")
3233

34+
assert attr == attr
3335
assert attr.name == "foo"
3436
assert attr.dtype == np.float64, "default attribute type is float64"
3537

@@ -46,6 +48,7 @@ def test_attribute(self, capfd):
4648
)
4749
def test_attribute_fill(self, dtype, fill):
4850
attr = tiledb.Attr("", dtype=dtype, fill=fill)
51+
assert attr == attr
4952
assert np.array(attr.fill, dtype=dtype) == np.array(fill, dtype=dtype)
5053

5154
path = self.path()
@@ -68,6 +71,7 @@ def test_full_attribute(self, capfd):
6871
attr.dump()
6972
assert_captured(capfd, "Name: foo")
7073

74+
self.assertEqual(attr, attr)
7175
self.assertEqual(attr.name, "foo")
7276
self.assertEqual(attr.dtype, np.int64)
7377
self.assertIsInstance(attr.filters[0], tiledb.ZstdFilter)
@@ -77,6 +81,7 @@ def test_ncell_attribute(self):
7781
dtype = np.dtype([("", np.int32), ("", np.int32), ("", np.int32)])
7882
attr = tiledb.Attr("foo", dtype=dtype)
7983

84+
self.assertEqual(attr, attr)
8085
self.assertEqual(attr.dtype, dtype)
8186
self.assertEqual(attr.ncells, 3)
8287

@@ -125,9 +130,27 @@ def test_two_cell_double_attribute(self, fill):
125130
assert attr.fill == attr.fill
126131
assert attr.ncells == 2
127132

133+
def test_ncell_double_attribute(self):
134+
dtype = np.dtype([("", np.double), ("", np.double), ("", np.double)])
135+
fill = np.array((0, np.nan, np.inf), dtype=dtype)
136+
attr = tiledb.Attr("foo", dtype=dtype, fill=fill)
137+
138+
self.assertEqual(attr, attr)
139+
self.assertEqual(attr.dtype, dtype)
140+
self.assertEqual(attr.ncells, 3)
141+
142+
def test_ncell_not_equal_fill_attribute(self):
143+
dtype = np.dtype([("", np.double), ("", np.double), ("", np.double)])
144+
fill1 = np.array((0, np.nan, np.inf), dtype=dtype)
145+
fill2 = np.array((np.nan, -1, np.inf), dtype=dtype)
146+
attr1 = tiledb.Attr("foo", dtype=dtype, fill=fill1)
147+
attr2 = tiledb.Attr("foo", dtype=dtype, fill=fill2)
148+
assert attr1 != attr2
149+
128150
def test_ncell_bytes_attribute(self):
129151
dtype = np.dtype((np.bytes_, 10))
130152
attr = tiledb.Attr("foo", dtype=dtype)
153+
self.assertEqual(attr, attr)
131154
self.assertEqual(attr.dtype, dtype)
132155
self.assertEqual(attr.ncells, 10)
133156

@@ -143,28 +166,34 @@ def test_bytes_var_attribute(self):
143166
self.assertTrue(attr.isvar)
144167

145168
attr = tiledb.Attr("foo", var=True, dtype="S")
169+
self.assertEqual(attr, attr)
146170
self.assertEqual(attr.dtype, np.dtype("S"))
147171
self.assertTrue(attr.isvar)
148172

149173
attr = tiledb.Attr("foo", var=False, dtype="S1")
174+
self.assertEqual(attr, attr)
150175
self.assertEqual(attr.dtype, np.dtype("S1"))
151176
self.assertFalse(attr.isvar)
152177

153178
attr = tiledb.Attr("foo", dtype="S1")
179+
self.assertEqual(attr, attr)
154180
self.assertEqual(attr.dtype, np.dtype("S1"))
155181
self.assertFalse(attr.isvar)
156182

157183
attr = tiledb.Attr("foo", dtype="S")
184+
self.assertEqual(attr, attr)
158185
self.assertEqual(attr.dtype, np.dtype("S"))
159186
self.assertTrue(attr.isvar)
160187

161188
def test_nullable_attribute(self):
162189
attr = tiledb.Attr("nullable", nullable=True, dtype=np.int32)
190+
self.assertEqual(attr, attr)
163191
self.assertEqual(attr.dtype, np.dtype(np.int32))
164192
self.assertTrue(attr.isnullable)
165193

166194
def test_datetime_attribute(self):
167195
attr = tiledb.Attr("foo", dtype=np.datetime64("", "D"))
196+
self.assertEqual(attr, attr)
168197
assert attr.dtype == np.dtype(np.datetime64("", "D"))
169198
assert attr.dtype != np.dtype(np.datetime64("", "Y"))
170199
assert attr.dtype != np.dtype(np.datetime64)

0 commit comments

Comments
 (0)