Skip to content

Commit e24c2d5

Browse files
authored
fix sys ids in agg, add more tests (#1430)
* chore(tests): cleanup and add a bit more tests * fix agg: stabilize sys_ids
1 parent afecf1a commit e24c2d5

File tree

3 files changed

+146
-35
lines changed

3 files changed

+146
-35
lines changed

src/datachain/query/dataset.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -630,23 +630,18 @@ def clone(self, partition_by: PartitionByType | None = None) -> "Self":
630630
def apply(
631631
self, query_generator: QueryGenerator, temp_tables: list[str]
632632
) -> "StepResult":
633-
_query = query = query_generator.select()
633+
query, tables = self.process_input_query(query_generator.select())
634+
_query = query
634635

635636
# Apply partitioning if needed.
636637
if self.partition_by is not None:
637-
_query = query = self.catalog.warehouse._regenerate_system_columns(
638-
query_generator.select(),
639-
keep_existing_columns=True,
640-
regenerate_columns=["sys__id"],
641-
)
642638
partition_tbl = self.create_partitions_table(query)
643-
temp_tables.append(partition_tbl.name)
644639
query = query.outerjoin(
645640
partition_tbl,
646641
partition_tbl.c.sys__id == query.selected_columns.sys__id,
647642
).add_columns(*partition_columns())
643+
tables = [*tables, partition_tbl]
648644

649-
query, tables = self.process_input_query(query)
650645
temp_tables.extend(t.name for t in tables)
651646
udf_table = self.create_udf_table(_query)
652647
temp_tables.append(udf_table.name)

tests/func/test_udf.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,3 +878,146 @@ def name_len_interrupt(_name):
878878
chain.show()
879879
captured = capfd.readouterr()
880880
assert "semaphore" not in captured.err
881+
882+
883+
def test_gen_works_after_union(test_session_tmpfile, monkeypatch):
884+
"""
885+
Union drops sys columns, we test that UDF generates them correctly after that.
886+
"""
887+
monkeypatch.setattr("datachain.query.dispatch.DEFAULT_BATCH_SIZE", 5, raising=False)
888+
n = 30
889+
890+
x_ids = list(range(n))
891+
y_ids = list(range(n, 2 * n))
892+
893+
x = dc.read_values(idx=x_ids, session=test_session_tmpfile)
894+
y = dc.read_values(idx=y_ids, session=test_session_tmpfile)
895+
896+
xy = x.union(y)
897+
898+
def expand(idx):
899+
yield f"val-{idx}"
900+
901+
generated = xy.settings(parallel=2).gen(
902+
gen=expand,
903+
params=("idx",),
904+
output={"val": str},
905+
)
906+
907+
values = generated.to_values("val")
908+
909+
assert len(values) == 2 * n
910+
assert set(values) == {f"val-{i}" for i in range(2 * n)}
911+
912+
913+
@pytest.mark.parametrize("full", [False, True])
914+
def test_gen_works_after_merge(test_session_tmpfile, monkeypatch, full):
915+
"""Merge drops sys columns as well; ensure UDF generation still works."""
916+
monkeypatch.setattr("datachain.query.dispatch.DEFAULT_BATCH_SIZE", 5, raising=False)
917+
n = 30
918+
919+
idxs = list(range(n))
920+
921+
left = dc.read_values(
922+
idx=idxs,
923+
left_value=[f"left-{i}" for i in idxs],
924+
session=test_session_tmpfile,
925+
)
926+
right = dc.read_values(
927+
idx=idxs,
928+
right_value=[f"right-{i}" for i in idxs],
929+
session=test_session_tmpfile,
930+
)
931+
932+
merged = left.merge(right, on="idx", full=full)
933+
934+
def expand(idx, left_value, right_value):
935+
yield f"val-{idx}-{left_value}-{right_value}"
936+
937+
generated = merged.settings(parallel=2).gen(
938+
gen=expand,
939+
params=("idx", "left_value", "right_value"),
940+
output={"val": str},
941+
)
942+
943+
values = generated.to_values("val")
944+
945+
assert len(values) == n
946+
expected = {f"val-{i}-left-{i}-right-{i}" for i in idxs}
947+
assert set(values) == expected
948+
949+
950+
def test_agg_works_after_union(test_session_tmpfile, monkeypatch):
951+
"""Union must preserve sys columns for aggregations with functional partitions."""
952+
from datachain import func
953+
954+
monkeypatch.setattr("datachain.query.dispatch.DEFAULT_BATCH_SIZE", 5, raising=False)
955+
956+
groups = 5
957+
n = 30
958+
959+
x_paths = [f"group-{i % groups}/item-{i}" for i in range(n)]
960+
y_paths = [f"group-{i % groups}/item-{n + i}" for i in range(n)]
961+
962+
x = dc.read_values(path=x_paths, session=test_session_tmpfile)
963+
y = dc.read_values(path=y_paths, session=test_session_tmpfile)
964+
965+
xy = x.union(y)
966+
967+
def summarize(paths):
968+
group = paths[0].split("/")[0]
969+
yield group, len(paths)
970+
971+
aggregated = xy.settings(parallel=2).agg(
972+
summarize,
973+
params=("path",),
974+
output={"partition": str, "count": int},
975+
partition_by=func.parent("path"),
976+
)
977+
978+
records = aggregated.to_records()
979+
expected_counts = {f"group-{g}": 2 * n // groups for g in range(groups)}
980+
assert {row["partition"]: row["count"] for row in records} == expected_counts
981+
982+
983+
@pytest.mark.parametrize("full", [False, True])
984+
def test_agg_works_after_merge(test_session_tmpfile, monkeypatch, full):
985+
"""Ensure merge keeps sys columns for aggregations with functional partitions."""
986+
from datachain import func
987+
988+
monkeypatch.setattr("datachain.query.dispatch.DEFAULT_BATCH_SIZE", 5, raising=False)
989+
990+
groups = 5
991+
n = 30
992+
idxs = list(range(n))
993+
994+
left = dc.read_values(
995+
idx=idxs,
996+
left_path=[f"group-{i % groups}/left-{i}" for i in idxs],
997+
session=test_session_tmpfile,
998+
)
999+
right = dc.read_values(
1000+
idx=idxs,
1001+
right_value=idxs,
1002+
session=test_session_tmpfile,
1003+
)
1004+
1005+
merged = left.merge(right, on="idx", full=full)
1006+
1007+
def summarize(left_path, right_value):
1008+
group = left_path[0].split("/")[0]
1009+
yield group, sum(right_value)
1010+
1011+
aggregated = merged.settings(parallel=2).agg(
1012+
summarize,
1013+
params=("left_path", "right_value"),
1014+
output={"partition": str, "total": int},
1015+
partition_by=func.parent("left_path"),
1016+
)
1017+
1018+
records = aggregated.to_records()
1019+
expected_totals = {
1020+
f"group-{g}": sum(val for val in idxs if val % groups == g)
1021+
for g in range(groups)
1022+
}
1023+
assert {row["partition"]: row["total"] for row in records} == expected_totals

tests/func/test_union.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -57,30 +57,3 @@ def test_union_parallel_udf_ids_only_no_dup(test_session_tmpfile, monkeypatch):
5757
assert total == 2 * n
5858
assert len(distinct_idx) == 2 * n
5959
assert total == len(distinct_idx)
60-
61-
62-
def test_union_parallel_gen_ids_only_no_dup(test_session_tmpfile, monkeypatch):
63-
monkeypatch.setattr("datachain.query.dispatch.DEFAULT_BATCH_SIZE", 5, raising=False)
64-
n = 30
65-
66-
x_ids = list(range(n))
67-
y_ids = list(range(n, 2 * n))
68-
69-
x = dc.read_values(idx=x_ids, session=test_session_tmpfile)
70-
y = dc.read_values(idx=y_ids, session=test_session_tmpfile)
71-
72-
xy = x.union(y)
73-
74-
def expand(idx):
75-
yield f"val-{idx}"
76-
77-
generated = xy.settings(parallel=2).gen(
78-
gen=expand,
79-
params=("idx",),
80-
output={"val": str},
81-
)
82-
83-
values = generated.to_values("val")
84-
85-
assert len(values) == 2 * n
86-
assert set(values) == {f"val-{i}" for i in range(2 * n)}

0 commit comments

Comments
 (0)