|
5 | 5 | import string |
6 | 6 | from abc import ABC, abstractmethod |
7 | 7 | from collections.abc import Callable, Generator, Iterable, Iterator, Sequence |
8 | | -from typing import TYPE_CHECKING, Any, Union |
| 8 | +from typing import TYPE_CHECKING, Any, Union, cast |
9 | 9 | from urllib.parse import urlparse |
10 | 10 |
|
11 | 11 | import attrs |
|
23 | 23 | from datachain.query.batch import RowsOutput |
24 | 24 | from datachain.query.schema import ColumnMeta |
25 | 25 | from datachain.sql.functions import path as pathfunc |
26 | | -from datachain.sql.types import Int, SQLType |
| 26 | +from datachain.sql.types import SQLType |
27 | 27 | from datachain.utils import sql_escape_like |
28 | 28 |
|
29 | 29 | if TYPE_CHECKING: |
|
32 | 32 | _FromClauseArgument, |
33 | 33 | _OnClauseArgument, |
34 | 34 | ) |
| 35 | + from sqlalchemy.sql.selectable import FromClause |
35 | 36 | from sqlalchemy.types import TypeEngine |
36 | 37 |
|
37 | 38 | from datachain.data_storage import schema |
@@ -248,45 +249,56 @@ def dataset_select_paginated( |
248 | 249 |
|
249 | 250 | def _regenerate_system_columns( |
250 | 251 | self, |
251 | | - selectable: sa.Select | sa.CTE, |
| 252 | + selectable: sa.Select, |
252 | 253 | keep_existing_columns: bool = False, |
| 254 | + regenerate_columns: Iterable[str] | None = None, |
253 | 255 | ) -> sa.Select: |
254 | 256 | """ |
255 | | - Return a SELECT that regenerates sys__id and sys__rand deterministically. |
| 257 | + Return a SELECT that regenerates system columns deterministically. |
256 | 258 |
|
257 | | - If keep_existing_columns is True, existing sys__id and sys__rand columns |
258 | | - will be kept as-is if they exist in the input selectable. |
259 | | - """ |
260 | | - base = selectable.subquery() if hasattr(selectable, "subquery") else selectable |
261 | | - |
262 | | - result_columns: dict[str, sa.ColumnElement] = {} |
263 | | - for col in base.c: |
264 | | - if col.name in result_columns: |
265 | | - raise ValueError(f"Duplicate column name {col.name} in SELECT") |
266 | | - if col.name in ("sys__id", "sys__rand"): |
267 | | - if keep_existing_columns: |
268 | | - result_columns[col.name] = col |
269 | | - else: |
270 | | - result_columns[col.name] = col |
| 259 | + If keep_existing_columns is True, existing system columns will be kept as-is |
| 260 | + even when they are listed in ``regenerate_columns``. |
271 | 261 |
|
272 | | - system_types: dict[str, sa.types.TypeEngine] = { |
| 262 | + Args: |
| 263 | + selectable: Base SELECT |
| 264 | + keep_existing_columns: When True, reuse existing system columns even if |
| 265 | + they are part of the regeneration set. |
| 266 | + regenerate_columns: Names of system columns to regenerate. Defaults to |
| 267 | + {"sys__id", "sys__rand"}. Columns not listed are left untouched. |
| 268 | + """ |
| 269 | + system_columns = { |
273 | 270 | sys_col.name: sys_col.type |
274 | 271 | for sys_col in self.schema.dataset_row_cls.sys_columns() |
275 | 272 | } |
| 273 | + regenerate = set(regenerate_columns or system_columns) |
| 274 | + generators = { |
| 275 | + "sys__id": self._system_row_number_expr, |
| 276 | + "sys__rand": self._system_random_expr, |
| 277 | + } |
| 278 | + |
| 279 | + base = cast("FromClause", selectable.subquery()) |
| 280 | + |
| 281 | + def build(name: str) -> sa.ColumnElement: |
| 282 | + expr = generators[name]() |
| 283 | + return sa.cast(expr, system_columns[name]).label(name) |
| 284 | + |
| 285 | + columns: list[sa.ColumnElement] = [] |
| 286 | + present: set[str] = set() |
| 287 | + changed = False |
| 288 | + |
| 289 | + for col in base.c: |
| 290 | + present.add(col.name) |
| 291 | + regen = col.name in regenerate and not keep_existing_columns |
| 292 | + columns.append(build(col.name) if regen else col) |
| 293 | + changed |= regen |
| 294 | + |
| 295 | + for name in regenerate - present: |
| 296 | + columns.append(build(name)) |
| 297 | + changed = True |
| 298 | + |
| 299 | + if not changed: |
| 300 | + return selectable |
276 | 301 |
|
277 | | - # Add missing system columns if needed |
278 | | - if "sys__id" not in result_columns: |
279 | | - expr = self._system_row_number_expr() |
280 | | - expr = sa.cast(expr, system_types["sys__id"]) |
281 | | - result_columns["sys__id"] = expr.label("sys__id") |
282 | | - if "sys__rand" not in result_columns: |
283 | | - expr = self._system_random_expr() |
284 | | - expr = sa.cast(expr, system_types["sys__rand"]) |
285 | | - result_columns["sys__rand"] = expr.label("sys__rand") |
286 | | - |
287 | | - # Wrap in subquery to materialize window functions, then wrap again in SELECT |
288 | | - # This ensures window functions are computed before INSERT...FROM SELECT |
289 | | - columns = list(result_columns.values()) |
290 | 302 | inner = sa.select(*columns).select_from(base).subquery() |
291 | 303 | return sa.select(*inner.c).select_from(inner) |
292 | 304 |
|
@@ -950,10 +962,15 @@ def create_udf_table( |
950 | 962 | SQLite TEMPORARY tables cannot be directly used as they are process-specific, |
951 | 963 | and UDFs are run in other processes when run in parallel. |
952 | 964 | """ |
| 965 | + columns = [ |
| 966 | + c |
| 967 | + for c in columns |
| 968 | + if c.name not in [col.name for col in self.dataset_row_cls.sys_columns()] |
| 969 | + ] |
953 | 970 | tbl = sa.Table( |
954 | 971 | name or self.udf_table_name(), |
955 | 972 | sa.MetaData(), |
956 | | - sa.Column("sys__id", Int, primary_key=True), |
| 973 | + *self.dataset_row_cls.sys_columns(), |
957 | 974 | *columns, |
958 | 975 | ) |
959 | 976 | self.db.create_table(tbl, if_not_exists=True) |
|
0 commit comments