Skip to content

Commit a93be0c

Browse files
authored
Use enum in ArraySchema for to-be-written Pandas category (#1881)
* Use categories in `ArraySchema` for to-be-written Pandas columns
1 parent 0adadc1 commit a93be0c

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

tiledb/dataframe_.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,15 @@ def _write_array(
693693
row_start_idx=None,
694694
timestamp=None,
695695
):
696+
696697
with tiledb.open(uri, "w", timestamp=timestamp) as A:
698+
for j in range(A.schema.nattr):
699+
attr = A.schema.attr(j)
700+
if attr.enum_label is not None:
701+
enmr = A.enum(attr.enum_label).values()
702+
df[attr.name] = df[attr.name].cat.set_categories(enmr)
703+
write_dict[attr.name] = df[attr.name].cat.codes
704+
697705
if A.schema.sparse:
698706
coords = []
699707
for k in range(A.schema.ndim):

tiledb/tests/test_enumeration.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,40 @@ def test_enum_dtypes(self, dtype, values):
156156
else:
157157
assert enmr.dtype == enmr.values().dtype == dtype
158158
assert_array_equal(enmr.values(), values)
159+
160+
@pytest.mark.skipif(not has_pandas(), reason="pandas not installed")
161+
def test_from_pandas_dtype_mismatch(self):
162+
import pandas as pd
163+
164+
schema = tiledb.ArraySchema(
165+
enums=[
166+
tiledb.Enumeration(name="enum1", values=["a", "b", "c"], ordered=False)
167+
],
168+
domain=tiledb.Domain(
169+
tiledb.Dim(name="dim1", dtype=np.int32, domain=(0, 1))
170+
),
171+
attrs=[tiledb.Attr(name="attr1", dtype=np.int32, enum_label="enum1")],
172+
sparse=True,
173+
)
174+
175+
# Pandas category's categories matches the TileDB enumeration's values
176+
df1 = pd.DataFrame(data={"dim1": [0, 1], "attr1": ["b", "c"]})
177+
df1["attr1"] = pd.Categorical(values=df1.attr1, categories=["a", "b", "c"])
178+
179+
array_path = self.path("arr1")
180+
tiledb.Array.create(array_path, schema)
181+
tiledb.from_pandas(array_path, df1, schema=schema, mode="append")
182+
183+
actual_values = tiledb.open(array_path).df[:]["attr1"].values.tolist()
184+
assert actual_values == ["b", "c"]
185+
186+
# Pandas category's categories does not match the TileDB enumeration's values
187+
df2 = pd.DataFrame(data={"dim1": [0, 1], "attr1": ["b", "c"]})
188+
df2["attr1"] = df2["attr1"].astype("category")
189+
190+
array_path = self.path("arr2")
191+
tiledb.Array.create(array_path, schema)
192+
tiledb.from_pandas(array_path, df2, schema=schema, mode="append")
193+
194+
actual_values = tiledb.open(array_path).df[:]["attr1"].values.tolist()
195+
assert actual_values == ["b", "c"]

0 commit comments

Comments
 (0)