Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion nvtabular/ops/categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from dask.highlevelgraph import HighLevelGraph
from dask.utils import parse_bytes
from fsspec.core import get_fs_token_paths
from packaging.version import Version

from merlin.core import dispatch
from merlin.core.dispatch import DataFrameType, annotate, is_cpu_object, nullable_series
Expand All @@ -53,6 +54,7 @@
PAD_OFFSET = 0
NULL_OFFSET = 1
OOV_OFFSET = 2
PA_GE_14 = Version(pa.__version__) >= Version("14.0")


class Categorify(StatOperator):
Expand Down Expand Up @@ -907,7 +909,11 @@ def _general_concat(
):
# Concatenate DataFrame or pa.Table objects
if isinstance(frames[0], pa.Table):
df = pa.concat_tables(frames, promote=True)
if PA_GE_14:
df = pa.concat_tables(frames, promote_options="default")
else:
df = pa.concat_tables(frames, promote=True)

if (
cardinality_memory_limit
and col_selector is not None
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/ops/test_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_lambdaop_dtype_multi_op_propagation(cpu):
{
"a": np.arange(size),
"b": np.random.choice(["apple", "banana", "orange"], size),
"c": np.random.choice([0, 1], size).astype(np.float16),
"c": np.random.choice([0, 1], size),
}
)
ddf0 = dd.from_pandas(df0, npartitions=4)
Expand Down
10 changes: 8 additions & 2 deletions tests/unit/test_dask_nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def test_dask_workflow_api_dlrm(

@pytest.mark.parametrize("part_mem_fraction", [0.01])
def test_dask_groupby_stats(client, tmpdir, datasets, part_mem_fraction):
from nvtabular.ops.join_groupby import AGG_DTYPES

set_dask_client(client=client)

engine = "parquet"
Expand Down Expand Up @@ -175,10 +177,14 @@ def test_dask_groupby_stats(client, tmpdir, datasets, part_mem_fraction):
gb_e = expect.groupby("name-cat").aggregate({"name-cat": "count", "x": ["sum", "min", "std"]})
gb_e.columns = ["count", "sum", "min", "std"]
df_check = got.merge(gb_e, left_on="name-cat", right_index=True, how="left")
assert_eq(df_check["name-cat_count"], df_check["count"], check_names=False)
assert_eq(
df_check["name-cat_count"], df_check["count"].astype(AGG_DTYPES["count"]), check_names=False
)
assert_eq(df_check["name-cat_x_sum"], df_check["sum"], check_names=False)
assert_eq(df_check["name-cat_x_min"], df_check["min"], check_names=False)
assert_eq(df_check["name-cat_x_std"], df_check["std"].astype("float32"), check_names=False)
assert_eq(
df_check["name-cat_x_std"], df_check["std"].astype(AGG_DTYPES["std"]), check_names=False
)


@pytest.mark.parametrize("part_mem_fraction", [0.01])
Expand Down