@@ -772,3 +772,31 @@ def test_group_by_agg_last(
772772 df = df .sort (aggs , ** pre_sort )
773773 result = df .group_by (keys ).agg (nw .col (aggs ).last ()).sort (keys )
774774 assert_equal_data (result , expected )
775+
776+
777+ def test_multi_column_expansion (constructor : Constructor ) -> None :
778+ if "polars" in str (constructor ) and POLARS_VERSION < (1 , 32 ):
779+ pytest .skip (reason = "https://github.com/pola-rs/polars/issues/21773" )
780+ if "modin" in str (constructor ):
781+ pytest .skip (reason = "Internal error" )
782+ df = nw .from_native (constructor ({"a" : [1 , 1 , 2 ], "b" : [4 , 5 , 6 ]}))
783+ result = (
784+ df .group_by ("a" )
785+ .agg (nw .all ().sum ().name .suffix ("_aggregated" ))
786+ .sort ("a" , descending = True )
787+ )
788+ expected = {"a" : [2 , 1 ], "b_aggregated" : [6 , 9 ]}
789+ assert_equal_data (result , expected )
790+ result = (
791+ df .group_by ("a" )
792+ .agg (nw .col ("a" , "b" ).sum ().name .suffix ("_aggregated" ))
793+ .sort ("a" , descending = True )
794+ )
795+ expected = {"a" : [2 , 1 ], "a_aggregated" : [2 , 2 ], "b_aggregated" : [6 , 9 ]}
796+ assert_equal_data (result , expected )
797+ result = (
798+ df .group_by ("a" )
799+ .agg (nw .nth (0 , 1 ).sum ().name .suffix ("_aggregated" ))
800+ .sort ("a" , descending = True )
801+ )
802+ assert_equal_data (result , expected )
0 commit comments