Skip to content

Commit 5acb1b7

Browse files
authored
Add Rename to the set of core DAG ops that work in all DAGs (#312)
* Add a utility for testing equality between `TensorTables` * Add `Rename` to the set of core DAG ops that work in all DAGs This should work with either dataframes or `TensorTables`, which requires minor changes to TT to add a definition of equality and make it possible to change column names. Adjust param names and types to fix linter issue * Add docstrings to TensorTable methods * Adjust the implementation of `TensorTable.columns` * Mark assignment in `Rename` to be ignored by `mypy` linter * Appease the linter's line length checks * Update `Rename` op tests to use `assert_transformable_equal`
1 parent 5601766 commit 5acb1b7

File tree

5 files changed

+237
-0
lines changed

5 files changed

+237
-0
lines changed

merlin/dag/ops/rename.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
from merlin.core.protocols import Transformable
17+
from merlin.dag import ColumnSelector
18+
from merlin.dag.base_operator import BaseOperator
19+
20+
21+
class Rename(BaseOperator):
22+
"""This operation renames columns by one of several methods:
23+
24+
- using a user defined lambda function to transform column names
25+
- appending a postfix string to every column name
26+
- renaming a single column to a single fixed string
27+
28+
Example usage::
29+
30+
# Rename columns after LogOp
31+
cont_features = cont_names >> nvt.ops.LogOp() >> Rename(postfix='_log')
32+
processor = nvt.Workflow(cont_features)
33+
34+
Parameters
35+
----------
36+
f : callable, optional
37+
Function that takes a column name and returns a new column name
38+
postfix : str, optional
39+
If set each column name in the output will have this string appended to it
40+
name : str, optional
41+
If set, a single input column will be renamed to this string
42+
"""
43+
44+
def __init__(self, f=None, postfix=None, name=None):
45+
if not f and postfix is None and name is None:
46+
raise ValueError("must specify name, f, or postfix, for Rename op")
47+
48+
self.f = f
49+
self.postfix = postfix
50+
self.name = name
51+
super().__init__()
52+
53+
def transform(
54+
self, col_selector: ColumnSelector, transformable: Transformable
55+
) -> Transformable:
56+
transformable = transformable[col_selector.names]
57+
transformable.columns = list( # type: ignore[assignment]
58+
self.column_mapping(col_selector).keys()
59+
)
60+
return transformable
61+
62+
transform.__doc__ = BaseOperator.transform.__doc__
63+
64+
def column_mapping(self, col_selector):
65+
column_mapping = {}
66+
for col_name in col_selector.names:
67+
if self.f:
68+
new_col_name = self.f(col_name)
69+
elif self.postfix:
70+
new_col_name = col_name + self.postfix
71+
elif self.name:
72+
if len(col_selector.names) == 1:
73+
new_col_name = self.name
74+
else:
75+
raise RuntimeError("Single column name provided for renaming multiple columns")
76+
else:
77+
raise RuntimeError(
78+
"The Rename op requires one of f, postfix, or name to be provided"
79+
)
80+
81+
column_mapping[new_col_name] = [col_name]
82+
83+
return column_mapping

merlin/table/tensor_table.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ def columns(self) -> List[str]:
209209
"""
210210
return list(self.keys())
211211

212+
@columns.setter
213+
def columns(self, col_names):
214+
renamed_columns = {}
215+
for col, col_name in zip(self.columns, col_names):
216+
renamed_columns[col_name] = self._columns[col]
217+
self._columns = renamed_columns
218+
212219
@property
213220
def column_type(self):
214221
return type(list(self.values())[0])
@@ -241,10 +248,26 @@ def to_dict(self):
241248
return result
242249

243250
def cpu(self):
251+
"""
252+
Move this TensorTable and its columns to CPU
253+
254+
Returns
255+
-------
256+
TensorTable
257+
A new TensorTable containing the same columns but on CPU
258+
"""
244259
columns = {col_name: col_values.cpu() for col_name, col_values in self.items()}
245260
return TensorTable(columns)
246261

247262
def gpu(self):
263+
"""
264+
Move this TensorTable and its columns to GPU
265+
266+
Returns
267+
-------
268+
TensorTable
269+
A new TensorTable containing the same columns but on GPU
270+
"""
248271
columns = {col_name: col_values.gpu() for col_name, col_values in self.items()}
249272
return TensorTable(columns)
250273

merlin/testing/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
# flake8: noqa
18+
from merlin.testing.assert_equals import assert_transformable_equal

merlin/testing/assert_equals.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
from merlin.core.compat import pandas as pd
18+
from merlin.dispatch.lazy import lazy_singledispatch
19+
from merlin.table import TensorTable
20+
21+
22+
def assert_table_equal(left: TensorTable, right: TensorTable):
23+
pd.testing.assert_frame_equal(left.cpu().to_df(), right.cpu().to_df())
24+
25+
26+
@lazy_singledispatch
27+
def assert_transformable_equal(left, right):
28+
raise NotImplementedError
29+
30+
31+
@assert_transformable_equal.register(TensorTable)
32+
def _assert_equal_table(left, right):
33+
assert_table_equal(left, right)
34+
35+
36+
@assert_transformable_equal.register_lazy("cudf")
37+
def _register_assert_equal_df_cudf():
38+
import cudf
39+
40+
@assert_transformable_equal.register(cudf.DataFrame)
41+
def _assert_equal_df_cudf(left, right):
42+
cudf.testing.assert_frame_equal(left, right)
43+
44+
45+
@assert_transformable_equal.register_lazy("pandas")
46+
def _register_assert_equal_pandas():
47+
import pandas
48+
49+
@assert_transformable_equal.register(pandas.DataFrame)
50+
def _assert_equal_pandas(left, right):
51+
pandas.testing.assert_frame_equal(left, right)

tests/unit/dag/ops/test_rename.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
import numpy as np
17+
import pandas as pd
18+
import pytest
19+
20+
from merlin.core.compat import cudf
21+
from merlin.dag import ColumnSelector
22+
from merlin.dag.ops.rename import Rename
23+
from merlin.table import TensorTable
24+
from merlin.testing import assert_transformable_equal
25+
26+
transformables = [pd.DataFrame, TensorTable]
27+
if cudf:
28+
transformables.append(cudf.DataFrame)
29+
30+
31+
@pytest.mark.parametrize("transformable", transformables)
32+
def test_rename(transformable):
33+
df = transformable({"x": np.array([1, 2, 3, 4, 5]), "y": np.array([6, 7, 8, 9, 10])})
34+
35+
selector = ColumnSelector(["x", "y"])
36+
37+
op = Rename(f=lambda name: name.upper())
38+
transformed = op.transform(selector, df)
39+
expected = transformable({"X": np.array([1, 2, 3, 4, 5]), "Y": np.array([6, 7, 8, 9, 10])})
40+
assert_transformable_equal(transformed, expected)
41+
42+
op = Rename(postfix="_lower")
43+
transformed = op.transform(selector, df)
44+
expected = transformable(
45+
{
46+
"x_lower": np.array([1, 2, 3, 4, 5]),
47+
"y_lower": np.array([6, 7, 8, 9, 10]),
48+
}
49+
)
50+
assert_transformable_equal(transformed, expected)
51+
52+
selector = ColumnSelector(["x"])
53+
54+
op = Rename(name="z")
55+
transformed = op.transform(selector, df)
56+
expected = transformable({"z": np.array([1, 2, 3, 4, 5])})
57+
assert_transformable_equal(transformed, expected)
58+
59+
op = Rename(f=lambda name: name.upper())
60+
transformed = op.transform(selector, df)
61+
expected = transformable({"X": np.array([1, 2, 3, 4, 5])})
62+
assert_transformable_equal(transformed, expected)

0 commit comments

Comments
 (0)