Skip to content

Commit 9e7340b

Browse files
committed
better fix for sys_id in delta, add tests
1 parent 7258b2e commit 9e7340b

File tree

4 files changed

+64
-47
lines changed

4 files changed

+64
-47
lines changed

src/datachain/delta.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
1-
import hashlib
21
from collections.abc import Sequence
32
from copy import copy
43
from functools import wraps
54
from typing import TYPE_CHECKING, TypeVar
65

7-
from attrs import frozen
8-
96
import datachain
107
from datachain.dataset import DatasetDependency, DatasetRecord
118
from datachain.error import DatasetNotFoundError
129
from datachain.project import Project
13-
from datachain.query.dataset import Step, step_result
1410

1511
if TYPE_CHECKING:
1612
from collections.abc import Callable
1713
from typing import Concatenate
1814

1915
from typing_extensions import ParamSpec
2016

21-
from datachain.catalog import Catalog
2217
from datachain.lib.dc import DataChain
23-
from datachain.query.dataset import QueryGenerator
2418

2519
P = ParamSpec("P")
2620

@@ -49,38 +43,11 @@ def _inner(self: T, *args: "P.args", **kwargs: "P.kwargs") -> T:
4943
return _inner
5044

5145

52-
@frozen
53-
class _RegenerateSystemColumnsStep(Step):
54-
catalog: "Catalog"
55-
56-
def hash_inputs(self) -> str:
57-
return hashlib.sha256(b"regenerate_sys_columns").hexdigest()
58-
59-
def apply(self, query_generator: "QueryGenerator", temp_tables: list[str]):
60-
selectable = query_generator.select()
61-
regenerated = self.catalog.warehouse._regenerate_system_columns(
62-
selectable,
63-
keep_existing_columns=True,
64-
regenerate_columns=None,
65-
)
66-
67-
def q(*columns):
68-
return regenerated.with_only_columns(*columns)
69-
70-
return step_result(q, regenerated.selected_columns)
71-
72-
7346
def _append_steps(dc: "DataChain", other: "DataChain"):
7447
"""Returns cloned chain with appended steps from other chain.
7548
Steps are all those modification methods applied like filters, mappers etc.
7649
"""
7750
dc = dc.clone()
78-
dc._query.steps.append(
79-
_RegenerateSystemColumnsStep(
80-
catalog=dc.session.catalog,
81-
)
82-
)
83-
8451
dc._query.steps += other._query.steps.copy()
8552
dc.signals_schema = other.signals_schema
8653
return dc

src/datachain/lib/dc/datasets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def read_dataset(
200200
signals_schema |= SignalSchema.deserialize(query.feature_schema)
201201
else:
202202
signals_schema |= SignalSchema.from_column_types(query.column_types or {})
203+
204+
if delta:
205+
signals_schema = signals_schema.clone_without_sys_signals()
206+
203207
chain = DataChain(query, _settings, signals_schema)
204208

205209
if delta:

src/datachain/lib/dc/storage.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,12 @@ def read_storage(
187187
project=listing_project_name,
188188
session=session,
189189
settings=settings,
190+
delta=delta,
191+
delta_on=delta_on,
192+
delta_result_on=delta_result_on,
193+
delta_compare=delta_compare,
194+
delta_retry=delta_retry,
195+
delta_unsafe=delta_unsafe,
190196
)
191197
dc._query.update = update
192198
dc.signals_schema = dc.signals_schema.mutate({f"{column}": file_type})
@@ -252,13 +258,4 @@ def lst_fn(ds_name, lst_uri):
252258

253259
assert storage_chain is not None
254260

255-
if delta:
256-
storage_chain = storage_chain._as_delta(
257-
on=delta_on,
258-
right_on=delta_result_on,
259-
compare=delta_compare,
260-
delta_retry=delta_retry,
261-
delta_unsafe=delta_unsafe,
262-
)
263-
264261
return storage_chain

tests/func/test_delta.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,17 +314,66 @@ def build_chain(delta: bool):
314314

315315
build_chain(delta=False).save(result_name)
316316

317-
build_chain(delta=True).save(
318-
result_name,
319-
delta=True,
320-
delta_on="measurement_id",
321-
)
317+
build_chain(delta=True).save(result_name)
322318

323319
assert set(
324320
dc.read_dataset(result_name, session=test_session).to_values("measurement_id")
325321
) == {1, 2}
326322

327323

324+
def test_storage_delta_replay_regenerates_system_columns(test_session, tmp_dir):
325+
data_dir = tmp_dir / f"regen_storage_{uuid.uuid4().hex[:8]}"
326+
data_dir.mkdir()
327+
storage_uri = data_dir.as_uri()
328+
result_name = f"regen_storage_result_{uuid.uuid4().hex[:8]}"
329+
330+
def write_payload(index: int) -> None:
331+
(data_dir / f"item{index}.txt").write_text(f"payload-{index}")
332+
333+
write_payload(1)
334+
write_payload(2)
335+
336+
def build_chain(delta: bool):
337+
read_kwargs = {"session": test_session, "update": True}
338+
if delta:
339+
read_kwargs |= {
340+
"delta": True,
341+
"delta_on": ["file.path"],
342+
"delta_result_on": ["file.path"],
343+
}
344+
345+
def get_measurement_id(file: File) -> int:
346+
match = re.search(r"item(\d+)\.txt$", file.path)
347+
assert match
348+
return int(match.group(1))
349+
350+
def get_num(file: File) -> int:
351+
return get_measurement_id(file)
352+
353+
chain = dc.read_storage(storage_uri, **read_kwargs)
354+
return (
355+
chain.mutate(num=1)
356+
.select_except("num")
357+
.map(measurement_id=get_measurement_id)
358+
.map(err=lambda file: "")
359+
.map(num=get_num)
360+
.filter(C.err == "")
361+
.select_except("err")
362+
.map(double=lambda num: num * 2, output=int)
363+
.select_except("num")
364+
)
365+
366+
build_chain(delta=False).save(result_name)
367+
368+
write_payload(3)
369+
370+
build_chain(delta=True).save(result_name)
371+
372+
assert set(
373+
dc.read_dataset(result_name, session=test_session).to_values("measurement_id")
374+
) == {1, 2, 3}
375+
376+
328377
def test_delta_update_from_storage(test_session, tmp_dir, tmp_path):
329378
ds_name = "delta_ds"
330379
path = tmp_dir.as_uri()

0 commit comments

Comments
 (0)