Skip to content

Commit cf81584

Browse files
committed
added tests
1 parent 626f508 commit cf81584

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

sklearn/compose/tests/test_column_transformer.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1647,12 +1647,33 @@ def test_sk_visual_block_remainder_col_names_pandas():
16471647
assert visual_block.name_details == (["col1"], ["col2"])
16481648

16491649

1650-
def test_sk_visual_block_without_remainder():
1650+
def test_sk_visual_block_full_transform():
1651+
"""Check that visual_block doesn't return remainder when it has no columns
1652+
Non-regression test - https://github.com/scikit-learn/scikit-learn/issues/33513
1653+
"""
16511654
ct = ColumnTransformer([("norm1", Normalizer(), [0, 1])], remainder="passthrough")
16521655
X = np.array([[0, 4], [3, 3]])
16531656
ct.fit(X)
16541657
visual_block = ct._sk_visual_block_()
16551658
assert visual_block.names == ("norm1",)
1659+
assert visual_block.name_details == ([0, 1],)
1660+
assert isinstance(visual_block.estimators[0], Normalizer)
1661+
1662+
1663+
def test_sk_visual_block_remainder_with_preprocessor():
1664+
"""Check that visual_block doesn't cut the remainder if it is a transformer
1665+
Non-regression test - https://github.com/scikit-learn/scikit-learn/issues/33513
1666+
"""
1667+
ct = ColumnTransformer(
1668+
[("norm1", Normalizer(), [0, 1])], remainder=StandardScaler()
1669+
)
1670+
X = np.array([[0, 4, 3], [3, 3, 3]])
1671+
ct.fit(X)
1672+
visual_block = ct._sk_visual_block_()
1673+
assert visual_block.names == ("norm1", "remainder")
1674+
assert visual_block.name_details == ([0, 1], [2])
1675+
assert isinstance(visual_block.estimators[0], Normalizer)
1676+
assert isinstance(visual_block.estimators[1], StandardScaler)
16561677

16571678

16581679
@pytest.mark.parametrize("explicit_colname", ["first", "second", 0, 1])

0 commit comments

Comments
 (0)