Skip to content

Commit 128a0fa

Browse files
authored
Minor improvements for .save_table(mode="overwrite") (#298)
This PR includes some updates for `.save_table()` in overwrite mode: - If no rows are supplied, instead of a no-op the table is truncated. The mock backend already assumed this but diverged from the concrete implementations which treated this as a no-op (as when appending). Unit tests now cover this situation for all backends, and the existing integration test for the SQL-statement backend has been updated to cover this. - The SQL-based backends have a slight optimisation: instead of first truncating before inserting the truncate is now performed as part of the insert for the first batch. - Type hints on the abstract method now match the concrete implementations.
1 parent dbfc823 commit 128a0fa

File tree

4 files changed

+244
-35
lines changed

4 files changed

+244
-35
lines changed

src/databricks/labs/lsql/backends.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,20 +135,29 @@ def execute(self, sql: str, *, catalog: str | None = None, schema: str | None =
135135
def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Any]:
136136
raise NotImplementedError
137137

138-
def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"):
138+
def save_table(
139+
self,
140+
full_name: str,
141+
rows: Sequence[DataclassInstance],
142+
klass: Dataclass,
143+
mode: Literal["append", "overwrite"] = "append",
144+
):
139145
rows = self._filter_none_rows(rows, klass)
140146
self.create_table(full_name, klass)
141-
if len(rows) == 0:
147+
if not rows:
148+
if mode == "overwrite":
149+
self.execute(f"TRUNCATE TABLE {full_name}")
142150
return
143151
fields = dataclasses.fields(klass)
144152
field_names = [f.name for f in fields]
145-
if mode == "overwrite":
146-
self.execute(f"TRUNCATE TABLE {full_name}")
153+
insert_modifier = "OVERWRITE" if mode == "overwrite" else "INTO"
147154
for i in range(0, len(rows), self._max_records_per_batch):
148155
batch = rows[i : i + self._max_records_per_batch]
149156
vals = "), (".join(self._row_to_sql(r, fields) for r in batch)
150-
sql = f'INSERT INTO {full_name} ({", ".join(field_names)}) VALUES ({vals})'
157+
sql = f'INSERT {insert_modifier} {full_name} ({", ".join(field_names)}) VALUES ({vals})'
151158
self.execute(sql)
159+
# Only the first batch can truncate; subsequent batches append.
160+
insert_modifier = "INTO"
152161

153162
@classmethod
154163
def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]):
@@ -277,8 +286,7 @@ def save_table(
277286
mode: Literal["append", "overwrite"] = "append",
278287
) -> None:
279288
rows = self._filter_none_rows(rows, klass)
280-
281-
if len(rows) == 0:
289+
if not rows and mode == "append":
282290
self.create_table(full_name, klass)
283291
return
284292
# pyspark deals well with lists of dataclass instances, as long as schema is provided

