@@ -878,3 +878,146 @@ def name_len_interrupt(_name):
878878 chain .show ()
879879 captured = capfd .readouterr ()
880880 assert "semaphore" not in captured .err
881+
882+
883+ def test_gen_works_after_union (test_session_tmpfile , monkeypatch ):
884+ """
885+ Union drops sys columns, we test that UDF generates them correctly after that.
886+ """
887+ monkeypatch .setattr ("datachain.query.dispatch.DEFAULT_BATCH_SIZE" , 5 , raising = False )
888+ n = 30
889+
890+ x_ids = list (range (n ))
891+ y_ids = list (range (n , 2 * n ))
892+
893+ x = dc .read_values (idx = x_ids , session = test_session_tmpfile )
894+ y = dc .read_values (idx = y_ids , session = test_session_tmpfile )
895+
896+ xy = x .union (y )
897+
898+ def expand (idx ):
899+ yield f"val-{ idx } "
900+
901+ generated = xy .settings (parallel = 2 ).gen (
902+ gen = expand ,
903+ params = ("idx" ,),
904+ output = {"val" : str },
905+ )
906+
907+ values = generated .to_values ("val" )
908+
909+ assert len (values ) == 2 * n
910+ assert set (values ) == {f"val-{ i } " for i in range (2 * n )}
911+
912+
913+ @pytest .mark .parametrize ("full" , [False , True ])
914+ def test_gen_works_after_merge (test_session_tmpfile , monkeypatch , full ):
915+ """Merge drops sys columns as well; ensure UDF generation still works."""
916+ monkeypatch .setattr ("datachain.query.dispatch.DEFAULT_BATCH_SIZE" , 5 , raising = False )
917+ n = 30
918+
919+ idxs = list (range (n ))
920+
921+ left = dc .read_values (
922+ idx = idxs ,
923+ left_value = [f"left-{ i } " for i in idxs ],
924+ session = test_session_tmpfile ,
925+ )
926+ right = dc .read_values (
927+ idx = idxs ,
928+ right_value = [f"right-{ i } " for i in idxs ],
929+ session = test_session_tmpfile ,
930+ )
931+
932+ merged = left .merge (right , on = "idx" , full = full )
933+
934+ def expand (idx , left_value , right_value ):
935+ yield f"val-{ idx } -{ left_value } -{ right_value } "
936+
937+ generated = merged .settings (parallel = 2 ).gen (
938+ gen = expand ,
939+ params = ("idx" , "left_value" , "right_value" ),
940+ output = {"val" : str },
941+ )
942+
943+ values = generated .to_values ("val" )
944+
945+ assert len (values ) == n
946+ expected = {f"val-{ i } -left-{ i } -right-{ i } " for i in idxs }
947+ assert set (values ) == expected
948+
949+
950+ def test_agg_works_after_union (test_session_tmpfile , monkeypatch ):
951+ """Union must preserve sys columns for aggregations with functional partitions."""
952+ from datachain import func
953+
954+ monkeypatch .setattr ("datachain.query.dispatch.DEFAULT_BATCH_SIZE" , 5 , raising = False )
955+
956+ groups = 5
957+ n = 30
958+
959+ x_paths = [f"group-{ i % groups } /item-{ i } " for i in range (n )]
960+ y_paths = [f"group-{ i % groups } /item-{ n + i } " for i in range (n )]
961+
962+ x = dc .read_values (path = x_paths , session = test_session_tmpfile )
963+ y = dc .read_values (path = y_paths , session = test_session_tmpfile )
964+
965+ xy = x .union (y )
966+
967+ def summarize (paths ):
968+ group = paths [0 ].split ("/" )[0 ]
969+ yield group , len (paths )
970+
971+ aggregated = xy .settings (parallel = 2 ).agg (
972+ summarize ,
973+ params = ("path" ,),
974+ output = {"partition" : str , "count" : int },
975+ partition_by = func .parent ("path" ),
976+ )
977+
978+ records = aggregated .to_records ()
979+ expected_counts = {f"group-{ g } " : 2 * n // groups for g in range (groups )}
980+ assert {row ["partition" ]: row ["count" ] for row in records } == expected_counts
981+
982+
983+ @pytest .mark .parametrize ("full" , [False , True ])
984+ def test_agg_works_after_merge (test_session_tmpfile , monkeypatch , full ):
985+ """Ensure merge keeps sys columns for aggregations with functional partitions."""
986+ from datachain import func
987+
988+ monkeypatch .setattr ("datachain.query.dispatch.DEFAULT_BATCH_SIZE" , 5 , raising = False )
989+
990+ groups = 5
991+ n = 30
992+ idxs = list (range (n ))
993+
994+ left = dc .read_values (
995+ idx = idxs ,
996+ left_path = [f"group-{ i % groups } /left-{ i } " for i in idxs ],
997+ session = test_session_tmpfile ,
998+ )
999+ right = dc .read_values (
1000+ idx = idxs ,
1001+ right_value = idxs ,
1002+ session = test_session_tmpfile ,
1003+ )
1004+
1005+ merged = left .merge (right , on = "idx" , full = full )
1006+
1007+ def summarize (left_path , right_value ):
1008+ group = left_path [0 ].split ("/" )[0 ]
1009+ yield group , sum (right_value )
1010+
1011+ aggregated = merged .settings (parallel = 2 ).agg (
1012+ summarize ,
1013+ params = ("left_path" , "right_value" ),
1014+ output = {"partition" : str , "total" : int },
1015+ partition_by = func .parent ("left_path" ),
1016+ )
1017+
1018+ records = aggregated .to_records ()
1019+ expected_totals = {
1020+ f"group-{ g } " : sum (val for val in idxs if val % groups == g )
1021+ for g in range (groups )
1022+ }
1023+ assert {row ["partition" ]: row ["total" ] for row in records } == expected_totals
0 commit comments