Skip to content

Commit b7a2d6e

Browse files
authored
Resolve wildcard selectors in BaseOperator.compute_selector() (#146)
* Resolve wildcard selectors in `BaseOperator.compute_selector()` In order to make sure wildcard selectors get resolved in all operators, we also: * Made parent and dependency selectors optional in `compute_selector` * Refactored operators that override `compute_selector` to use `super()` * Make `compute_selector` signatures match across ops * Adjust type hints to flag `Optional` arguments
1 parent 1fd18f9 commit b7a2d6e

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

merlin/dag/base_operator.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import annotations
1717

1818
from enum import Flag, auto
19-
from typing import Any, List, Union
19+
from typing import Any, List, Optional, Union
2020

2121
import merlin.dag
2222
from merlin.core.protocols import Transformable
@@ -46,8 +46,8 @@ def compute_selector(
4646
self,
4747
input_schema: Schema,
4848
selector: ColumnSelector,
49-
parents_selector: ColumnSelector,
50-
dependencies_selector: ColumnSelector,
49+
parents_selector: Optional[ColumnSelector] = None,
50+
dependencies_selector: Optional[ColumnSelector] = None,
5151
) -> ColumnSelector:
5252
"""
5353
Provides a hook method for sub-classes to override to implement
@@ -69,9 +69,11 @@ def compute_selector(
6969
ColumnSelector
7070
Revised column selector to apply to the input schema
7171
"""
72+
selector = selector or ColumnSelector("*")
73+
7274
self._validate_matching_cols(input_schema, selector, self.compute_selector.__name__)
7375

74-
return selector
76+
return selector.resolve(input_schema)
7577

7678
def compute_input_schema(
7779
self,
@@ -109,7 +111,7 @@ def compute_output_schema(
109111
self,
110112
input_schema: Schema,
111113
col_selector: ColumnSelector,
112-
prev_output_schema: Schema = None,
114+
prev_output_schema: Optional[Schema] = None,
113115
) -> Schema:
114116
"""
115117
Given a set of schemas and a column selector for the input columns,
@@ -281,7 +283,9 @@ def _compute_properties(self, col_schema, input_schema):
281283

282284
def _validate_matching_cols(self, schema, selector, method_name):
283285
selector = selector or ColumnSelector()
284-
missing_cols = [name for name in selector.names if name not in schema.column_names]
286+
resolved_selector = selector.resolve(schema)
287+
288+
missing_cols = [name for name in selector.names if name not in resolved_selector.names]
285289
if missing_cols:
286290
raise ValueError(
287291
f"Missing columns {missing_cols} found in operator"

merlin/dag/ops/concat_columns.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def compute_selector(
3333
self,
3434
input_schema: Schema,
3535
selector: ColumnSelector,
36-
parents_selector: ColumnSelector,
37-
dependencies_selector: ColumnSelector,
36+
parents_selector: ColumnSelector = None,
37+
dependencies_selector: ColumnSelector = None,
3838
) -> ColumnSelector:
3939
"""
4040
Combine selectors from the nodes being added
@@ -55,14 +55,11 @@ def compute_selector(
5555
ColumnSelector
5656
Combined column selectors of parent and dependency nodes
5757
"""
58-
self._validate_matching_cols(
58+
return super().compute_selector(
5959
input_schema,
6060
parents_selector + dependencies_selector,
61-
self.compute_selector.__name__,
6261
)
6362

64-
return parents_selector + dependencies_selector
65-
6663
def compute_input_schema(
6764
self,
6865
root_schema: Schema,

merlin/dag/ops/subtraction.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def compute_selector(
3434
self,
3535
input_schema: Schema,
3636
selector: ColumnSelector,
37-
parents_selector: ColumnSelector,
38-
dependencies_selector: ColumnSelector,
37+
parents_selector: ColumnSelector = None,
38+
dependencies_selector: ColumnSelector = None,
3939
) -> ColumnSelector:
4040
"""
4141
Creates selector of all columns from the input schema
@@ -56,7 +56,10 @@ def compute_selector(
5656
ColumnSelector
5757
Selector of all columns from the input schema
5858
"""
59-
return ColumnSelector(input_schema.column_names)
59+
return super().compute_selector(
60+
input_schema,
61+
ColumnSelector("*"),
62+
)
6063

6164
def compute_input_schema(
6265
self,

0 commit comments

Comments
 (0)