Skip to content

Commit 7b1297a

Browse files
authored
Fix for use case when generator skips input rows in checkpoints (#1609)
Fix for use case when generator skips input rows in checkpoints
1 parent afedefb commit 7b1297a

File tree

3 files changed

+54
-30
lines changed

3 files changed

+54
-30
lines changed

src/datachain/lib/udf.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -564,16 +564,23 @@ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
564564

565565
def _process_row(row):
566566
row_id, *row = row
567+
has_output = False
567568
with safe_closing(self.process(*row)) as result_objs:
568569
for result_obj, is_last in with_last_flag(result_objs):
570+
has_output = True
569571
udf_output = self._flatten_row(result_obj)
570572
udf_output = dict(zip(self.signal_names, udf_output, strict=False))
571-
# Include sys__input_id to track which input generated this
572-
# output.
573-
udf_output["sys__input_id"] = row_id # input id
574-
# Mark as partial=True unless it's the last output
573+
udf_output["sys__input_id"] = row_id
575574
udf_output["sys__partial"] = not is_last
575+
udf_output["sys__empty"] = None
576576
yield udf_output
577+
if not has_output:
578+
# Marker: records that this input was processed but yielded nothing.
579+
yield {
580+
"sys__input_id": row_id,
581+
"sys__partial": False,
582+
"sys__empty": True,
583+
}
577584

578585
prepared_inputs = _prepare_rows(udf_inputs)
579586
prepared_inputs = _prefetch_inputs(
@@ -584,12 +591,6 @@ def _process_row(row):
584591
)
585592

586593
with closing(prepared_inputs):
587-
# TODO: Fix limitation where inputs yielding nothing are not tracked in
588-
# processed table. Currently, if process() yields nothing for an input,
589-
# that input's sys__id is never added to the processed table, causing it
590-
# to be re-processed on checkpoint recovery. Solution: yield a marker
591-
# row with sys__input_id when process() yields nothing, then filter
592-
# these marker rows before inserting to output table.
593594
for row in prepared_inputs:
594595
yield _process_row(row)
595596
processed_cb.relative_update(1)

src/datachain/query/dataset.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ def _checkpoint_tracking_columns(self) -> list["sqlalchemy.Column"]:
538538
return [
539539
sa.Column("sys__input_id", sa.Integer, nullable=True),
540540
sa.Column("sys__partial", sa.Boolean, nullable=True),
541+
sa.Column("sys__empty", sa.Boolean, nullable=True),
541542
]
542543

543544
def get_input_query(self, input_table_name: str, original_query: Select) -> Select:
@@ -1591,13 +1592,21 @@ def create_output_table(self, name: str) -> "Table":
15911592
def create_result_query(
15921593
self, udf_table, query: Select
15931594
) -> tuple[QueryGeneratorFunc, list["sqlalchemy.Column"]]:
1594-
udf_table_query = udf_table.select().subquery()
1595-
# Exclude sys__input_id and sys__partial - they're only needed for tracking
1596-
# during UDF execution and checkpoint recovery
1595+
# Filter out empty-input marker rows (inputs that yielded nothing)
1596+
udf_table_query = (
1597+
udf_table.select()
1598+
.where(
1599+
sa.or_(
1600+
udf_table.c.sys__empty.is_(None),
1601+
udf_table.c.sys__empty == sa.false(),
1602+
)
1603+
)
1604+
.subquery()
1605+
)
1606+
# Exclude checkpoint tracking columns from the result
1607+
excluded = {c.name for c in self._checkpoint_tracking_columns()}
15971608
udf_table_cols: list[sqlalchemy.Label[Any]] = [
1598-
label(c.name, c)
1599-
for c in udf_table_query.columns
1600-
if c.name not in ("sys__input_id", "sys__partial")
1609+
label(c.name, c) for c in udf_table_query.columns if c.name not in excluded
16011610
]
16021611

16031612
def q(*columns):
@@ -1606,11 +1615,7 @@ def q(*columns):
16061615
cols = [c for c in udf_table_cols if c.name in names]
16071616
return sqlalchemy.select(*cols).select_from(udf_table_query)
16081617

1609-
return q, [
1610-
c
1611-
for c in udf_table_query.columns
1612-
if c.name not in ("sys__input_id", "sys__partial")
1613-
]
1618+
return q, [c for c in udf_table_query.columns if c.name not in excluded]
16141619

16151620

16161621
@frozen

tests/func/checkpoints/test_checkpoint_recovery.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,22 +254,29 @@ def gen_multiple(num) -> Iterator[int]:
254254
assert len(all_results) == len(set(all_results)), "Should have no duplicate results"
255255

256256

257-
@pytest.mark.xfail(
258-
reason="Known limitation: inputs that yield nothing are not tracked "
259-
"in processed table"
260-
)
261257
def test_generator_yielding_nothing(test_session, monkeypatch, nums_dataset):
262-
"""Test that generator correctly handles inputs that yield zero outputs."""
258+
"""Test that generator correctly handles inputs that yield zero outputs.
259+
260+
Uses fail_after_count=4 so that regardless of DB row ordering, at least
261+
one odd input (which yields nothing) is processed before the failure.
262+
With 6 inputs (3 odd, 3 even), any 4 must include at least 1 odd.
263+
"""
264+
fail_after_count = 4
263265
processed = []
266+
skipped = []
267+
call_count = [0]
264268

265269
def selective_generator(num) -> Iterator[int]:
266270
processed.append(num)
267-
if num == 3:
271+
call_count[0] += 1
272+
if call_count[0] > fail_after_count:
268273
raise Exception("Simulated failure")
269274
if num % 2 == 0: # Only even numbers yield outputs
270275
yield num * 10
276+
else:
277+
skipped.append(num)
271278

272-
# First run - fails on num=3
279+
# First run - fails after processing 4 inputs
273280
reset_session_job_state()
274281
chain = dc.read_dataset("nums", session=test_session).gen(
275282
value=selective_generator, output=int
@@ -278,13 +285,24 @@ def selective_generator(num) -> Iterator[int]:
278285
with pytest.raises(Exception, match="Simulated failure"):
279286
chain.save("results")
280287

288+
assert len(processed) == fail_after_count + 1 # 4 succeeded + 1 failed
289+
skipped_first_run = list(skipped)
290+
assert skipped_first_run, "Expected at least one empty-yield input"
291+
281292
# Second run - should continue from checkpoint
282293
reset_session_job_state()
283294
processed.clear()
295+
skipped.clear()
296+
call_count[0] = 0
284297
chain.save("results")
285298

286-
# Only inputs 3,4,5,6 should be processed (1,2 were already done)
287-
assert processed == [3, 4, 5, 6]
299+
# Empty-yield inputs from first run must not be re-processed
300+
assert not set(skipped_first_run) & set(processed), (
301+
f"Empty-yield inputs {set(skipped_first_run) & set(processed)} "
302+
f"were re-processed despite being checkpointed"
303+
)
304+
assert len(processed) == 2
305+
# Final result: all even numbers yield output
288306
result = sorted(dc.read_dataset("results", session=test_session).to_list("value"))
289307
assert result == [(20,), (40,), (60,)]
290308

0 commit comments

Comments
 (0)