Skip to content

Commit 563be4b

Browse files
necessary changes to allow graph execution in dataloader (#152)
* necessary changes to allow graph execution in dataloader * added test for wildcard, removed comments, refactored ternary Co-authored-by: Karl Higley <kmhigley@gmail.com>
1 parent eda153c commit 563be4b

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

merlin/dag/executors.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def _build_input_data(self, node, transformable, capture_dtypes=False):
120120
else:
121121
# If there are no parents, this is an input node,
122122
# so pull columns directly from root data
123-
input_data = transformable[node_input_cols + list(addl_input_cols)]
123+
addl_input_cols = list(addl_input_cols) if addl_input_cols else []
124+
input_data = transformable[node_input_cols + addl_input_cols]
124125

125126
return input_data
126127

@@ -161,6 +162,10 @@ def _transform_data(self, node, input_data, capture_dtypes=False):
161162

162163
if is_list:
163164
col_dtype = list_val_dtype(col_series)
165+
if hasattr(col_dtype, "as_numpy_dtype"):
166+
col_dtype = col_dtype.as_numpy_dtype()
167+
elif hasattr(col_series, "numpy"):
168+
col_dtype = col_series[0].cpu().numpy().dtype
164169

165170
output_data_schema = output_col_schema.with_dtype(
166171
col_dtype, is_list=is_list, is_ragged=is_list

merlin/dag/ops/selection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,6 @@ def compute_output_schema(
108108
The schemas of the columns produced by this operator
109109
"""
110110
selector = col_selector or self.selector
111+
if selector.all:
112+
selector = ColumnSelector(input_schema.column_names)
111113
return super().compute_output_schema(input_schema, selector, prev_output_schema)

tests/unit/dag/ops/test_selection.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,14 @@ def test_selection_output_schema(df):
4949
result_schema = op.compute_output_schema(schema, ColumnSelector())
5050

5151
assert result_schema.column_names == ["x", "y"]
52+
53+
54+
@pytest.mark.parametrize("engine", ["parquet"])
55+
def test_selection_wildcard_output_schema(df):
56+
selector = ColumnSelector("*")
57+
schema = Schema([ColumnSchema(col) for col in df.columns])
58+
op = SelectionOp(selector)
59+
60+
result_schema = op.compute_output_schema(schema, ColumnSelector())
61+
62+
assert result_schema.column_names == schema.column_names

0 commit comments

Comments
 (0)