Skip to content

Commit afecf1a

Browse files
authored
fix(gen): always prepare table for now to have sysys columns (#1428)
1 parent 1ba2f89 commit afecf1a

File tree

3 files changed

+37
-9
lines changed

3 files changed

+37
-9
lines changed

src/datachain/query/dataset.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,10 @@ def create_udf_table(self, query: Select) -> "Table":
426426
"""Method that creates a table where temp udf results will be saved"""
427427

428428
def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
429-
"""Apply any necessary processing to the input query"""
430-
return query, []
429+
"""Materialize inputs, ensure sys columns are available, needed for checkpoints,
430+
needed for map to work (merge results)"""
431+
table = self.catalog.warehouse.create_pre_udf_table(query)
432+
return sqlalchemy.select(*table.c), [table]
431433

432434
@abstractmethod
433435
def create_result_query(
@@ -675,13 +677,6 @@ def create_udf_table(self, query: Select) -> "Table":
675677

676678
return self.catalog.warehouse.create_udf_table(udf_output_columns)
677679

678-
def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
679-
if os.getenv("DATACHAIN_DISABLE_QUERY_CACHE", "") not in ("", "0"):
680-
return query, []
681-
table = self.catalog.warehouse.create_pre_udf_table(query)
682-
q: Select = sqlalchemy.select(*table.c)
683-
return q, [table]
684-
685680
def create_result_query(
686681
self, udf_table, query
687682
) -> tuple[QueryGeneratorFunc, list["sqlalchemy.Column"]]:

tests/func/test_datachain.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,10 +1699,16 @@ def process(filename: list[str]) -> Iterator[tuple[str, int]]:
16991699
value=list(range(100)),
17001700
session=catalog_tmpfile.session,
17011701
)
1702+
# Read values in general doesn't guarantee order, so we need to order first
1703+
ds = ds.order_by("filename")
17021704
if offset is not None:
17031705
ds = ds.offset(offset)
17041706
if limit is not None:
17051707
ds = ds.limit(limit)
1708+
1709+
limited_filenames = ds.to_values("filename")
1710+
assert set(limited_filenames) == set(files)
1711+
17061712
ds = (
17071713
ds.settings(parallel=parallel)
17081714
.agg(

tests/func/test_union.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,30 @@ def test_union_parallel_udf_ids_only_no_dup(test_session_tmpfile, monkeypatch):
5757
assert total == 2 * n
5858
assert len(distinct_idx) == 2 * n
5959
assert total == len(distinct_idx)
60+
61+
62+
def test_union_parallel_gen_ids_only_no_dup(test_session_tmpfile, monkeypatch):
63+
monkeypatch.setattr("datachain.query.dispatch.DEFAULT_BATCH_SIZE", 5, raising=False)
64+
n = 30
65+
66+
x_ids = list(range(n))
67+
y_ids = list(range(n, 2 * n))
68+
69+
x = dc.read_values(idx=x_ids, session=test_session_tmpfile)
70+
y = dc.read_values(idx=y_ids, session=test_session_tmpfile)
71+
72+
xy = x.union(y)
73+
74+
def expand(idx):
75+
yield f"val-{idx}"
76+
77+
generated = xy.settings(parallel=2).gen(
78+
gen=expand,
79+
params=("idx",),
80+
output={"val": str},
81+
)
82+
83+
values = generated.to_values("val")
84+
85+
assert len(values) == 2 * n
86+
assert set(values) == {f"val-{i}" for i in range(2 * n)}

0 commit comments

Comments
 (0)