Skip to content

Commit 654d9af

Browse files
authored
fix(steps): drop sys columns when it is not safe to keep them (#1400)
1 parent 406a0a1 commit 654d9af

File tree

11 files changed

+113
-124
lines changed

11 files changed

+113
-124
lines changed

src/datachain/data_storage/schema.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
JSON,
1212
Boolean,
1313
DateTime,
14-
Int,
1514
Int64,
1615
SQLType,
1716
String,
@@ -269,7 +268,7 @@ def delete(self):
269268
@classmethod
270269
def sys_columns(cls):
271270
return [
272-
sa.Column("sys__id", Int, primary_key=True),
271+
sa.Column("sys__id", UInt64, primary_key=True),
273272
sa.Column(
274273
"sys__rand", UInt64, nullable=False, server_default=f.abs(f.random())
275274
),

src/datachain/data_storage/sqlite.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -868,11 +868,8 @@ def add_left_rows_filter(exp: BinaryExpression):
868868
if isinstance(c, BinaryExpression):
869869
right_left_join = add_left_rows_filter(c)
870870

871-
# Use CTE instead of subquery to force SQLite to materialize the result
872-
# This breaks deep nesting and prevents parser stack overflow.
873871
union_cte = sqlalchemy.union(left_right_join, right_left_join).cte()
874-
875-
return self._regenerate_system_columns(union_cte)
872+
return sqlalchemy.select(*union_cte.c).select_from(union_cte)
876873

877874
def _system_row_number_expr(self):
878875
return func.row_number().over()
@@ -884,11 +881,7 @@ def create_pre_udf_table(self, query: "Select") -> "Table":
884881
"""
885882
Create a temporary table from a query for use in a UDF.
886883
"""
887-
columns = [
888-
sqlalchemy.Column(c.name, c.type)
889-
for c in query.selected_columns
890-
if c.name != "sys__id"
891-
]
884+
columns = [sqlalchemy.Column(c.name, c.type) for c in query.selected_columns]
892885
table = self.create_udf_table(columns)
893886

894887
with tqdm(desc="Preparing", unit=" rows", leave=False) as pbar:

src/datachain/data_storage/warehouse.py

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import string
66
from abc import ABC, abstractmethod
77
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
8-
from typing import TYPE_CHECKING, Any, Union
8+
from typing import TYPE_CHECKING, Any, Union, cast
99
from urllib.parse import urlparse
1010

1111
import attrs
@@ -23,7 +23,7 @@
2323
from datachain.query.batch import RowsOutput
2424
from datachain.query.schema import ColumnMeta
2525
from datachain.sql.functions import path as pathfunc
26-
from datachain.sql.types import Int, SQLType
26+
from datachain.sql.types import SQLType
2727
from datachain.utils import sql_escape_like
2828

2929
if TYPE_CHECKING:
@@ -32,6 +32,7 @@
3232
_FromClauseArgument,
3333
_OnClauseArgument,
3434
)
35+
from sqlalchemy.sql.selectable import FromClause
3536
from sqlalchemy.types import TypeEngine
3637

3738
from datachain.data_storage import schema
@@ -248,45 +249,56 @@ def dataset_select_paginated(
248249

249250
def _regenerate_system_columns(
250251
self,
251-
selectable: sa.Select | sa.CTE,
252+
selectable: sa.Select,
252253
keep_existing_columns: bool = False,
254+
regenerate_columns: Iterable[str] | None = None,
253255
) -> sa.Select:
254256
"""
255-
Return a SELECT that regenerates sys__id and sys__rand deterministically.
257+
Return a SELECT that regenerates system columns deterministically.
256258
257-
If keep_existing_columns is True, existing sys__id and sys__rand columns
258-
will be kept as-is if they exist in the input selectable.
259-
"""
260-
base = selectable.subquery() if hasattr(selectable, "subquery") else selectable
261-
262-
result_columns: dict[str, sa.ColumnElement] = {}
263-
for col in base.c:
264-
if col.name in result_columns:
265-
raise ValueError(f"Duplicate column name {col.name} in SELECT")
266-
if col.name in ("sys__id", "sys__rand"):
267-
if keep_existing_columns:
268-
result_columns[col.name] = col
269-
else:
270-
result_columns[col.name] = col
259+
If keep_existing_columns is True, existing system columns will be kept as-is
260+
even when they are listed in ``regenerate_columns``.
271261
272-
system_types: dict[str, sa.types.TypeEngine] = {
262+
Args:
263+
selectable: Base SELECT
264+
keep_existing_columns: When True, reuse existing system columns even if
265+
they are part of the regeneration set.
266+
regenerate_columns: Names of system columns to regenerate. Defaults to
267+
{"sys__id", "sys__rand"}. Columns not listed are left untouched.
268+
"""
269+
system_columns = {
273270
sys_col.name: sys_col.type
274271
for sys_col in self.schema.dataset_row_cls.sys_columns()
275272
}
273+
regenerate = set(regenerate_columns or system_columns)
274+
generators = {
275+
"sys__id": self._system_row_number_expr,
276+
"sys__rand": self._system_random_expr,
277+
}
278+
279+
base = cast("FromClause", selectable.subquery())
280+
281+
def build(name: str) -> sa.ColumnElement:
282+
expr = generators[name]()
283+
return sa.cast(expr, system_columns[name]).label(name)
284+
285+
columns: list[sa.ColumnElement] = []
286+
present: set[str] = set()
287+
changed = False
288+
289+
for col in base.c:
290+
present.add(col.name)
291+
regen = col.name in regenerate and not keep_existing_columns
292+
columns.append(build(col.name) if regen else col)
293+
changed |= regen
294+
295+
for name in regenerate - present:
296+
columns.append(build(name))
297+
changed = True
298+
299+
if not changed:
300+
return selectable
276301

277-
# Add missing system columns if needed
278-
if "sys__id" not in result_columns:
279-
expr = self._system_row_number_expr()
280-
expr = sa.cast(expr, system_types["sys__id"])
281-
result_columns["sys__id"] = expr.label("sys__id")
282-
if "sys__rand" not in result_columns:
283-
expr = self._system_random_expr()
284-
expr = sa.cast(expr, system_types["sys__rand"])
285-
result_columns["sys__rand"] = expr.label("sys__rand")
286-
287-
# Wrap in subquery to materialize window functions, then wrap again in SELECT
288-
# This ensures window functions are computed before INSERT...FROM SELECT
289-
columns = list(result_columns.values())
290302
inner = sa.select(*columns).select_from(base).subquery()
291303
return sa.select(*inner.c).select_from(inner)
292304

@@ -950,10 +962,15 @@ def create_udf_table(
950962
SQLite TEMPORARY tables cannot be directly used as they are process-specific,
951963
and UDFs are run in other processes when run in parallel.
952964
"""
965+
columns = [
966+
c
967+
for c in columns
968+
if c.name not in [col.name for col in self.dataset_row_cls.sys_columns()]
969+
]
953970
tbl = sa.Table(
954971
name or self.udf_table_name(),
955972
sa.MetaData(),
956-
sa.Column("sys__id", Int, primary_key=True),
973+
*self.dataset_row_cls.sys_columns(),
957974
*columns,
958975
)
959976
self.db.create_table(tbl, if_not_exists=True)

src/datachain/diff/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class CompareStatus(str, Enum):
2424
SAME = "S"
2525

2626

27-
def _compare( # noqa: C901, PLR0912
27+
def _compare( # noqa: C901
2828
left: "DataChain",
2929
right: "DataChain",
3030
on: str | Sequence[str],
@@ -151,11 +151,7 @@ def _to_list(obj: str | Sequence[str] | None) -> list[str] | None:
151151
if status_col:
152152
cols_select.append(diff_col)
153153

154-
if not dc_diff._sys:
155-
# TODO workaround when sys signal is not available in diff
156-
dc_diff = dc_diff.settings(sys=True).select(*cols_select).settings(sys=False)
157-
else:
158-
dc_diff = dc_diff.select(*cols_select)
154+
dc_diff = dc_diff.select(*cols_select)
159155

160156
# final schema is schema from the left chain with status column added if needed
161157
dc_diff.signals_schema = (

src/datachain/lib/dc/datachain.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,9 @@ def map(
856856
udf_obj.to_udf_wrapper(self._settings.batch_size),
857857
**self._settings.to_dict(),
858858
),
859-
signal_schema=self.signals_schema | udf_obj.output,
859+
signal_schema=SignalSchema({"sys": Sys})
860+
| self.signals_schema
861+
| udf_obj.output,
860862
)
861863

862864
def gen(
@@ -894,7 +896,7 @@ def gen(
894896
udf_obj.to_udf_wrapper(self._settings.batch_size),
895897
**self._settings.to_dict(),
896898
),
897-
signal_schema=udf_obj.output,
899+
signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
898900
)
899901

900902
@delta_disabled
@@ -1031,7 +1033,7 @@ def my_agg(files: list[File]) -> Iterator[tuple[File, int]]:
10311033
partition_by=processed_partition_by,
10321034
**self._settings.to_dict(),
10331035
),
1034-
signal_schema=udf_obj.output,
1036+
signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
10351037
)
10361038

10371039
def batch_map(
@@ -1097,11 +1099,7 @@ def _udf_to_obj(
10971099
sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
10981100
DataModel.register(list(sign.output_schema.values.values()))
10991101

1100-
signals_schema = self.signals_schema
1101-
if self._sys:
1102-
signals_schema = SignalSchema({"sys": Sys}) | signals_schema
1103-
1104-
params_schema = signals_schema.slice(
1102+
params_schema = self.signals_schema.slice(
11051103
sign.params, self._setup, is_batch=is_batch
11061104
)
11071105

@@ -1156,11 +1154,9 @@ def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
11561154
)
11571155
)
11581156

1159-
def select(self, *args: str, _sys: bool = True) -> "Self":
1157+
def select(self, *args: str) -> "Self":
11601158
"""Select only a specified set of signals."""
11611159
new_schema = self.signals_schema.resolve(*args)
1162-
if self._sys and _sys:
1163-
new_schema = SignalSchema({"sys": Sys}) | new_schema
11641160
columns = new_schema.db_signals()
11651161
return self._evolve(
11661162
query=self._query.select(*columns), signal_schema=new_schema
@@ -1710,9 +1706,11 @@ def _resolve(
17101706

17111707
signals_schema = self.signals_schema.clone_without_sys_signals()
17121708
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
1713-
ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
1714-
right_signals_schema, rname
1715-
)
1709+
1710+
ds.signals_schema = signals_schema.merge(right_signals_schema, rname)
1711+
1712+
if not full:
1713+
ds.signals_schema = SignalSchema({"sys": Sys}) | ds.signals_schema
17161714

17171715
return ds
17181716

@@ -1723,6 +1721,7 @@ def union(self, other: "Self") -> "Self":
17231721
Parameters:
17241722
other: chain whose rows will be added to `self`.
17251723
"""
1724+
self.signals_schema = self.signals_schema.clone_without_sys_signals()
17261725
return self._evolve(query=self._query.union(other._query))
17271726

17281727
def subtract( # type: ignore[override]

src/datachain/query/dataset.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,6 @@ def create_result_query(
438438
"""
439439

440440
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
441-
if "sys__id" not in query.selected_columns:
442-
raise RuntimeError("Query must have sys__id column to run UDF")
443-
444441
if (rows_total := self.catalog.warehouse.query_count(query)) == 0:
445442
return
446443

@@ -634,12 +631,11 @@ def apply(
634631

635632
# Apply partitioning if needed.
636633
if self.partition_by is not None:
637-
if "sys__id" not in query.selected_columns:
638-
_query = query = self.catalog.warehouse._regenerate_system_columns(
639-
query,
640-
keep_existing_columns=True,
641-
)
642-
634+
_query = query = self.catalog.warehouse._regenerate_system_columns(
635+
query_generator.select(),
636+
keep_existing_columns=True,
637+
regenerate_columns=["sys__id"],
638+
)
643639
partition_tbl = self.create_partitions_table(query)
644640
temp_tables.append(partition_tbl.name)
645641
query = query.outerjoin(
@@ -960,28 +956,23 @@ def apply(
960956
q2 = self.query2.apply_steps().select().subquery()
961957
temp_tables.extend(self.query2.temp_table_names)
962958

963-
columns1, columns2 = _order_columns(q1.columns, q2.columns)
964-
965-
union_select = sqlalchemy.select(*columns1).union_all(
966-
sqlalchemy.select(*columns2)
967-
)
968-
union_cte = union_select.cte()
969-
regenerated = self.query1.catalog.warehouse._regenerate_system_columns(
970-
union_cte
971-
)
972-
result_columns = tuple(regenerated.selected_columns)
959+
columns1 = _drop_system_columns(q1.columns)
960+
columns2 = _drop_system_columns(q2.columns)
961+
columns1, columns2 = _order_columns(columns1, columns2)
973962

974963
def q(*columns):
975-
if not columns:
976-
return regenerated
964+
selected_names = [c.name for c in columns]
965+
col1 = [c for c in columns1 if c.name in selected_names]
966+
col2 = [c for c in columns2 if c.name in selected_names]
967+
union_query = sqlalchemy.select(*col1).union_all(sqlalchemy.select(*col2))
977968

978-
names = {c.name for c in columns}
979-
selected = [c for c in result_columns if c.name in names]
980-
return regenerated.with_only_columns(*selected)
969+
union_cte = union_query.cte()
970+
select_cols = [union_cte.c[name] for name in selected_names]
971+
return sqlalchemy.select(*select_cols)
981972

982973
return step_result(
983974
q,
984-
result_columns,
975+
columns1,
985976
dependencies=self.query1.dependencies | self.query2.dependencies,
986977
)
987978

@@ -1070,7 +1061,7 @@ def apply(
10701061
q1 = self.get_query(self.query1, temp_tables)
10711062
q2 = self.get_query(self.query2, temp_tables)
10721063

1073-
q1_columns = list(q1.c)
1064+
q1_columns = _drop_system_columns(q1.c) if self.full else list(q1.c)
10741065
q1_column_names = {c.name for c in q1_columns}
10751066

10761067
q2_columns = []
@@ -1211,6 +1202,10 @@ def _order_columns(
12111202
return [[d[n] for n in column_order] for d in column_dicts]
12121203

12131204

1205+
def _drop_system_columns(columns: Iterable[ColumnElement]) -> list[ColumnElement]:
1206+
return [c for c in columns if not c.name.startswith("sys__")]
1207+
1208+
12141209
@attrs.define
12151210
class ResultIter:
12161211
_row_iter: Iterable[Any]

tests/func/test_datachain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1629,7 +1629,7 @@ def test_read_pandas_multiindex(test_session):
16291629

16301630
# Check the resulting column names and data
16311631
expected_columns = ["a_cat", "b_dog", "b_cat", "a_dog"]
1632-
assert set(chain.signals_schema.db_signals()) == set(expected_columns)
1632+
assert set(chain.schema.keys()) == set(expected_columns)
16331633

16341634
expected_data = [
16351635
{"a_cat": 1, "b_dog": 2, "b_cat": 3, "a_dog": 4},

0 commit comments

Comments
 (0)