Skip to content

Commit 5601766

Browse files
authored
Convert data formats before executing each op in LocalExecutor (#280)
* Refactor `LocalExecutor` into more discrete steps that can be overridden * Adjust how empty transformables are detected * Fix bug in determination of additional root columns to merge * Make it possible to pass a merge function in `LocalExecutor` * Avoid creating nodes when adding empty lists of column names * Minor cleanup of how additional root columns are merged * Move data format conversion between DAG nodes into `LocalExecutor` * Migrate executor test with `DataFrameLike` to `TensorTable` * Remove `DictArray` * Add `column_type` method to `TensorTable` * Adjust CPU-only masking for new enum values
1 parent 8d06650 commit 5601766

File tree

6 files changed

+186
-172
lines changed

6 files changed

+186
-172
lines changed

merlin/dag/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
#
1616

1717
# flake8: noqa
18-
from merlin.dag.base_operator import BaseOperator, Supports
19-
from merlin.dag.dictarray import DictArray
18+
from merlin.dag.base_operator import BaseOperator, DataFormats, Supports
2019
from merlin.dag.graph import Graph
2120
from merlin.dag.node import Node, iter_nodes, postorder_iter_nodes, preorder_iter_nodes
2221
from merlin.dag.selector import ColumnSelector

merlin/dag/base_operator.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525

2626

2727
class Supports(Flag):
28-
"""Indicates what type of data representation this operator supports for transformations"""
28+
"""
29+
Indicates what type of data representation this operator supports for transformations
30+
31+
(Deprecated)
32+
"""
2933

3034
# cudf dataframe
3135
CPU_DATAFRAME = auto()
@@ -37,6 +41,19 @@ class Supports(Flag):
3741
GPU_DICT_ARRAY = auto()
3842

3943

44+
class DataFormats(Flag):
45+
CUDF_DATAFRAME = auto()
46+
PANDAS_DATAFRAME = auto()
47+
48+
NUMPY_TENSOR_TABLE = auto()
49+
CUPY_TENSOR_TABLE = auto()
50+
TF_TENSOR_TABLE = auto()
51+
TORCH_TENSOR_TABLE = auto()
52+
53+
NUMPY_DICT_ARRAY = auto()
54+
CUPY_DICT_ARRAY = auto()
55+
56+
4057
class BaseOperator:
4158
"""
4259
Base class for all operator classes.
@@ -355,6 +372,15 @@ def supports(self) -> Supports:
355372
"""Returns what kind of data representation this operator supports"""
356373
return Supports.CPU_DATAFRAME | Supports.GPU_DATAFRAME
357374

375+
@property
376+
def supported_formats(self) -> DataFormats:
377+
return (
378+
DataFormats.PANDAS_DATAFRAME
379+
| DataFormats.CUDF_DATAFRAME
380+
| DataFormats.NUMPY_TENSOR_TABLE
381+
| DataFormats.CUPY_TENSOR_TABLE
382+
)
383+
358384
def _get_columns(self, df, selector):
359385
if isinstance(df, dict):
360386
return {col_name: df[col_name] for col_name in selector.names}

merlin/dag/dictarray.py

Lines changed: 0 additions & 125 deletions
This file was deleted.

0 commit comments

Comments
 (0)