Skip to content

Commit e1eaf26

Browse files
Handle schema inference in Dataset with empty list col (#319)
* Add test for schema inference with empty list * Only use `list_val_dtype` if col is a list
1 parent 5acb1b7 commit e1eaf26

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

merlin/core/dispatch.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import enum
1717
import functools
1818
import itertools
19-
from typing import Callable, Union
19+
from typing import Callable, Optional, Union
2020

2121
import dask.dataframe as dd
2222
import numpy as np
@@ -311,7 +311,7 @@ def series_has_nulls(s):
311311
return s.has_nulls
312312

313313

314-
def list_val_dtype(ser: SeriesLike) -> np.dtype:
314+
def list_val_dtype(ser: SeriesLike) -> Optional[np.dtype]:
315315
"""
316316
Return the dtype of the leaves from a list or nested list
317317
@@ -322,16 +322,21 @@ def list_val_dtype(ser: SeriesLike) -> np.dtype:
322322
323323
Returns
324324
-------
325-
np.dtype
326-
The dtype of the innermost elements
325+
Optional[np.dtype]
326+
The dtype of the innermost elements if we find one
327327
"""
328328
if is_list_dtype(ser):
329329
if cudf is not None and isinstance(ser, cudf.Series):
330330
if is_list_dtype(ser):
331331
ser = ser.list.leaves
332332
return ser.dtype
333333
elif isinstance(ser, pd.Series):
334-
return pd.core.dtypes.cast.infer_dtype_from(next(iter(pd.core.common.flatten(ser))))[0]
334+
try:
335+
return pd.core.dtypes.cast.infer_dtype_from(
336+
next(iter(pd.core.common.flatten(ser)))
337+
)[0]
338+
except StopIteration:
339+
return None
335340
if isinstance(ser, np.ndarray):
336341
return ser.dtype
337342
# adds detection when in merlin column

merlin/io/dataset.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,13 +1212,12 @@ def sample_dtypes(self, n=1, annotate_lists=False):
12121212

12131213
if annotate_lists:
12141214
_real_meta = self._real_meta[n]
1215-
annotated = {
1216-
col: {
1217-
"dtype": list_val_dtype(_real_meta[col]) or _real_meta[col].dtype,
1218-
"is_list": is_list_dtype(_real_meta[col]),
1219-
}
1220-
for col in _real_meta.columns
1221-
}
1215+
annotated = {}
1216+
for col in _real_meta.columns:
1217+
is_list = is_list_dtype(_real_meta[col])
1218+
dtype = list_val_dtype(_real_meta[col]) if is_list else _real_meta[col].dtype
1219+
annotated[col] = {"dtype": dtype, "is_list": is_list}
1220+
12221221
return annotated
12231222

12241223
return self._real_meta[n].dtypes

tests/unit/io/test_dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,9 @@ def test_false_with_cudf_and_gpu(self):
4949
def test_false_missing_cudf_or_gpu(self):
5050
with pytest.raises(RuntimeError):
5151
Dataset(make_df({"a": [1, 2, 3]}), cpu=False)
52+
53+
54+
def test_infer_list_dtype_unknown():
55+
df = pd.DataFrame({"col": [[], []]})
56+
dataset = Dataset(df, cpu=True)
57+
assert dataset.schema["col"].dtype.element_type.value == "unknown"

0 commit comments

Comments
 (0)