tests/integration/test_backends.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ def test_statement_execution_backend_overwrites_table(ws, env_or_skip, make_rand
162162
rows = list(sql_backend.fetch(f"SELECT * FROM {catalog}.{schema}.foo"))
163163
assert rows == [Row(first="xyz", second=True)]
164164

165+
sql_backend.save_table(f"{catalog}.{schema}.foo", [], views.Foo, "overwrite")
166+
167+
rows = list(sql_backend.fetch(f"SELECT * FROM {catalog}.{schema}.foo"))
168+
assert rows == []
169+
165170

166171
def test_runtime_backend_use_statements(ws):
167172
product_info = ProductInfo.for_testing(SqlBackend)

tests/unit/test_backends.py

Lines changed: 147 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import sys
44
from dataclasses import dataclass
5+
from typing import Literal
56
from unittest import mock
67
from unittest.mock import MagicMock, call, create_autospec
78

@@ -137,17 +138,7 @@ def test_statement_execution_backend_save_table_overwrite_empty_table():
137138
),
138139
mock.call(
139140
warehouse_id="abc",
140-
statement="TRUNCATE TABLE a.b.c",
141-
catalog=None,
142-
schema=None,
143-
disposition=None,
144-
format=Format.JSON_ARRAY,
145-
byte_limit=None,
146-
wait_timeout=None,
147-
),
148-
mock.call(
149-
warehouse_id="abc",
150-
statement="INSERT INTO a.b.c (first, second) VALUES ('1', NULL)",
141+
statement="INSERT OVERWRITE a.b.c (first, second) VALUES ('1', NULL)",
151142
catalog=None,
152143
schema=None,
153144
disposition=None,
@@ -170,7 +161,7 @@ def test_statement_execution_backend_save_table_empty_records():
170161

171162
seb.save_table("a.b.c", [], Bar)
172163

173-
ws.statement_execution.execute_statement.assert_called_with(
164+
ws.statement_execution.execute_statement.assert_called_once_with(
174165
warehouse_id="abc",
175166
statement="CREATE TABLE IF NOT EXISTS a.b.c "
176167
"(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA",
@@ -183,6 +174,44 @@ def test_statement_execution_backend_save_table_empty_records():
183174
)
184175

185176

177+
def test_statement_execution_backend_save_table_overwrite_empty_records() -> None:
178+
ws = create_autospec(WorkspaceClient)
179+
180+
ws.statement_execution.execute_statement.return_value = StatementResponse(
181+
status=StatementStatus(state=StatementState.SUCCEEDED)
182+
)
183+
184+
seb = StatementExecutionBackend(ws, "abc")
185+
186+
seb.save_table("a.b.c", [], Bar, mode="overwrite")
187+
188+
ws.statement_execution.execute_statement.assert_has_calls(
189+
[
190+
call(
191+
warehouse_id="abc",
192+
statement="CREATE TABLE IF NOT EXISTS a.b.c "
193+
"(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA",
194+
catalog=None,
195+
schema=None,
196+
disposition=None,
197+
format=Format.JSON_ARRAY,
198+
byte_limit=None,
199+
wait_timeout=None,
200+
),
201+
call(
202+
warehouse_id="abc",
203+
statement="TRUNCATE TABLE a.b.c",
204+
catalog=None,
205+
schema=None,
206+
disposition=None,
207+
format=Format.JSON_ARRAY,
208+
byte_limit=None,
209+
wait_timeout=None,
210+
),
211+
]
212+
)
213+
214+
186215
def test_statement_execution_backend_save_table_two_records():
187216
ws = create_autospec(WorkspaceClient)
188217

@@ -220,7 +249,7 @@ def test_statement_execution_backend_save_table_two_records():
220249
)
221250

222251

223-
def test_statement_execution_backend_save_table_in_batches_of_two():
252+
def test_statement_execution_backend_save_table_append_in_batches_of_two() -> None:
224253
ws = create_autospec(WorkspaceClient)
225254

226255
ws.statement_execution.execute_statement.return_value = StatementResponse(
@@ -229,7 +258,7 @@ def test_statement_execution_backend_save_table_in_batches_of_two():
229258

230259
seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2)
231260

232-
seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo)
261+
seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo, mode="append")
233262

234263
ws.statement_execution.execute_statement.assert_has_calls(
235264
[
@@ -267,6 +296,53 @@ def test_statement_execution_backend_save_table_in_batches_of_two():
267296
)
268297

269298

299+
def test_statement_execution_backend_save_table_overwrite_in_batches_of_two() -> None:
300+
ws = create_autospec(WorkspaceClient)
301+
302+
ws.statement_execution.execute_statement.return_value = StatementResponse(
303+
status=StatementStatus(state=StatementState.SUCCEEDED)
304+
)
305+
306+
seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2)
307+
308+
seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo, mode="overwrite")
309+
310+
ws.statement_execution.execute_statement.assert_has_calls(
311+
[
312+
mock.call(
313+
warehouse_id="abc",
314+
statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA",
315+
catalog=None,
316+
schema=None,
317+
disposition=None,
318+
format=Format.JSON_ARRAY,
319+
byte_limit=None,
320+
wait_timeout=None,
321+
),
322+
mock.call(
323+
warehouse_id="abc",
324+
statement="INSERT OVERWRITE a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)",
325+
catalog=None,
326+
schema=None,
327+
disposition=None,
328+
format=Format.JSON_ARRAY,
329+
byte_limit=None,
330+
wait_timeout=None,
331+
),
332+
mock.call(
333+
warehouse_id="abc",
334+
statement="INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)",
335+
catalog=None,
336+
schema=None,
337+
disposition=None,
338+
format=Format.JSON_ARRAY,
339+
byte_limit=None,
340+
wait_timeout=None,
341+
),
342+
]
343+
)
344+
345+
270346
def test_runtime_backend_execute():
271347
with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}):
272348
pyspark_sql_session = MagicMock()
@@ -298,21 +374,53 @@ def test_runtime_backend_fetch():
298374
spark.sql.assert_has_calls(calls)
299375

