Skip to content

Commit b1c3a6e

Browse files
authored
dont duplicate temp tables on reusing queries (#1410)
1 parent 54c03fe commit b1c3a6e

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

src/datachain/query/dataset.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,9 @@ def query(
231231

232232
def apply(self, query_generator, temp_tables: list[str]) -> "StepResult":
233233
source_query = query_generator.exclude(("sys__id",))
234+
right_before = len(self.dq.temp_table_names)
234235
target_query = self.dq.apply_steps().select()
235-
temp_tables.extend(self.dq.temp_table_names)
236+
temp_tables.extend(self.dq.temp_table_names[right_before:])
236237

237238
# creating temp table that will hold subtract results
238239
temp_table_name = self.catalog.warehouse.temp_table_name()
@@ -951,10 +952,12 @@ def hash_inputs(self) -> str:
951952
def apply(
952953
self, query_generator: QueryGenerator, temp_tables: list[str]
953954
) -> StepResult:
955+
left_before = len(self.query1.temp_table_names)
954956
q1 = self.query1.apply_steps().select().subquery()
955-
temp_tables.extend(self.query1.temp_table_names)
957+
temp_tables.extend(self.query1.temp_table_names[left_before:])
958+
right_before = len(self.query2.temp_table_names)
956959
q2 = self.query2.apply_steps().select().subquery()
957-
temp_tables.extend(self.query2.temp_table_names)
960+
temp_tables.extend(self.query2.temp_table_names[right_before:])
958961

959962
columns1 = _drop_system_columns(q1.columns)
960963
columns2 = _drop_system_columns(q2.columns)
@@ -1004,8 +1007,9 @@ def hash_inputs(self) -> str:
10041007
return hashlib.sha256(b"".join(parts)).hexdigest()
10051008

10061009
def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
1010+
temp_tables_before = len(dq.temp_table_names)
10071011
query = dq.apply_steps().select()
1008-
temp_tables.extend(dq.temp_table_names)
1012+
temp_tables.extend(dq.temp_table_names[temp_tables_before:])
10091013

10101014
if not any(isinstance(step, (SQLJoin, SQLUnion)) for step in dq.steps):
10111015
return query.subquery(dq.table.name)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import datachain as dc
2+
from datachain.query.dataset import DatasetQuery
3+
4+
5+
def _capture_temp_tables(mocker):
6+
captured: list[list[str]] = []
7+
original_cleanup = DatasetQuery.cleanup
8+
9+
def capture(self):
10+
captured.append(list(self.temp_table_names))
11+
return original_cleanup(self)
12+
13+
mocker.patch("datachain.query.dataset.DatasetQuery.cleanup", capture)
14+
return captured
15+
16+
17+
def _assert_no_duplicate_temp_tables(captured: list[list[str]]):
18+
for tables in captured:
19+
assert len(tables) == len(set(tables))
20+
21+
22+
def test_nested_merge_has_no_duplicate_temp_tables(test_session, mocker):
23+
captured = _capture_temp_tables(mocker)
24+
25+
base = dc.read_values(num=[1, 2], session=test_session)
26+
generated = base.map(num_plus=lambda num: str(num + 10))
27+
inner = generated.merge(base, on="num", inner=True)
28+
chain = base.merge(inner, on="num", inner=True)
29+
30+
expected = chain.select("num").to_pandas()["num"].tolist()
31+
assert expected == [1, 2]
32+
33+
rerun = chain.select("num").to_pandas()["num"].tolist()
34+
assert rerun == expected
35+
36+
_assert_no_duplicate_temp_tables(captured)
37+
38+
39+
def test_union_has_no_duplicate_temp_tables(test_session, mocker):
40+
captured = _capture_temp_tables(mocker)
41+
42+
left = dc.read_values(num=[1, 2], session=test_session)
43+
right = dc.read_values(num=[3], session=test_session)
44+
union_chain = left.union(right)
45+
46+
expected = sorted(union_chain.select("num").to_pandas()["num"].tolist())
47+
assert expected == [1, 2, 3]
48+
49+
rerun = sorted(union_chain.select("num").to_pandas()["num"].tolist())
50+
assert rerun == expected
51+
52+
_assert_no_duplicate_temp_tables(captured)
53+
54+
55+
def test_subtract_has_no_duplicate_temp_tables(test_session, mocker):
56+
captured = _capture_temp_tables(mocker)
57+
58+
source = dc.read_values(num=[1, 2], session=test_session)
59+
target = dc.read_values(num=[2], session=test_session)
60+
subtract_chain = source.subtract(target, on="num")
61+
62+
expected = subtract_chain.select("num").to_pandas()["num"].tolist()
63+
assert expected == [1]
64+
65+
rerun = subtract_chain.select("num").to_pandas()["num"].tolist()
66+
assert rerun == expected
67+
68+
_assert_no_duplicate_temp_tables(captured)

0 commit comments

Comments
 (0)