Skip to content

Commit 7258b2e

Browse files
authored
always drop sys id in merge (#1424)
1 parent 6322a5c commit 7258b2e

File tree

4 files changed

+47
-8
lines changed

4 files changed

+47
-8
lines changed

src/datachain/delta.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def _get_retry_chain(
150150
error_records = result_dataset.filter(C(delta_retry) != "")
151151
error_source_records = source_dc.merge(
152152
error_records, on=on, right_on=right_on, inner=True
153-
).select(*list(source_dc.signals_schema.values))
153+
).select(
154+
*list(source_dc.signals_schema.clone_without_sys_signals().values.keys())
155+
)
154156
retry_chain = error_source_records
155157

156158
# Handle missing records if delta_retry is True

src/datachain/lib/dc/datachain.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,14 +1697,13 @@ def _resolve(
16971697
query.feature_schema = None
16981698
ds = self._evolve(query=query)
16991699

1700+
# Note: merge drops sys signals from both sides, make sure to not include it
1701+
# in the resulting schema
17001702
signals_schema = self.signals_schema.clone_without_sys_signals()
17011703
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
17021704

17031705
ds.signals_schema = signals_schema.merge(right_signals_schema, rname)
17041706

1705-
if not full:
1706-
ds.signals_schema = SignalSchema({"sys": Sys}) | ds.signals_schema
1707-
17081707
return ds
17091708

17101709
@delta_disabled

src/datachain/query/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,7 @@ def apply(
10651065
q1 = self.get_query(self.query1, temp_tables)
10661066
q2 = self.get_query(self.query2, temp_tables)
10671067

1068-
q1_columns = _drop_system_columns(q1.c) if self.full else list(q1.c)
1068+
q1_columns = _drop_system_columns(q1.c)
10691069
q1_column_names = {c.name for c in q1_columns}
10701070

10711071
q2_columns = []

tests/unit/lib/test_datachain_merge.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_merge_similar_objects(test_session):
140140
rname = "qq"
141141
ch = ch1.merge(ch2, "emp.person.name", rname=rname)
142142

143-
assert list(ch.signals_schema.values.keys()) == ["sys", "emp", rname + "emp"]
143+
assert list(ch.signals_schema.values.keys()) == ["emp", rname + "emp"]
144144

145145
empl = list(ch.to_list())
146146
assert len(empl) == 4
@@ -175,7 +175,7 @@ def test_merge_similar_objects_in_memory():
175175
assert ch.session.catalog.metastore.db.db_file == ":memory:"
176176
assert ch.session.catalog.warehouse.db.db_file == ":memory:"
177177

178-
assert list(ch.signals_schema.values.keys()) == ["sys", "emp", rname + "emp"]
178+
assert list(ch.signals_schema.values.keys()) == ["emp", rname + "emp"]
179179

180180
empl = list(ch.to_list())
181181
assert len(empl) == 4
@@ -198,7 +198,6 @@ def test_merge_values(test_session):
198198
ch = ch1.merge(ch2, "id")
199199

200200
assert list(ch.signals_schema.values.keys()) == [
201-
"sys",
202201
"id",
203202
"descr",
204203
"right_id",
@@ -339,3 +338,42 @@ def _get_expr(chain):
339338
count += 1
340339

341340
assert count == len(team) * len(team)
341+
342+
343+
def test_merge_with_drops_sys_columns(test_session):
344+
left = dc.read_values(id=[1, 1], lval=[10, 20], session=test_session)
345+
right = dc.read_values(id=[1, 1], rval=["a", "b"], session=test_session)
346+
347+
merged = left.merge(right, on="id")
348+
349+
assert "sys" not in merged.signals_schema.values
350+
351+
cols = merged.settings(sys=True).to_pandas(flatten=True).columns
352+
assert all(not str(col).startswith("sys") for col in cols)
353+
354+
ds_name = "merge_left_dups_sys_check_sys"
355+
merged.save(ds_name)
356+
357+
df_with_sys = (
358+
dc.read_dataset(ds_name, session=test_session)
359+
.settings(sys=True)
360+
.to_pandas(flatten=True)
361+
)
362+
363+
sys_cols = [c for c in df_with_sys.columns if str(c).startswith("sys")]
364+
assert sys_cols
365+
366+
def _col(name: str) -> str:
367+
for col in df_with_sys.columns:
368+
if str(col) == f"sys.{name}":
369+
return str(col)
370+
raise AssertionError(f"Missing sys column for {name}")
371+
372+
sys_id_col = _col("id")
373+
sys_rand_col = _col("rand")
374+
375+
sys_ids = list(df_with_sys[sys_id_col])
376+
assert len(sys_ids) == len(set(sys_ids))
377+
378+
sys_rand = list(df_with_sys[sys_rand_col])
379+
assert len(sys_rand) == len(set(sys_rand))

0 commit comments

Comments
 (0)