Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions src/datachain/delta.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
import hashlib
from collections.abc import Sequence
from copy import copy
from functools import wraps
from typing import TYPE_CHECKING, TypeVar

from attrs import frozen

import datachain
from datachain.dataset import DatasetDependency, DatasetRecord
from datachain.error import DatasetNotFoundError
from datachain.project import Project
from datachain.query.dataset import Step, step_result

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Concatenate

from typing_extensions import ParamSpec

from datachain.catalog import Catalog
from datachain.lib.dc import DataChain
from datachain.query.dataset import QueryGenerator

P = ParamSpec("P")

Expand Down Expand Up @@ -49,38 +43,11 @@ def _inner(self: T, *args: "P.args", **kwargs: "P.kwargs") -> T:
return _inner


@frozen
class _RegenerateSystemColumnsStep(Step):
catalog: "Catalog"

def hash_inputs(self) -> str:
return hashlib.sha256(b"regenerate_sys_columns").hexdigest()

def apply(self, query_generator: "QueryGenerator", temp_tables: list[str]):
selectable = query_generator.select()
regenerated = self.catalog.warehouse._regenerate_system_columns(
selectable,
keep_existing_columns=True,
regenerate_columns=None,
)

def q(*columns):
return regenerated.with_only_columns(*columns)

return step_result(q, regenerated.selected_columns)


def _append_steps(dc: "DataChain", other: "DataChain"):
"""Returns cloned chain with appended steps from other chain.
Steps are all those modification methods applied like filters, mappers etc.
"""
dc = dc.clone()
dc._query.steps.append(
_RegenerateSystemColumnsStep(
catalog=dc.session.catalog,
)
)

dc._query.steps += other._query.steps.copy()
dc.signals_schema = other.signals_schema
return dc
Expand Down
4 changes: 4 additions & 0 deletions src/datachain/lib/dc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def read_dataset(
signals_schema |= SignalSchema.deserialize(query.feature_schema)
else:
signals_schema |= SignalSchema.from_column_types(query.column_types or {})

if delta:
signals_schema = signals_schema.clone_without_sys_signals()

chain = DataChain(query, _settings, signals_schema)

if delta:
Expand Down
15 changes: 6 additions & 9 deletions src/datachain/lib/dc/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def read_storage(
project=listing_project_name,
session=session,
settings=settings,
delta=delta,
delta_on=delta_on,
delta_result_on=delta_result_on,
delta_compare=delta_compare,
delta_retry=delta_retry,
delta_unsafe=delta_unsafe,
)
dc._query.update = update
dc.signals_schema = dc.signals_schema.mutate({f"{column}": file_type})
Expand Down Expand Up @@ -252,13 +258,4 @@ def lst_fn(ds_name, lst_uri):

assert storage_chain is not None

if delta:
storage_chain = storage_chain._as_delta(
on=delta_on,
right_on=delta_result_on,
compare=delta_compare,
delta_retry=delta_retry,
delta_unsafe=delta_unsafe,
)

return storage_chain
59 changes: 54 additions & 5 deletions tests/func/test_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,17 +314,66 @@ def build_chain(delta: bool):

build_chain(delta=False).save(result_name)

build_chain(delta=True).save(
result_name,
delta=True,
delta_on="measurement_id",
)
build_chain(delta=True).save(result_name)

assert set(
dc.read_dataset(result_name, session=test_session).to_values("measurement_id")
) == {1, 2}


def test_storage_delta_replay_regenerates_system_columns(test_session, tmp_dir):
data_dir = tmp_dir / f"regen_storage_{uuid.uuid4().hex[:8]}"
data_dir.mkdir()
storage_uri = data_dir.as_uri()
result_name = f"regen_storage_result_{uuid.uuid4().hex[:8]}"

def write_payload(index: int) -> None:
(data_dir / f"item{index}.txt").write_text(f"payload-{index}")

write_payload(1)
write_payload(2)

def build_chain(delta: bool):
read_kwargs = {"session": test_session, "update": True}
if delta:
read_kwargs |= {
"delta": True,
"delta_on": ["file.path"],
"delta_result_on": ["file.path"],
}

def get_measurement_id(file: File) -> int:
match = re.search(r"item(\d+)\.txt$", file.path)
assert match
return int(match.group(1))

def get_num(file: File) -> int:
return get_measurement_id(file)

chain = dc.read_storage(storage_uri, **read_kwargs)
return (
chain.mutate(num=1)
.select_except("num")
.map(measurement_id=get_measurement_id)
.map(err=lambda file: "")
.map(num=get_num)
.filter(C.err == "")
.select_except("err")
.map(double=lambda num: num * 2, output=int)
.select_except("num")
)

build_chain(delta=False).save(result_name)

write_payload(3)

build_chain(delta=True).save(result_name)

assert set(
dc.read_dataset(result_name, session=test_session).to_values("measurement_id")
) == {1, 2, 3}


def test_delta_update_from_storage(test_session, tmp_dir, tmp_path):
ds_name = "delta_ds"
path = tmp_dir.as_uri()
Expand Down