Skip to content

Commit a72c865

Browse files
Simplify working with query columns (#1382)
* Get rid of custom columns utils and use sqlalchemy methods. * Fill sys__id on partitioning instead of creating temp table * Tests refactoring + more tests added --------- Co-authored-by: Vladimir Rudnyh <[email protected]>
1 parent d5d89c0 commit a72c865

File tree

10 files changed

+485
-367
lines changed

10 files changed

+485
-367
lines changed

src/datachain/data_storage/warehouse.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
2323
from datachain.query.batch import RowsOutput
2424
from datachain.query.schema import ColumnMeta
25-
from datachain.query.utils import get_query_id_column
2625
from datachain.sql.functions import path as pathfunc
2726
from datachain.sql.types import Int, SQLType
2827
from datachain.utils import sql_escape_like
@@ -228,7 +227,8 @@ def dataset_select_paginated(
228227
while True:
229228
if limit is not None:
230229
limit -= num_yielded
231-
if limit == 0:
230+
num_yielded = 0
231+
if limit <= 0:
232232
break
233233
if limit < page_size:
234234
paginated_query = paginated_query.limit(None).limit(limit)
@@ -246,32 +246,48 @@ def dataset_select_paginated(
246246
break # no more results
247247
offset += page_size
248248

249-
def _regenerate_system_columns(self, selectable):
250-
"""Return a SELECT that regenerates sys__id and sys__rand deterministically."""
249+
def _regenerate_system_columns(
250+
self,
251+
selectable: sa.Select | sa.CTE,
252+
keep_existing_columns: bool = False,
253+
) -> sa.Select:
254+
"""
255+
Return a SELECT that regenerates sys__id and sys__rand deterministically.
251256
257+
If keep_existing_columns is True, existing sys__id and sys__rand columns
258+
will be kept as-is if they exist in the input selectable.
259+
"""
252260
base = selectable.subquery() if hasattr(selectable, "subquery") else selectable
253261

262+
result_columns: dict[str, sa.ColumnElement] = {}
263+
for col in base.c:
264+
if col.name in result_columns:
265+
raise ValueError(f"Duplicate column name {col.name} in SELECT")
266+
if col.name in ("sys__id", "sys__rand"):
267+
if keep_existing_columns:
268+
result_columns[col.name] = col
269+
else:
270+
result_columns[col.name] = col
271+
254272
system_types: dict[str, sa.types.TypeEngine] = {
255273
sys_col.name: sys_col.type
256274
for sys_col in self.schema.dataset_row_cls.sys_columns()
257275
}
258276

259-
result_columns = []
260-
for col in base.c:
261-
if col.name == "sys__id":
262-
expr = self._system_row_number_expr()
263-
expr = sa.cast(expr, system_types["sys__id"])
264-
result_columns.append(expr.label("sys__id"))
265-
elif col.name == "sys__rand":
266-
expr = self._system_random_expr()
267-
expr = sa.cast(expr, system_types["sys__rand"])
268-
result_columns.append(expr.label("sys__rand"))
269-
else:
270-
result_columns.append(col)
277+
# Add missing system columns if needed
278+
if "sys__id" not in result_columns:
279+
expr = self._system_row_number_expr()
280+
expr = sa.cast(expr, system_types["sys__id"])
281+
result_columns["sys__id"] = expr.label("sys__id")
282+
if "sys__rand" not in result_columns:
283+
expr = self._system_random_expr()
284+
expr = sa.cast(expr, system_types["sys__rand"])
285+
result_columns["sys__rand"] = expr.label("sys__rand")
271286

272287
# Wrap in subquery to materialize window functions, then wrap again in SELECT
273288
# This ensures window functions are computed before INSERT...FROM SELECT
274-
inner = sa.select(*result_columns).select_from(base).subquery()
289+
columns = list(result_columns.values())
290+
inner = sa.select(*columns).select_from(base).subquery()
275291
return sa.select(*inner.c).select_from(inner)
276292

277293
def _system_row_number_expr(self):
@@ -380,7 +396,7 @@ def dataset_rows_select_from_ids(
380396
"""
381397
Fetch dataset rows from database using a list of IDs.
382398
"""
383-
if (id_col := get_query_id_column(query)) is None:
399+
if (id_col := query.selected_columns.get("sys__id")) is None:
384400
raise RuntimeError("sys__id column not found in query")
385401

386402
query = query._clone().offset(None).limit(None).order_by(None)

src/datachain/query/batch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import sqlalchemy as sa
77

88
from datachain.data_storage.schema import PARTITION_COLUMN_ID
9-
from datachain.query.utils import get_query_column
109

1110
RowsOutputBatch = Sequence[Sequence]
1211
RowsOutput = Sequence | RowsOutputBatch
@@ -106,7 +105,7 @@ def __call__(
106105
query: sa.Select,
107106
id_col: sa.ColumnElement | None = None,
108107
) -> Generator[RowsOutput, None, None]:
109-
if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
108+
if (partition_col := query.selected_columns.get(PARTITION_COLUMN_ID)) is None:
110109
raise RuntimeError("partition column not found in query")
111110

112111
ids_only = False

src/datachain/query/dataset.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,9 @@ def create_result_query(
438438
"""
439439

440440
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
441+
if "sys__id" not in query.selected_columns:
442+
raise RuntimeError("Query must have sys__id column to run UDF")
443+
441444
if (rows_total := self.catalog.warehouse.query_count(query)) == 0:
442445
return
443446

@@ -580,13 +583,10 @@ def create_partitions_table(self, query: Select) -> "Table":
580583
"""
581584
Create temporary table with group by partitions.
582585
"""
583-
# Check if partition_by is set, we need it to create partitions.
584-
assert self.partition_by is not None
585-
# Check if sys__id is in the query, we need it to be able to join
586-
# the partition table with the udf table later.
587-
assert any(c.name == "sys__id" for c in query.selected_columns), (
588-
"Query must have sys__id column to use partitioning."
589-
)
586+
if self.partition_by is None:
587+
raise RuntimeError("Query must have partition_by set to use partitioning")
588+
if (id_col := query.selected_columns.get("sys__id")) is None:
589+
raise RuntimeError("Query must have sys__id column to use partitioning")
590590

591591
if isinstance(self.partition_by, (list, tuple, GeneratorType)):
592592
list_partition_by = list(self.partition_by)
@@ -602,7 +602,7 @@ def create_partitions_table(self, query: Select) -> "Table":
602602

603603
# fill table with partitions
604604
cols = [
605-
query.selected_columns.sys__id,
605+
id_col,
606606
f.dense_rank().over(order_by=partition_by).label(PARTITION_COLUMN_ID),
607607
]
608608
self.catalog.warehouse.db.execute(
@@ -634,21 +634,11 @@ def apply(
634634

635635
# Apply partitioning if needed.
636636
if self.partition_by is not None:
637-
if not any(c.name == "sys__id" for c in query.selected_columns):
638-
# If sys__id is not in the query, we need to create a temp table
639-
# to hold the query results, so we can join it with the
640-
# partition table later.
641-
columns = [
642-
c if isinstance(c, Column) else Column(c.name, c.type)
643-
for c in query.subquery().columns
644-
]
645-
temp_table = self.catalog.warehouse.create_dataset_rows_table(
646-
self.catalog.warehouse.temp_table_name(),
647-
columns=columns,
637+
if "sys__id" not in query.selected_columns:
638+
_query = query = self.catalog.warehouse._regenerate_system_columns(
639+
query,
640+
keep_existing_columns=True,
648641
)
649-
temp_tables.append(temp_table.name)
650-
self.catalog.warehouse.copy_table(temp_table, query)
651-
_query = query = temp_table.select()
652642

653643
partition_tbl = self.create_partitions_table(query)
654644
temp_tables.append(partition_tbl.name)

src/datachain/query/dispatch.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from datachain.query.queue import get_from_queue, put_into_queue
2424
from datachain.query.udf import UdfInfo
25-
from datachain.query.utils import get_query_id_column
2625
from datachain.utils import batched, flatten, safe_closing
2726

2827
if TYPE_CHECKING:
@@ -55,6 +54,9 @@ def udf_entrypoint() -> int:
5554
udf_info: UdfInfo = load(stdin.buffer)
5655

5756
query = udf_info["query"]
57+
if "sys__id" not in query.selected_columns:
58+
raise RuntimeError("sys__id column is required in UDF query")
59+
5860
batching = udf_info["batching"]
5961
is_generator = udf_info["is_generator"]
6062

@@ -65,15 +67,16 @@ def udf_entrypoint() -> int:
6567
wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
6668
warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
6769

68-
id_col = get_query_id_column(query)
69-
7070
with contextlib.closing(
71-
batching(warehouse.dataset_select_paginated, query, id_col=id_col)
71+
batching(
72+
warehouse.dataset_select_paginated,
73+
query,
74+
id_col=query.selected_columns.sys__id,
75+
)
7276
) as udf_inputs:
7377
try:
7478
UDFDispatcher(udf_info).run_udf(
7579
udf_inputs,
76-
ids_only=id_col is not None,
7780
download_cb=download_cb,
7881
processed_cb=processed_cb,
7982
generated_cb=generated_cb,
@@ -147,10 +150,10 @@ def _create_worker(self) -> "UDFWorker":
147150
self.udf_fields,
148151
)
149152

150-
def _run_worker(self, ids_only: bool) -> None:
153+
def _run_worker(self) -> None:
151154
try:
152155
worker = self._create_worker()
153-
worker.run(ids_only)
156+
worker.run()
154157
except (Exception, KeyboardInterrupt) as e:
155158
if self.done_queue:
156159
put_into_queue(
@@ -164,7 +167,6 @@ def _run_worker(self, ids_only: bool) -> None:
164167
def run_udf(
165168
self,
166169
input_rows: Iterable["RowsOutput"],
167-
ids_only: bool,
168170
download_cb: Callback = DEFAULT_CALLBACK,
169171
processed_cb: Callback = DEFAULT_CALLBACK,
170172
generated_cb: Callback = DEFAULT_CALLBACK,
@@ -178,9 +180,7 @@ def run_udf(
178180

179181
if n_workers == 1:
180182
# no need to spawn worker processes if we are running in a single process
181-
self.run_udf_single(
182-
input_rows, ids_only, download_cb, processed_cb, generated_cb
183-
)
183+
self.run_udf_single(input_rows, download_cb, processed_cb, generated_cb)
184184
else:
185185
if self.buffer_size < n_workers:
186186
raise RuntimeError(
@@ -189,13 +189,12 @@ def run_udf(
189189
)
190190

191191
self.run_udf_parallel(
192-
n_workers, input_rows, ids_only, download_cb, processed_cb, generated_cb
192+
n_workers, input_rows, download_cb, processed_cb, generated_cb
193193
)
194194

195195
def run_udf_single(
196196
self,
197197
input_rows: Iterable["RowsOutput"],
198-
ids_only: bool,
199198
download_cb: Callback = DEFAULT_CALLBACK,
200199
processed_cb: Callback = DEFAULT_CALLBACK,
201200
generated_cb: Callback = DEFAULT_CALLBACK,
@@ -204,18 +203,15 @@ def run_udf_single(
204203
# Rebuild schemas in single process too for consistency (cheap, idempotent).
205204
ModelStore.rebuild_all()
206205

207-
if ids_only and not self.is_batching:
206+
if not self.is_batching:
208207
input_rows = flatten(input_rows)
209208

210209
def get_inputs() -> Iterable["RowsOutput"]:
211210
warehouse = self.catalog.warehouse.clone()
212-
if ids_only:
213-
for ids in batched(input_rows, DEFAULT_BATCH_SIZE):
214-
yield from warehouse.dataset_rows_select_from_ids(
215-
self.query, ids, self.is_batching
216-
)
217-
else:
218-
yield from input_rows
211+
for ids in batched(input_rows, DEFAULT_BATCH_SIZE):
212+
yield from warehouse.dataset_rows_select_from_ids(
213+
self.query, ids, self.is_batching
214+
)
219215

220216
prefetch = udf.prefetch
221217
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
@@ -249,7 +245,6 @@ def run_udf_parallel( # noqa: C901, PLR0912
249245
self,
250246
n_workers: int,
251247
input_rows: Iterable["RowsOutput"],
252-
ids_only: bool,
253248
download_cb: Callback = DEFAULT_CALLBACK,
254249
processed_cb: Callback = DEFAULT_CALLBACK,
255250
generated_cb: Callback = DEFAULT_CALLBACK,
@@ -258,9 +253,7 @@ def run_udf_parallel( # noqa: C901, PLR0912
258253
self.done_queue = self.ctx.Queue()
259254

260255
pool = [
261-
self.ctx.Process(
262-
name=f"Worker-UDF-{i}", target=self._run_worker, args=[ids_only]
263-
)
256+
self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker)
264257
for i in range(n_workers)
265258
]
266259
for p in pool:
@@ -406,13 +399,13 @@ def __init__(
406399
self.processed_cb = ProcessedCallback("processed", self.done_queue)
407400
self.generated_cb = ProcessedCallback("generated", self.done_queue)
408401

409-
def run(self, ids_only: bool) -> None:
402+
def run(self) -> None:
410403
prefetch = self.udf.prefetch
411404
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
412405
catalog = clone_catalog_with_cache(self.catalog, _cache)
413406
udf_results = self.udf.run(
414407
self.udf_fields,
415-
self.get_inputs(ids_only),
408+
self.get_inputs(),
416409
catalog,
417410
self.cache,
418411
download_cb=self.download_cb,
@@ -434,13 +427,10 @@ def notify_and_process(self, udf_results):
434427
put_into_queue(self.done_queue, {"status": OK_STATUS})
435428
yield row
436429

437-
def get_inputs(self, ids_only: bool) -> Iterable["RowsOutput"]:
430+
def get_inputs(self) -> Iterable["RowsOutput"]:
438431
warehouse = self.catalog.warehouse.clone()
439432
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
440-
if ids_only:
441-
for ids in batched(batch, DEFAULT_BATCH_SIZE):
442-
yield from warehouse.dataset_rows_select_from_ids(
443-
self.query, ids, self.is_batching
444-
)
445-
else:
446-
yield from batch
433+
for ids in batched(batch, DEFAULT_BATCH_SIZE):
434+
yield from warehouse.dataset_rows_select_from_ids(
435+
self.query, ids, self.is_batching
436+
)

src/datachain/query/utils.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)