Skip to content

Commit 969ae40

Browse files
authored
Cast enum values to enum dtype (#1854)
* For empty enumerations, if we do not explicitly cast, it is automatically casted to dtype of float64
1 parent 23818a7 commit 969ae40

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

tiledb/enumeration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,11 @@ def values(self) -> NDArray:
9090
:rtype: NDArray
9191
"""
9292
if self.dtype.kind == "U":
93-
return np.array(super().str_values())
93+
return np.array(super().str_values(), dtype=np.str_)
9494
elif self.dtype.kind == "S":
9595
return np.array(super().str_values(), dtype=np.bytes_)
9696
else:
97-
return super().values()
97+
return np.array(super().values(), dtype=self.dtype)
9898

9999
def extend(self, values: Sequence[Any]) -> Enumeration:
100100
"""Add additional values to the enumeration.

tiledb/tests/test_enumeration.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,43 @@ def test_array_schema_enumeration_nullable(self, sparse, pass_df):
116116
expected_validity = [False, False, True, False, False]
117117
assert_array_equal(A[:]["a"].mask, expected_validity)
118118
assert_array_equal(A.df[:]["a"].isna(), expected_validity)
119+
120+
@pytest.mark.parametrize(
121+
"dtype, values",
122+
[
123+
(np.int8, np.array([1, 2, 3], np.int8)),
124+
(np.uint8, np.array([1, 2, 3], np.uint8)),
125+
(np.int16, np.array([1, 2, 3], np.int16)),
126+
(np.uint16, np.array([1, 2, 3], np.uint16)),
127+
(np.int32, np.array([1, 2, 3], np.int32)),
128+
(np.uint32, np.array([1, 2, 3], np.uint32)),
129+
(np.int64, np.array([1, 2, 3], np.int64)),
130+
(np.uint64, np.array([1, 2, 3], np.uint64)),
131+
(np.dtype("S"), np.array(["a", "b", "c"], np.dtype("S"))),
132+
(np.dtype("U"), np.array(["a", "b", "c"], np.dtype("U"))),
133+
],
134+
)
135+
def test_enum_dtypes(self, dtype, values):
136+
# create empty
137+
enmr = tiledb.Enumeration("e", False, dtype=dtype)
138+
if dtype in (np.dtype("S"), np.dtype("U")):
139+
assert enmr.dtype.kind == enmr.values().dtype.kind == dtype.kind
140+
else:
141+
assert enmr.dtype == enmr.values().dtype == dtype
142+
assert_array_equal(enmr.values(), [])
143+
144+
# then extend with values
145+
enmr = enmr.extend(values)
146+
if dtype in (np.dtype("S"), np.dtype("U")):
147+
assert enmr.dtype.kind == enmr.values().dtype.kind == dtype.kind
148+
else:
149+
assert enmr.dtype == enmr.values().dtype == dtype
150+
assert_array_equal(enmr.values(), values)
151+
152+
# create with values
153+
enmr = tiledb.Enumeration("e", False, values=values)
154+
if dtype in (np.dtype("S"), np.dtype("U")):
155+
assert enmr.dtype.kind == enmr.values().dtype.kind == dtype.kind
156+
else:
157+
assert enmr.dtype == enmr.values().dtype == dtype
158+
assert_array_equal(enmr.values(), values)

tiledb/tests/test_schema_evolution.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,12 @@ def test_schema_evolution_with_enmr(tmp_path):
174174

175175
@pytest.mark.parametrize(
176176
"type,data",
177-
(("int", [0]), ("bool", [True, False]), ("str", ["abc", "defghi", "jk"])),
177+
(
178+
("int", [0]),
179+
("bool", [True, False]),
180+
("str", ["abc", "defghi", "jk"]),
181+
("bytes", [b"abc", b"defghi", b"jk"]),
182+
),
178183
)
179184
def test_schema_evolution_extend_enmr(tmp_path, type, data):
180185
uri = str(tmp_path)

0 commit comments

Comments
 (0)