22from unittest .mock import MagicMock , patch
33
44from conftest import await_job_completion
5+ from delta .tables import DeltaMergeBuilder
56import pytest
67
78from pydantic import ValidationError
2425
2526def test_delta_table_writer (dummy_df , spark ):
2627 table_name = "test_table"
27- writer = DeltaTableWriter (table = table_name , output_mode = BatchOutputMode .APPEND , df = dummy_df )
28+ writer = DeltaTableWriter (
29+ table = table_name , output_mode = BatchOutputMode .APPEND , df = dummy_df
30+ )
2831 writer .execute ()
2932 actual_count = spark .read .table (table_name ).count ()
3033 assert actual_count == 1
@@ -44,7 +47,10 @@ def test_delta_table_writer(dummy_df, spark):
4447def test_delta_partitioning (spark , sample_df_to_partition ):
4548 table_name = "partition_table"
4649 DeltaTableWriter (
47- table = table_name , output_mode = BatchOutputMode .OVERWRITE , df = sample_df_to_partition , partition_by = ["partition" ]
50+ table = table_name ,
51+ output_mode = BatchOutputMode .OVERWRITE ,
52+ df = sample_df_to_partition ,
53+ partition_by = ["partition" ],
4854 ).execute ()
4955 output_df = spark .read .table (table_name )
5056 assert output_df .count () == 2
@@ -55,7 +61,11 @@ def test_delta_table_merge_all(spark):
5561
5662 table_name = "test_merge_all_table"
5763 target_df = spark .createDataFrame (
58- [{"id" : 1 , "value" : "no_merge" }, {"id" : 2 , "value" : "expected_merge" }, {"id" : 5 , "value" : "xxxx" }]
64+ [
65+ {"id" : 1 , "value" : "no_merge" },
66+ {"id" : 2 , "value" : "expected_merge" },
67+ {"id" : 5 , "value" : "xxxx" },
68+ ]
5969 )
6070 source_df = spark .createDataFrame (
6171 [
@@ -73,7 +83,9 @@ def test_delta_table_merge_all(spark):
7383 # No merge as old value is greater
7484 5 : "xxxx" ,
7585 }
76- DeltaTableWriter (table = table_name , output_mode = BatchOutputMode .APPEND , df = target_df ).execute ()
86+ DeltaTableWriter (
87+ table = table_name , output_mode = BatchOutputMode .APPEND , df = target_df
88+ ).execute ()
7789 merge_writer = DeltaTableWriter (
7890 table = table_name ,
7991 output_mode = BatchOutputMode .MERGEALL ,
@@ -89,7 +101,9 @@ def test_delta_table_merge_all(spark):
89101 with pytest .raises (SparkConnectDeltaTableException ) as exc_info :
90102 merge_writer .execute ()
91103
92- assert str (exc_info .value ).startswith ("`DeltaTable.forName` is not supported due to delta calling _sc" )
104+ assert str (exc_info .value ).startswith (
105+ "`DeltaTable.forName` is not supported due to delta calling _sc"
106+ )
93107 else :
94108 merge_writer .execute ()
95109 result = {
@@ -107,12 +121,18 @@ def test_deltatablewriter_with_invalid_conditions(spark, dummy_df):
107121
108122 if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session ():
109123 with pytest .raises (SparkConnectDeltaTableException ) as exc_info :
110- builder = get_delta_table_for_name (spark_session = spark , table_name = table_name )
124+ builder = get_delta_table_for_name (
125+ spark_session = spark , table_name = table_name
126+ )
111127
112- assert str (exc_info .value ).startswith ("`DeltaTable.forName` is not supported due to delta calling _sc" )
128+ assert str (exc_info .value ).startswith (
129+ "`DeltaTable.forName` is not supported due to delta calling _sc"
130+ )
113131 else :
114132 with pytest .raises (AnalysisException ):
115- builder = get_delta_table_for_name (spark_session = spark , table_name = table_name )
133+ builder = get_delta_table_for_name (
134+ spark_session = spark , table_name = table_name
135+ )
116136 merge_builder = builder .alias ("target" ).merge (
117137 condition = "invalid_condition" , source = dummy_df .alias ("source" )
118138 )
@@ -137,7 +157,10 @@ def test_delta_new_table_merge(spark):
137157 [
138158 {"id" : 1 , "value" : "new_value" },
139159 {"id" : 2 , "value" : "new_value" },
140- {"id" : 3 , "value" : None }, # Will be not saved, because source.value IS NOT NULL,
160+ {
161+ "id" : 3 ,
162+ "value" : None ,
163+ }, # Will be not saved, because source.value IS NOT NULL,
141164 {"id" : 4 , "value" : "new_value" },
142165 {"id" : 5 , "value" : "new_value" },
143166 ]
@@ -154,7 +177,8 @@ def test_delta_new_table_merge(spark):
154177 df = source_df ,
155178 ).execute ()
156179 result = {
157- list (row .asDict ().values ())[0 ]: list (row .asDict ().values ())[1 ] for row in spark .read .table (table_name ).collect ()
180+ list (row .asDict ().values ())[0 ]: list (row .asDict ().values ())[1 ]
181+ for row in spark .read .table (table_name ).collect ()
158182 }
159183 assert result == expected
160184
@@ -270,7 +294,9 @@ def test_delta_with_options(spark):
270294 """
271295 sample_df = spark .createDataFrame ([{"id" : 1 , "value" : "test_value" }])
272296
273- with patch ("koheesio.spark.writers.delta.DeltaTableWriter.writer" , new_callable = MagicMock ) as mock_writer :
297+ with patch (
298+ "koheesio.spark.writers.delta.DeltaTableWriter.writer" , new_callable = MagicMock
299+ ) as mock_writer :
274300 delta_writer = DeltaTableWriter (
275301 table = "test_table" ,
276302 output_mode = BatchOutputMode .APPEND ,
@@ -279,9 +305,13 @@ def test_delta_with_options(spark):
279305 df = sample_df ,
280306 )
281307 delta_writer .execute ()
282- mock_writer .options .assert_called_once_with (testParam1 = "testValue1" , testParam2 = "testValue2" )
308+ mock_writer .options .assert_called_once_with (
309+ testParam1 = "testValue1" , testParam2 = "testValue2"
310+ )
283311
284- with patch ("koheesio.spark.writers.delta.DeltaTableWriter.writer" , new_callable = MagicMock ) as mock_writer :
312+ with patch (
313+ "koheesio.spark.writers.delta.DeltaTableWriter.writer" , new_callable = MagicMock
314+ ) as mock_writer :
285315 delta_writer = DeltaTableWriter (
286316 table = "test_table" ,
287317 output_mode = BatchOutputMode .OVERWRITE ,
@@ -290,7 +320,9 @@ def test_delta_with_options(spark):
290320 df = sample_df ,
291321 )
292322 delta_writer .execute ()
293- mock_writer .options .assert_called_once_with (testParam1 = "testValue1" , testParam2 = "testValue2" )
323+ mock_writer .options .assert_called_once_with (
324+ testParam1 = "testValue1" , testParam2 = "testValue2"
325+ )
294326
295327
296328def test_merge_from_args (spark , dummy_df ):
@@ -313,7 +345,11 @@ def test_merge_from_args(spark, dummy_df):
313345 output_mode = BatchOutputMode .MERGE ,
314346 output_mode_params = {
315347 "merge_builder" : [
316- {"clause" : "whenMatchedUpdate" , "set" : {"id" : "source.id" }, "condition" : "source.id=target.id" },
348+ {
349+ "clause" : "whenMatchedUpdate" ,
350+ "set" : {"id" : "source.id" },
351+ "condition" : "source.id=target.id" ,
352+ },
317353 {
318354 "clause" : "whenNotMatchedInsert" ,
319355 "values" : {"id" : "source.id" },
@@ -328,7 +364,9 @@ def test_merge_from_args(spark, dummy_df):
328364 with pytest .raises (SparkConnectDeltaTableException ) as exc_info :
329365 writer ._merge_builder_from_args ()
330366
331- assert str (exc_info .value ).startswith ("`DeltaTable.forName` is not supported due to delta calling _sc" )
367+ assert str (exc_info .value ).startswith (
368+ "`DeltaTable.forName` is not supported due to delta calling _sc"
369+ )
332370 else :
333371 writer ._merge_builder_from_args ()
334372 mock_delta_builder .whenMatchedUpdate .assert_called_once_with (
@@ -337,17 +375,21 @@ def test_merge_from_args(spark, dummy_df):
337375 mock_delta_builder .whenNotMatchedInsert .assert_called_once_with (
338376 values = {"id" : "source.id" }, condition = "source.id IS NOT NULL"
339377 )
340- assert ["clause" in c for c in writer .params ["merge_builder" ]] == [True ] * len (
341- writer . params [ "merge_builder" ]
342- )
378+ assert ["clause" in c for c in writer .params ["merge_builder" ]] == [
379+ True
380+ ] * len ( writer . params [ "merge_builder" ] )
343381
344382
345383@pytest .mark .parametrize (
346384 "output_mode_params" ,
347385 [
348386 {
349387 "merge_builder" : [
350- {"clause" : "NOT-SUPPORTED-MERGE-CLAUSE" , "set" : {"id" : "source.id" }, "condition" : "source.id=target.id" }
388+ {
389+ "clause" : "NOT-SUPPORTED-MERGE-CLAUSE" ,
390+ "set" : {"id" : "source.id" },
391+ "condition" : "source.id=target.id" ,
392+ }
351393 ],
352394 "merge_cond" : "source.id=target.id" ,
353395 },
@@ -368,7 +410,11 @@ def test_merge_no_table(spark):
368410
369411 table_name = "test_merge_no_table"
370412 target_df = spark .createDataFrame (
371- [{"id" : 1 , "value" : "no_merge" }, {"id" : 2 , "value" : "expected_merge" }, {"id" : 5 , "value" : "expected_merge" }]
413+ [
414+ {"id" : 1 , "value" : "no_merge" },
415+ {"id" : 2 , "value" : "expected_merge" },
416+ {"id" : 5 , "value" : "expected_merge" },
417+ ]
372418 )
373419 source_df = spark .createDataFrame (
374420 [
@@ -402,18 +448,26 @@ def test_merge_no_table(spark):
402448 "merge_cond" : "source.id=target.id" ,
403449 }
404450 writer1 = DeltaTableWriter (
405- df = target_df , table = table_name , output_mode = BatchOutputMode .MERGE , output_mode_params = params
451+ df = target_df ,
452+ table = table_name ,
453+ output_mode = BatchOutputMode .MERGE ,
454+ output_mode_params = params ,
406455 )
407456 writer2 = DeltaTableWriter (
408- df = source_df , table = table_name , output_mode = BatchOutputMode .MERGE , output_mode_params = params
457+ df = source_df ,
458+ table = table_name ,
459+ output_mode = BatchOutputMode .MERGE ,
460+ output_mode_params = params ,
409461 )
410462 if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session ():
411463 writer1 .execute ()
412464
413465 with pytest .raises (SparkConnectDeltaTableException ) as exc_info :
414466 writer2 .execute ()
415467
416- assert str (exc_info .value ).startswith ("`DeltaTable.forName` is not supported due to delta calling _sc" )
468+ assert str (exc_info .value ).startswith (
469+ "`DeltaTable.forName` is not supported due to delta calling _sc"
470+ )
417471 else :
418472 writer1 .execute ()
419473 writer2 .execute ()
@@ -435,10 +489,14 @@ def test_log_clauses(mocker):
435489
436490 mock_clause = mocker .MagicMock ()
437491 mock_clause .clauseType .return_value = "test"
438- mock_clause .actions .return_value .toList .return_value .apply .return_value .toString .return_value = "test_column"
492+ mock_clause .actions .return_value .toList .return_value .apply .return_value .toString .return_value = (
493+ "test_column"
494+ )
439495
440496 mock_condition = mocker .MagicMock ()
441- mock_condition .value .return_value .toString .return_value = "source_alias == target_alias"
497+ mock_condition .value .return_value .toString .return_value = (
498+ "source_alias == target_alias"
499+ )
442500 mock_condition .toString .return_value = "None"
443501 mock_clause .condition .return_value = mock_condition
444502
@@ -448,4 +506,62 @@ def test_log_clauses(mocker):
448506 result = log_clauses (mock_clauses , "source_alias" , "target_alias" )
449507
450508 # Assert the result
451- assert result == "Test will perform action:Test columns (test_column) if `source_alias == target_alias`"
509+ assert (
510+ result
511+ == "Test will perform action:Test columns (test_column) if `source_alias == target_alias`"
512+ )
513+
514+
515+ def test_merge_builder_type__list_of_merge_builders (mocker , spark ):
516+ table_name = "test_merge_builder_type"
517+ df = spark .createDataFrame ([{"id" : 1 , "value" : "test" }])
518+ merge_builder = mocker .MagicMock (spec = list ) # mocks a list of merge builders
519+ # No ValueError should be raised
520+ DeltaTableWriter (
521+ df = df ,
522+ table = table_name ,
523+ output_mode = BatchOutputMode .MERGE ,
524+ output_mode_params = {"merge_builder" : merge_builder },
525+ )
526+
527+
528+ def test_merge_builder_type___delta_merge_builder (mocker , spark ):
529+ table_name = "test_merge_builder_type"
530+ df = spark .createDataFrame ([{"id" : 1 , "value" : "test" }])
531+ merge_builder = mocker .MagicMock (spec = DeltaMergeBuilder )
532+ # No ValueError should be raised
533+ DeltaTableWriter (
534+ df = df ,
535+ table = table_name ,
536+ output_mode = BatchOutputMode .MERGE ,
537+ output_mode_params = {"merge_builder" : merge_builder },
538+ )
539+
540+
541+ def test_merge_builder_type__connect_delta_merge_builder (mocker , spark ):
542+ table_name = "test_merge_builder_type"
543+ df = spark .createDataFrame ([{"id" : 1 , "value" : "test" }])
544+ # Not a delta.tables.DeltaMergeBuilder instance but the name is DeltaMergeBuilder (as in delta.connect.tables.DeltaMergeBuilder)
545+ merge_builder = mocker .MagicMock ()
546+ merge_builder .__class__ .__name__ = "DeltaMergeBuilder"
547+ DeltaTableWriter (
548+ df = df ,
549+ table = table_name ,
550+ output_mode = BatchOutputMode .MERGE ,
551+ output_mode_params = {"merge_builder" : merge_builder },
552+ )
553+
554+
555+ def test_merge_builder_type__invalid_merge_builder (mocker , spark ):
556+ table_name = "test_merge_builder_type"
557+ df = spark .createDataFrame ([{"id" : 1 , "value" : "test" }])
558+ merge_builder = mocker .MagicMock (
559+ spec = str
560+ ) # Not a DeltaMergeBuilder instance nor a list
561+ with pytest .raises (ValueError ):
562+ DeltaTableWriter (
563+ df = df ,
564+ table = table_name ,
565+ output_mode = BatchOutputMode .MERGE ,
566+ output_mode_params = {"merge_builder" : merge_builder },
567+ )
0 commit comments