Skip to content

Commit 6322a5c

Browse files
authored
fix(diff): edge case when there no columns except provided in on (#1421)
1 parent fc7181c commit 6322a5c

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

src/datachain/diff/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def _to_list(obj: str | Sequence[str] | None) -> list[str] | None:
103103
left = left.mutate(**{ldiff_col: 1})
104104
right = right.mutate(**{rdiff_col: 1})
105105

106-
if not compare:
106+
if compare is None:
107107
modified_cond = True
108+
elif len(compare) == 0:
109+
modified_cond = False
108110
else:
109111
modified_cond = or_( # type: ignore[assignment]
110112
*[

tests/unit/lib/test_checkpoints.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_checkpoints(
3535
catalog = test_session.catalog
3636
metastore = catalog.metastore
3737

38-
monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", reset_checkpoints)
38+
monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(reset_checkpoints))
3939

4040
if with_delta:
4141
chain = dc.read_dataset(
@@ -75,8 +75,9 @@ def test_checkpoints(
7575
chain.save("nums3")
7676
second_job_id = test_session.get_or_create_job().id
7777

78-
assert len(catalog.get_dataset("nums1").versions) == 2 if reset_checkpoints else 1
79-
assert len(catalog.get_dataset("nums2").versions) == 2 if reset_checkpoints else 1
78+
expected_versions = 1 if with_delta or not reset_checkpoints else 2
79+
assert len(catalog.get_dataset("nums1").versions) == expected_versions
80+
assert len(catalog.get_dataset("nums2").versions) == expected_versions
8081
assert len(catalog.get_dataset("nums3").versions) == 1
8182

8283
assert len(list(catalog.metastore.list_checkpoints(first_job_id))) == 2
@@ -88,7 +89,7 @@ def test_checkpoints_modified_chains(
8889
test_session, monkeypatch, nums_dataset, reset_checkpoints
8990
):
9091
catalog = test_session.catalog
91-
monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", reset_checkpoints)
92+
monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(reset_checkpoints))
9293

9394
chain = dc.read_dataset("nums", session=test_session)
9495

@@ -120,7 +121,7 @@ def test_checkpoints_multiple_runs(
120121
):
121122
catalog = test_session.catalog
122123

123-
monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", reset_checkpoints)
124+
monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(reset_checkpoints))
124125

125126
chain = dc.read_dataset("nums", session=test_session)
126127

@@ -184,7 +185,7 @@ def test_checkpoints_check_valid_chain_is_returned(
184185
monkeypatch,
185186
nums_dataset,
186187
):
187-
monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", False)
188+
monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(False))
188189
chain = dc.read_dataset("nums", session=test_session)
189190

190191
# -------------- FIRST RUN -------------------
@@ -197,6 +198,7 @@ def test_checkpoints_check_valid_chain_is_returned(
197198

198199
# checking that we return expected DataChain even though we skipped chain creation
199200
# because of the checkpoints
201+
assert ds.dataset is not None
200202
assert ds.dataset.name == "nums1"
201203
assert len(ds.dataset.versions) == 1
202204
assert ds.order_by("num").to_list("num") == [(1,), (2,), (3,)]

tests/unit/lib/test_diff.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,24 @@ def test_diff_on_equal_datasets(test_session, on_self):
256256
assert diff.order_by("id").to_list(*collect_fields) == expected
257257

258258

259+
def test_diff_only_on_columns_treated_as_same(test_session):
260+
ds1 = dc.read_values(
261+
id=[1, 2],
262+
session=test_session,
263+
)
264+
ds2 = dc.read_values(
265+
id=[1, 2],
266+
session=test_session,
267+
)
268+
269+
diff = ds1.diff(ds2, on=["id"], same=True, status_col="diff")
270+
271+
assert diff.order_by("id").to_list("diff", "id") == [
272+
(CompareStatus.SAME, 1),
273+
(CompareStatus.SAME, 2),
274+
]
275+
276+
259277
def test_diff_multiple_columns(test_session, str_default):
260278
ds1 = dc.read_values(
261279
id=[1, 2, 4],
@@ -382,7 +400,7 @@ def test_diff_missing_on(test_session):
382400
ds2 = dc.read_values(id=[1, 2, 4], session=test_session)
383401

384402
with pytest.raises(ValueError) as exc_info:
385-
ds1.diff(ds2, on=None)
403+
ds1.diff(ds2, on=None) # type: ignore[arg-type]
386404

387405
assert str(exc_info.value) == "'on' must be specified"
388406

tests/unit/test_datachain_hash.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,4 @@ def test_diff(test_session):
218218
status_col="diff",
219219
)
220220
.hash()
221-
) == "4135f2deffa91702259de50b48076dd2f8cdf3be32c167332840209c137977f9"
221+
) == "8ffac19b12cf96e2916968914d357c4a9c1b81038c43ab5cf97ba1127fb86567"

0 commit comments

Comments
 (0)