22import os
33import sys
44from dataclasses import dataclass
5+ from typing import Literal
56from unittest import mock
67from unittest .mock import MagicMock , call , create_autospec
78
@@ -137,17 +138,7 @@ def test_statement_execution_backend_save_table_overwrite_empty_table():
137138 ),
138139 mock .call (
139140 warehouse_id = "abc" ,
140- statement = "TRUNCATE TABLE a.b.c" ,
141- catalog = None ,
142- schema = None ,
143- disposition = None ,
144- format = Format .JSON_ARRAY ,
145- byte_limit = None ,
146- wait_timeout = None ,
147- ),
148- mock .call (
149- warehouse_id = "abc" ,
150- statement = "INSERT INTO a.b.c (first, second) VALUES ('1', NULL)" ,
141+ statement = "INSERT OVERWRITE a.b.c (first, second) VALUES ('1', NULL)" ,
151142 catalog = None ,
152143 schema = None ,
153144 disposition = None ,
@@ -170,7 +161,7 @@ def test_statement_execution_backend_save_table_empty_records():
170161
171162 seb .save_table ("a.b.c" , [], Bar )
172163
173- ws .statement_execution .execute_statement .assert_called_with (
164+ ws .statement_execution .execute_statement .assert_called_once_with (
174165 warehouse_id = "abc" ,
175166 statement = "CREATE TABLE IF NOT EXISTS a.b.c "
176167 "(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA" ,
@@ -183,6 +174,44 @@ def test_statement_execution_backend_save_table_empty_records():
183174 )
184175
185176
177+ def test_statement_execution_backend_save_table_overwrite_empty_records () -> None :
178+ ws = create_autospec (WorkspaceClient )
179+
180+ ws .statement_execution .execute_statement .return_value = StatementResponse (
181+ status = StatementStatus (state = StatementState .SUCCEEDED )
182+ )
183+
184+ seb = StatementExecutionBackend (ws , "abc" )
185+
186+ seb .save_table ("a.b.c" , [], Bar , mode = "overwrite" )
187+
188+ ws .statement_execution .execute_statement .assert_has_calls (
189+ [
190+ call (
191+ warehouse_id = "abc" ,
192+ statement = "CREATE TABLE IF NOT EXISTS a.b.c "
193+ "(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA" ,
194+ catalog = None ,
195+ schema = None ,
196+ disposition = None ,
197+ format = Format .JSON_ARRAY ,
198+ byte_limit = None ,
199+ wait_timeout = None ,
200+ ),
201+ call (
202+ warehouse_id = "abc" ,
203+ statement = "TRUNCATE TABLE a.b.c" ,
204+ catalog = None ,
205+ schema = None ,
206+ disposition = None ,
207+ format = Format .JSON_ARRAY ,
208+ byte_limit = None ,
209+ wait_timeout = None ,
210+ ),
211+ ]
212+ )
213+
214+
186215def test_statement_execution_backend_save_table_two_records ():
187216 ws = create_autospec (WorkspaceClient )
188217
@@ -220,7 +249,7 @@ def test_statement_execution_backend_save_table_two_records():
220249 )
221250
222251
223- def test_statement_execution_backend_save_table_in_batches_of_two () :
252+ def test_statement_execution_backend_save_table_append_in_batches_of_two () -> None :
224253 ws = create_autospec (WorkspaceClient )
225254
226255 ws .statement_execution .execute_statement .return_value = StatementResponse (
@@ -229,7 +258,7 @@ def test_statement_execution_backend_save_table_in_batches_of_two():
229258
230259 seb = StatementExecutionBackend (ws , "abc" , max_records_per_batch = 2 )
231260
232- seb .save_table ("a.b.c" , [Foo ("aaa" , True ), Foo ("bbb" , False ), Foo ("ccc" , True )], Foo )
261+ seb .save_table ("a.b.c" , [Foo ("aaa" , True ), Foo ("bbb" , False ), Foo ("ccc" , True )], Foo , mode = "append" )
233262
234263 ws .statement_execution .execute_statement .assert_has_calls (
235264 [
@@ -267,6 +296,53 @@ def test_statement_execution_backend_save_table_in_batches_of_two():
267296 )
268297
269298
299+ def test_statement_execution_backend_save_table_overwrite_in_batches_of_two () -> None :
300+ ws = create_autospec (WorkspaceClient )
301+
302+ ws .statement_execution .execute_statement .return_value = StatementResponse (
303+ status = StatementStatus (state = StatementState .SUCCEEDED )
304+ )
305+
306+ seb = StatementExecutionBackend (ws , "abc" , max_records_per_batch = 2 )
307+
308+ seb .save_table ("a.b.c" , [Foo ("aaa" , True ), Foo ("bbb" , False ), Foo ("ccc" , True )], Foo , mode = "overwrite" )
309+
310+ ws .statement_execution .execute_statement .assert_has_calls (
311+ [
312+ mock .call (
313+ warehouse_id = "abc" ,
314+ statement = "CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA" ,
315+ catalog = None ,
316+ schema = None ,
317+ disposition = None ,
318+ format = Format .JSON_ARRAY ,
319+ byte_limit = None ,
320+ wait_timeout = None ,
321+ ),
322+ mock .call (
323+ warehouse_id = "abc" ,
324+ statement = "INSERT OVERWRITE a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)" ,
325+ catalog = None ,
326+ schema = None ,
327+ disposition = None ,
328+ format = Format .JSON_ARRAY ,
329+ byte_limit = None ,
330+ wait_timeout = None ,
331+ ),
332+ mock .call (
333+ warehouse_id = "abc" ,
334+ statement = "INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)" ,
335+ catalog = None ,
336+ schema = None ,
337+ disposition = None ,
338+ format = Format .JSON_ARRAY ,
339+ byte_limit = None ,
340+ wait_timeout = None ,
341+ ),
342+ ]
343+ )
344+
345+
270346def test_runtime_backend_execute ():
271347 with mock .patch .dict (os .environ , {"DATABRICKS_RUNTIME_VERSION" : "14.0" }):
272348 pyspark_sql_session = MagicMock ()
@@ -298,21 +374,53 @@ def test_runtime_backend_fetch():
298374 spark .sql .assert_has_calls (calls )
299375
300376
301- def test_runtime_backend_save_table ():
377+ @pytest .mark .parametrize ("mode" , ["append" , "overwrite" ])
378+ def test_runtime_backend_save_table (mode : Literal ["append" , "overwrite" ]) -> None :
302379 with mock .patch .dict (os .environ , {"DATABRICKS_RUNTIME_VERSION" : "14.0" }):
303380 pyspark_sql_session = MagicMock ()
304381 sys .modules ["pyspark.sql.session" ] = pyspark_sql_session
305382 spark = pyspark_sql_session .SparkSession .builder .getOrCreate ()
306383
307384 runtime_backend = RuntimeBackend ()
308385
309- runtime_backend .save_table ("a.b.c" , [Foo ("aaa" , True ), Foo ("bbb" , False )], Foo )
386+ runtime_backend .save_table ("a.b.c" , [Foo ("aaa" , True ), Foo ("bbb" , False )], Foo , mode = mode )
310387
311- spark .createDataFrame .assert_called_with (
388+ spark .createDataFrame .assert_called_once_with (
312389 [Foo (first = "aaa" , second = True ), Foo (first = "bbb" , second = False )],
313390 "first STRING NOT NULL, second BOOLEAN NOT NULL" ,
314391 )
315- spark .createDataFrame ().write .saveAsTable .assert_called_with ("a.b.c" , mode = "append" )
392+ spark .createDataFrame ().write .saveAsTable .assert_called_once_with ("a.b.c" , mode = mode )
393+
394+
395+ def test_runtime_backend_save_table_append_empty_records () -> None :
396+ with mock .patch .dict (os .environ , {"DATABRICKS_RUNTIME_VERSION" : "14.0" }):
397+ pyspark_sql_session = MagicMock ()
398+ sys .modules ["pyspark.sql.session" ] = pyspark_sql_session
399+ spark = pyspark_sql_session .SparkSession .builder .getOrCreate ()
400+
401+ runtime_backend = RuntimeBackend ()
402+
403+ runtime_backend .save_table ("a.b.c" , [], Foo , mode = "append" )
404+
405+ spark .createDataFrame .assert_not_called ()
406+ spark .createDataFrame ().write .saveAsTable .assert_not_called ()
407+ spark .sql .assert_called_once_with (
408+ "CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA"
409+ )
410+
411+
412+ def test_runtime_backend_save_table_overwrite_empty_records () -> None :
413+ with mock .patch .dict (os .environ , {"DATABRICKS_RUNTIME_VERSION" : "14.0" }):
414+ pyspark_sql_session = MagicMock ()
415+ sys .modules ["pyspark.sql.session" ] = pyspark_sql_session
416+ spark = pyspark_sql_session .SparkSession .builder .getOrCreate ()
417+
418+ runtime_backend = RuntimeBackend ()
419+
420+ runtime_backend .save_table ("a.b.c" , [], Foo , mode = "overwrite" )
421+
422+ spark .createDataFrame .assert_called_once_with ([], "first STRING NOT NULL, second BOOLEAN NOT NULL" )
423+ spark .createDataFrame ().write .saveAsTable .assert_called_once_with ("a.b.c" , mode = "overwrite" )
316424
317425
318426def test_runtime_backend_save_table_with_row_containing_none_with_nullable_class (mocker ):
@@ -427,6 +535,27 @@ def test_mock_backend_save_table_overwrite() -> None:
427535 ]
428536
429537
538+ def test_mock_backend_save_table_no_rows () -> None :
539+ mock_backend = MockBackend ()
540+
541+ mock_backend .save_table ("a.b.c" , [Foo ("aaa" , True ), Foo ("bbb" , False )], Foo )
542+ mock_backend .save_table ("a.b.c" , [], Foo )
543+
544+ assert mock_backend .rows_written_for ("a.b.c" , mode = "append" ) == [
545+ Row (first = "aaa" , second = True ),
546+ Row (first = "bbb" , second = False ),
547+ ]
548+
549+
550+ def test_mock_backend_save_table_overwrite_no_rows () -> None :
551+ mock_backend = MockBackend ()
552+
553+ mock_backend .save_table ("a.b.c" , [Foo ("aaa" , True ), Foo ("bbb" , False )], Foo )
554+ mock_backend .save_table ("a.b.c" , [], Foo )
555+
556+ assert mock_backend .rows_written_for ("a.b.c" , mode = "overwrite" ) == []
557+
558+
430559def test_mock_backend_rows_dsl ():
431560 rows = MockBackend .rows ("foo" , "bar" )[
432561 [1 , 2 ],
0 commit comments