300376

301-
def test_runtime_backend_save_table():
377+
@pytest.mark.parametrize("mode", ["append", "overwrite"])
378+
def test_runtime_backend_save_table(mode: Literal["append", "overwrite"]) -> None:
302379
with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}):
303380
pyspark_sql_session = MagicMock()
304381
sys.modules["pyspark.sql.session"] = pyspark_sql_session
305382
spark = pyspark_sql_session.SparkSession.builder.getOrCreate()
306383

307384
runtime_backend = RuntimeBackend()
308385

309-
runtime_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo)
386+
runtime_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo, mode=mode)
310387

311-
spark.createDataFrame.assert_called_with(
388+
spark.createDataFrame.assert_called_once_with(
312389
[Foo(first="aaa", second=True), Foo(first="bbb", second=False)],
313390
"first STRING NOT NULL, second BOOLEAN NOT NULL",
314391
)
315-
spark.createDataFrame().write.saveAsTable.assert_called_with("a.b.c", mode="append")
392+
spark.createDataFrame().write.saveAsTable.assert_called_once_with("a.b.c", mode=mode)
393+
394+
395+
def test_runtime_backend_save_table_append_empty_records() -> None:
396+
with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}):
397+
pyspark_sql_session = MagicMock()
398+
sys.modules["pyspark.sql.session"] = pyspark_sql_session
399+
spark = pyspark_sql_session.SparkSession.builder.getOrCreate()
400+
401+
runtime_backend = RuntimeBackend()
402+
403+
runtime_backend.save_table("a.b.c", [], Foo, mode="append")
404+
405+
spark.createDataFrame.assert_not_called()
406+
spark.createDataFrame().write.saveAsTable.assert_not_called()
407+
spark.sql.assert_called_once_with(
408+
"CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA"
409+
)
410+
411+
412+
def test_runtime_backend_save_table_overwrite_empty_records() -> None:
413+
with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}):
414+
pyspark_sql_session = MagicMock()
415+
sys.modules["pyspark.sql.session"] = pyspark_sql_session
416+
spark = pyspark_sql_session.SparkSession.builder.getOrCreate()
417+
418+
runtime_backend = RuntimeBackend()
419+
420+
runtime_backend.save_table("a.b.c", [], Foo, mode="overwrite")
421+
422+
spark.createDataFrame.assert_called_once_with([], "first STRING NOT NULL, second BOOLEAN NOT NULL")
423+
spark.createDataFrame().write.saveAsTable.assert_called_once_with("a.b.c", mode="overwrite")
316424

317425

318426
def test_runtime_backend_save_table_with_row_containing_none_with_nullable_class(mocker):
@@ -427,6 +535,27 @@ def test_mock_backend_save_table_overwrite() -> None:
427535
]
428536

429537

538+
def test_mock_backend_save_table_no_rows() -> None:
539+
mock_backend = MockBackend()
540+
541+
mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo)
542+
mock_backend.save_table("a.b.c", [], Foo)
543+
544+
assert mock_backend.rows_written_for("a.b.c", mode="append") == [
545+
Row(first="aaa", second=True),
546+
Row(first="bbb", second=False),
547+
]
548+
549+
550+
def test_mock_backend_save_table_overwrite_no_rows() -> None:
551+
mock_backend = MockBackend()
552+
553+
mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo)
554+
mock_backend.save_table("a.b.c", [], Foo)
555+
556+
assert mock_backend.rows_written_for("a.b.c", mode="overwrite") == []
557+
558+
430559
def test_mock_backend_rows_dsl():
431560
rows = MockBackend.rows("foo", "bar")[
432561
[1, 2],

0 commit comments

Comments
 (0)