Skip to content

Commit 7059e23

Browse files
authored
Rework executor transform methods to accept a Graph (#158)
* Rework executor `transform` methods to accept a `Graph` * Allow `Node` as a potential value of `transform`'s `graph` arg (for now)
1 parent 563be4b commit 7059e23

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

merlin/dag/executors.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
global_dask_client,
2626
set_client_deprecated,
2727
)
28-
from merlin.dag import ColumnSelector, Node
28+
from merlin.dag import ColumnSelector, Graph, Node
2929
from merlin.io.worker import clean_worker_cache
3030

3131
LOG = logging.getLogger("merlin")
@@ -39,7 +39,7 @@ class LocalExecutor:
3939
def transform(
4040
self,
4141
transformable,
42-
nodes,
42+
graph,
4343
output_dtypes=None,
4444
additional_columns=None,
4545
capture_dtypes=False,
@@ -48,6 +48,21 @@ def transform(
4848
Transforms a single dataframe (possibly a partition of a Dask Dataframe)
4949
by applying the operators from a collection of Nodes
5050
"""
51+
nodes = []
52+
if isinstance(graph, Graph):
53+
nodes.append(graph.output_node)
54+
elif isinstance(graph, Node):
55+
nodes.append(graph)
56+
elif isinstance(graph, list):
57+
nodes = graph
58+
else:
59+
raise TypeError(
60+
f"LocalExecutor detected unsupported type of input for graph: {type(graph)}."
61+
" `graph` argument must be either a `Graph` object (preferred)"
62+
" or a list of `Node` objects (deprecated, but supported for backward "
63+
" compatibility.)"
64+
)
65+
5166
output_data = None
5267

5368
for node in nodes:
@@ -220,12 +235,26 @@ def __getstate__(self):
220235
return {k: v for k, v in self.__dict__.items() if k != "client"}
221236

222237
def transform(
223-
self, ddf, nodes, output_dtypes=None, additional_columns=None, capture_dtypes=False
238+
self, ddf, graph, output_dtypes=None, additional_columns=None, capture_dtypes=False
224239
):
225240
"""
226241
Transforms all partitions of a Dask Dataframe by applying the operators
227242
from a collection of Nodes
228243
"""
244+
nodes = []
245+
if isinstance(graph, Graph):
246+
nodes.append(graph.output_node)
247+
elif isinstance(graph, Node):
248+
nodes.append(graph)
249+
elif isinstance(graph, list):
250+
nodes = graph
251+
else:
252+
raise TypeError(
253+
f"DaskExecutor detected unsupported type of input for graph: {type(graph)}."
254+
" `graph` argument must be either a `Graph` object (preferred)"
255+
" or a list of `Node` objects (deprecated, but supported for backward"
256+
" compatibility.)"
257+
)
229258

230259
self._clear_worker_cache()
231260

0 commit comments

Comments
 (0)