@@ -1647,12 +1647,33 @@ def test_sk_visual_block_remainder_col_names_pandas():
16471647 assert visual_block .name_details == (["col1" ], ["col2" ])
16481648
16491649
1650- def test_sk_visual_block_without_remainder ():
1650+ def test_sk_visual_block_full_transform ():
1651+ """Check that visual_block doesn't return remainder when it has no columns
1652+ Non-regression test - https://github.com/scikit-learn/scikit-learn/issues/33513
1653+ """
16511654 ct = ColumnTransformer ([("norm1" , Normalizer (), [0 , 1 ])], remainder = "passthrough" )
16521655 X = np .array ([[0 , 4 ], [3 , 3 ]])
16531656 ct .fit (X )
16541657 visual_block = ct ._sk_visual_block_ ()
16551658 assert visual_block .names == ("norm1" ,)
1659+ assert visual_block .name_details == ([0 , 1 ],)
1660+ assert isinstance (visual_block .estimators [0 ], Normalizer )
1661+
1662+
1663+ def test_sk_visual_block_remainder_with_preprocessor ():
1664+ """Check that visual_block doesn't cut the remainder if it is a transformer
1665+ Non-regression test - https://github.com/scikit-learn/scikit-learn/issues/33513
1666+ """
1667+ ct = ColumnTransformer (
1668+ [("norm1" , Normalizer (), [0 , 1 ])], remainder = StandardScaler ()
1669+ )
1670+ X = np .array ([[0 , 4 , 3 ], [3 , 3 , 3 ]])
1671+ ct .fit (X )
1672+ visual_block = ct ._sk_visual_block_ ()
1673+ assert visual_block .names == ("norm1" , "remainder" )
1674+ assert visual_block .name_details == ([0 , 1 ], [2 ])
1675+ assert isinstance (visual_block .estimators [0 ], Normalizer )
1676+ assert isinstance (visual_block .estimators [1 ], StandardScaler )
16561677
16571678
16581679@pytest .mark .parametrize ("explicit_colname" , ["first" , "second" , 0 , 1 ])
0 commit comments