2525 global_dask_client ,
2626 set_client_deprecated ,
2727)
28- from merlin .dag import ColumnSelector , Node
28+ from merlin .dag import ColumnSelector , Graph , Node
2929from merlin .io .worker import clean_worker_cache
3030
3131LOG = logging .getLogger ("merlin" )
@@ -39,7 +39,7 @@ class LocalExecutor:
3939 def transform (
4040 self ,
4141 transformable ,
42- nodes ,
42+ graph ,
4343 output_dtypes = None ,
4444 additional_columns = None ,
4545 capture_dtypes = False ,
@@ -48,6 +48,21 @@ def transform(
4848 Transforms a single dataframe (possibly a partition of a Dask Dataframe)
4949 by applying the operators from a collection of Nodes
5050 """
51+ nodes = []
52+ if isinstance (graph , Graph ):
53+ nodes .append (graph .output_node )
54+ elif isinstance (graph , Node ):
55+ nodes .append (graph )
56+ elif isinstance (graph , list ):
57+ nodes = graph
58+ else :
59+ raise TypeError (
60+ f"LocalExecutor detected unsupported type of input for graph: { type (graph )} ."
61+ " `graph` argument must be either a `Graph` object (preferred)"
62+ " or a list of `Node` objects (deprecated, but supported for backward "
63+ " compatibility.)"
64+ )
65+
5166 output_data = None
5267
5368 for node in nodes :
@@ -220,12 +235,26 @@ def __getstate__(self):
220235 return {k : v for k , v in self .__dict__ .items () if k != "client" }
221236
222237 def transform (
223- self , ddf , nodes , output_dtypes = None , additional_columns = None , capture_dtypes = False
238+ self , ddf , graph , output_dtypes = None , additional_columns = None , capture_dtypes = False
224239 ):
225240 """
226241 Transforms all partitions of a Dask Dataframe by applying the operators
227242 from a collection of Nodes
228243 """
244+ nodes = []
245+ if isinstance (graph , Graph ):
246+ nodes .append (graph .output_node )
247+ elif isinstance (graph , Node ):
248+ nodes .append (graph )
249+ elif isinstance (graph , list ):
250+ nodes = graph
251+ else :
252+ raise TypeError (
253+ f"DaskExecutor detected unsupported type of input for graph: { type (graph )} ."
254+ " `graph` argument must be either a `Graph` object (preferred)"
255+ " or a list of `Node` objects (deprecated, but supported for backward"
256+ " compatibility.)"
257+ )
229258
230259 self ._clear_worker_cache ()
231260
0 commit comments