Skip to content

Commit b3e683f

Browse files
authored
fix column concatenation to support list with normal column (#1685)
1 parent 97abf4e commit b3e683f

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

nvtabular/ops/categorify.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,8 @@ def _top_level_groupby(df, options: FitOptions):
758758
df_gb = type(df)()
759759
ignore_index = True
760760
df_gb[cat_col_selector_str] = _concat(
761-
[df[col] for col in cat_col_selector.names], ignore_index
761+
[_maybe_flatten_list_column(col, df)[col] for col in cat_col_selector.names],
762+
ignore_index,
762763
)
763764
cat_col_selector = ColumnSelector([cat_col_selector_str])
764765
else:
@@ -795,9 +796,7 @@ def _top_level_groupby(df, options: FitOptions):
795796

796797
# Perform groupby and flatten column index
797798
# (flattening provides better cudf/pd support)
798-
if is_list_col(cat_col_selector, df_gb):
799-
# handle list columns by encoding the list values
800-
df_gb = dispatch.flatten_list_column(df_gb[cat_col_selector.names[0]])
799+
df_gb = _maybe_flatten_list_column(cat_col_selector.names[0], df_gb)
801800
# NOTE: groupby(..., dropna=False) requires pandas>=1.1.0
802801
gb = df_gb.groupby(cat_col_selector.names, dropna=False).agg(agg_dict)
803802
gb.columns = [
@@ -1414,6 +1413,15 @@ def is_list_col(col_selector, df):
14141413
return has_lists
14151414

14161415

1416+
def _maybe_flatten_list_column(col: str, df):
1417+
# Flatten the specified column (col) if it is
1418+
# a list dtype. Otherwise, pass back df "as is"
1419+
selector = ColumnSelector([col])
1420+
if is_list_col(selector, df):
1421+
return dispatch.flatten_list_column(df[selector.names[0]])
1422+
return df
1423+
1424+
14171425
def _hash_bucket(df, num_buckets, col, encode_type="joint"):
14181426
if encode_type == "joint":
14191427
nb = num_buckets[col[0]]

tests/unit/ops/test_categorify.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,35 @@ def test_categorify_max_size_null_iloc_check():
654654
unique_C2 = pd.read_parquet("./categories/unique.C2.parquet")
655655
assert str(unique_C2["C2"].iloc[0]) in ["<NA>", "nan"]
656656
assert unique_C2["C2_size"].iloc[0] == 0
657+
658+
659+
@pytest.mark.parametrize("cpu", _CPU)
660+
def test_categorify_joint_list(cpu):
661+
df = pd.DataFrame(
662+
{
663+
"Author": ["User_A", "User_E", "User_B", "User_C"],
664+
"Engaging User": [
665+
["User_B", "User_C"],
666+
[],
667+
["User_A", "User_D"],
668+
["User_A"],
669+
],
670+
"Post": [1, 2, 3, 4],
671+
}
672+
)
673+
cat_names = ["Post", ["Author", "Engaging User"]]
674+
cats = cat_names >> nvt.ops.Categorify(encode_type="joint")
675+
workflow = nvt.Workflow(cats)
676+
df_out = (
677+
workflow.fit_transform(nvt.Dataset(df, cpu=cpu)).to_ddf().compute(scheduler="synchronous")
678+
)
679+
680+
compare_a = df_out["Author"].to_list() if cpu else df_out["Author"].to_arrow().to_pylist()
681+
compare_e = (
682+
df_out["Engaging User"].explode().dropna().to_list()
683+
if cpu
684+
else df_out["Engaging User"].explode().dropna().to_arrow().to_pylist()
685+
)
686+
687+
assert compare_a == [1, 5, 2, 3]
688+
assert compare_e == [2, 3, 1, 4, 1]

0 commit comments

Comments
 (0)