diff --git a/src/datachain/delta.py b/src/datachain/delta.py index 125abd4fc..9ab20dd18 100644 --- a/src/datachain/delta.py +++ b/src/datachain/delta.py @@ -1,16 +1,12 @@ -import hashlib from collections.abc import Sequence from copy import copy from functools import wraps from typing import TYPE_CHECKING, TypeVar -from attrs import frozen - import datachain from datachain.dataset import DatasetDependency, DatasetRecord from datachain.error import DatasetNotFoundError from datachain.project import Project -from datachain.query.dataset import Step, step_result if TYPE_CHECKING: from collections.abc import Callable @@ -18,9 +14,7 @@ from typing_extensions import ParamSpec - from datachain.catalog import Catalog from datachain.lib.dc import DataChain - from datachain.query.dataset import QueryGenerator P = ParamSpec("P") @@ -49,38 +43,11 @@ def _inner(self: T, *args: "P.args", **kwargs: "P.kwargs") -> T: return _inner -@frozen -class _RegenerateSystemColumnsStep(Step): - catalog: "Catalog" - - def hash_inputs(self) -> str: - return hashlib.sha256(b"regenerate_sys_columns").hexdigest() - - def apply(self, query_generator: "QueryGenerator", temp_tables: list[str]): - selectable = query_generator.select() - regenerated = self.catalog.warehouse._regenerate_system_columns( - selectable, - keep_existing_columns=True, - regenerate_columns=None, - ) - - def q(*columns): - return regenerated.with_only_columns(*columns) - - return step_result(q, regenerated.selected_columns) - - def _append_steps(dc: "DataChain", other: "DataChain"): """Returns cloned chain with appended steps from other chain. Steps are all those modification methods applied like filters, mappers etc. """ dc = dc.clone() - dc._query.steps.append( - _RegenerateSystemColumnsStep( - catalog=dc.session.catalog, - ) - ) - dc._query.steps += other._query.steps.copy() dc.signals_schema = other.signals_schema return dc diff --git a/src/datachain/lib/dc/datasets.py b/src/datachain/lib/dc/datasets.py index 3094882c7..d138d8eb6 100644 --- a/src/datachain/lib/dc/datasets.py +++ b/src/datachain/lib/dc/datasets.py @@ -200,6 +200,10 @@ def read_dataset( signals_schema |= SignalSchema.deserialize(query.feature_schema) else: signals_schema |= SignalSchema.from_column_types(query.column_types or {}) + + if delta: + signals_schema = signals_schema.clone_without_sys_signals() + chain = DataChain(query, _settings, signals_schema) if delta: diff --git a/src/datachain/lib/dc/storage.py b/src/datachain/lib/dc/storage.py index 1fa3f4e95..09af40e51 100644 --- a/src/datachain/lib/dc/storage.py +++ b/src/datachain/lib/dc/storage.py @@ -187,6 +187,12 @@ def read_storage( project=listing_project_name, session=session, settings=settings, + delta=delta, + delta_on=delta_on, + delta_result_on=delta_result_on, + delta_compare=delta_compare, + delta_retry=delta_retry, + delta_unsafe=delta_unsafe, ) dc._query.update = update dc.signals_schema = dc.signals_schema.mutate({f"{column}": file_type}) @@ -252,13 +258,4 @@ def lst_fn(ds_name, lst_uri): assert storage_chain is not None - if delta: - storage_chain = storage_chain._as_delta( - on=delta_on, - right_on=delta_result_on, - compare=delta_compare, - delta_retry=delta_retry, - delta_unsafe=delta_unsafe, - ) - return storage_chain diff --git a/tests/func/test_delta.py b/tests/func/test_delta.py index 8a08745ce..5e1e341a8 100644 --- a/tests/func/test_delta.py +++ b/tests/func/test_delta.py @@ -314,17 +314,66 @@ def build_chain(delta: bool): build_chain(delta=False).save(result_name) - build_chain(delta=True).save( - result_name, - delta=True, - delta_on="measurement_id", - ) + build_chain(delta=True).save(result_name) assert set( dc.read_dataset(result_name, session=test_session).to_values("measurement_id") ) == {1, 2} +def test_storage_delta_replay_regenerates_system_columns(test_session, tmp_dir): + data_dir = tmp_dir / f"regen_storage_{uuid.uuid4().hex[:8]}" + data_dir.mkdir() + storage_uri = data_dir.as_uri() + result_name = f"regen_storage_result_{uuid.uuid4().hex[:8]}" + + def write_payload(index: int) -> None: + (data_dir / f"item{index}.txt").write_text(f"payload-{index}") + + write_payload(1) + write_payload(2) + + def build_chain(delta: bool): + read_kwargs = {"session": test_session, "update": True} + if delta: + read_kwargs |= { + "delta": True, + "delta_on": ["file.path"], + "delta_result_on": ["file.path"], + } + + def get_measurement_id(file: File) -> int: + match = re.search(r"item(\d+)\.txt$", file.path) + assert match + return int(match.group(1)) + + def get_num(file: File) -> int: + return get_measurement_id(file) + + chain = dc.read_storage(storage_uri, **read_kwargs) + return ( + chain.mutate(num=1) + .select_except("num") + .map(measurement_id=get_measurement_id) + .map(err=lambda file: "") + .map(num=get_num) + .filter(C.err == "") + .select_except("err") + .map(double=lambda num: num * 2, output=int) + .select_except("num") + ) + + build_chain(delta=False).save(result_name) + + write_payload(3) + + build_chain(delta=True).save(result_name) + + assert set( + dc.read_dataset(result_name, session=test_session).to_values("measurement_id") + ) == {1, 2, 3} + + def test_delta_update_from_storage(test_session, tmp_dir, tmp_path): ds_name = "delta_ds" path = tmp_dir.as_uri()