Skip to content

Commit b5168e1

Browse files
authored
bugfix - PR #179 - *Delta > Writer*: Fixed checking merge builder type in DeltaTableWriter
There was a small error in validating merge builder type when using the `merge_builder` param with `DeltaTableWriter`. It required that the value passed was either a `list` or a `delta.tables.DeltaMergeBuilder` instance so validation failed when an instance of `delta.connect.tables.DeltaMergeBuilder` was passed. ## Motivation and Context One cannot pass a correct DeltaMergeBuilder instance as a merge builder for `DeltaTableWriter` if it's not a `delta.tables.DeltaMergeBuilder` instance or a `list`. Connect merge builders should also be supported. ## How Has This Been Tested? 4 dedicated test cases have been added to test the merge builder validation.
1 parent 163233b commit b5168e1

File tree

3 files changed

+155
-40
lines changed

3 files changed

+155
-40
lines changed

.github/workflows/test.yml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,18 @@ jobs:
3939
uses: actions/checkout@v4
4040
with:
4141
fetch-depth: 0
42-
ref: ${{ github.event.pull_request.head.ref }}
42+
ref: ${{ github.event.pull_request.head.sha }}
4343
repository: ${{ github.event.pull_request.head.repo.full_name }}
44-
- name: Fetch target branch
45-
run: git fetch origin ${{ github.event.pull_request.head.ref || 'main'}}:${{ github.event.pull_request.base.ref || 'main'}}
44+
4645
- name: Check changes
4746
id: check
4847
run: |
49-
# Set the base reference for the git diff
50-
BASE_REF=${{ github.event.pull_request.base.ref || 'main' }}
51-
52-
# Check for changes in this PR / commit
53-
git_diff_output=$(git diff --name-only $BASE_REF ${{ github.event.after }})
54-
48+
# Fetch the base branch
49+
git fetch origin ${{ github.event.pull_request.base.sha }}
50+
51+
# Get the diff between PR base and head
52+
git_diff_output=$(git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }})
53+
5554
# Count the number of changes to Python and TOML files
5655
python_changed=$(echo "$git_diff_output" | grep '\.py$' | wc -l)
5756
toml_changed=$(echo "$git_diff_output" | grep '\.toml$' | wc -l)

src/koheesio/spark/writers/delta/batch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,11 @@ def _validate_params(cls, params: dict) -> dict:
328328
clause = merge_conf.get("clause")
329329
if clause not in valid_clauses:
330330
raise ValueError(f"Invalid merge clause '{clause}' provided")
331-
elif (
332-
not isinstance(merge_builder, DeltaMergeBuilder)
333-
or not type(merge_builder).__name__ == "DeltaMergeBuilder"
331+
elif not (
332+
isinstance(merge_builder, DeltaMergeBuilder)
333+
or type(merge_builder).__name__ == "DeltaMergeBuilder"
334334
):
335-
raise ValueError("merge_builder must be a list or merge clauses or a DeltaMergeBuilder instance")
335+
raise ValueError("merge_builder must be a list of merge clauses or a DeltaMergeBuilder instance")
336336

337337
return params
338338

tests/spark/writers/delta/test_delta_writer.py

Lines changed: 143 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import MagicMock, patch
33

44
from conftest import await_job_completion
5+
from delta.tables import DeltaMergeBuilder
56
import pytest
67

78
from pydantic import ValidationError
@@ -24,7 +25,9 @@
2425

2526
def test_delta_table_writer(dummy_df, spark):
2627
table_name = "test_table"
27-
writer = DeltaTableWriter(table=table_name, output_mode=BatchOutputMode.APPEND, df=dummy_df)
28+
writer = DeltaTableWriter(
29+
table=table_name, output_mode=BatchOutputMode.APPEND, df=dummy_df
30+
)
2831
writer.execute()
2932
actual_count = spark.read.table(table_name).count()
3033
assert actual_count == 1
@@ -44,7 +47,10 @@ def test_delta_table_writer(dummy_df, spark):
4447
def test_delta_partitioning(spark, sample_df_to_partition):
4548
table_name = "partition_table"
4649
DeltaTableWriter(
47-
table=table_name, output_mode=BatchOutputMode.OVERWRITE, df=sample_df_to_partition, partition_by=["partition"]
50+
table=table_name,
51+
output_mode=BatchOutputMode.OVERWRITE,
52+
df=sample_df_to_partition,
53+
partition_by=["partition"],
4854
).execute()
4955
output_df = spark.read.table(table_name)
5056
assert output_df.count() == 2
@@ -55,7 +61,11 @@ def test_delta_table_merge_all(spark):
5561

5662
table_name = "test_merge_all_table"
5763
target_df = spark.createDataFrame(
58-
[{"id": 1, "value": "no_merge"}, {"id": 2, "value": "expected_merge"}, {"id": 5, "value": "xxxx"}]
64+
[
65+
{"id": 1, "value": "no_merge"},
66+
{"id": 2, "value": "expected_merge"},
67+
{"id": 5, "value": "xxxx"},
68+
]
5969
)
6070
source_df = spark.createDataFrame(
6171
[
@@ -73,7 +83,9 @@ def test_delta_table_merge_all(spark):
7383
# No merge as old value is greater
7484
5: "xxxx",
7585
}
76-
DeltaTableWriter(table=table_name, output_mode=BatchOutputMode.APPEND, df=target_df).execute()
86+
DeltaTableWriter(
87+
table=table_name, output_mode=BatchOutputMode.APPEND, df=target_df
88+
).execute()
7789
merge_writer = DeltaTableWriter(
7890
table=table_name,
7991
output_mode=BatchOutputMode.MERGEALL,
@@ -89,7 +101,9 @@ def test_delta_table_merge_all(spark):
89101
with pytest.raises(SparkConnectDeltaTableException) as exc_info:
90102
merge_writer.execute()
91103

92-
assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc")
104+
assert str(exc_info.value).startswith(
105+
"`DeltaTable.forName` is not supported due to delta calling _sc"
106+
)
93107
else:
94108
merge_writer.execute()
95109
result = {
@@ -107,12 +121,18 @@ def test_deltatablewriter_with_invalid_conditions(spark, dummy_df):
107121

108122
if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session():
109123
with pytest.raises(SparkConnectDeltaTableException) as exc_info:
110-
builder = get_delta_table_for_name(spark_session=spark, table_name=table_name)
124+
builder = get_delta_table_for_name(
125+
spark_session=spark, table_name=table_name
126+
)
111127

112-
assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc")
128+
assert str(exc_info.value).startswith(
129+
"`DeltaTable.forName` is not supported due to delta calling _sc"
130+
)
113131
else:
114132
with pytest.raises(AnalysisException):
115-
builder = get_delta_table_for_name(spark_session=spark, table_name=table_name)
133+
builder = get_delta_table_for_name(
134+
spark_session=spark, table_name=table_name
135+
)
116136
merge_builder = builder.alias("target").merge(
117137
condition="invalid_condition", source=dummy_df.alias("source")
118138
)
@@ -137,7 +157,10 @@ def test_delta_new_table_merge(spark):
137157
[
138158
{"id": 1, "value": "new_value"},
139159
{"id": 2, "value": "new_value"},
140-
{"id": 3, "value": None}, # Will be not saved, because source.value IS NOT NULL,
160+
{
161+
"id": 3,
162+
"value": None,
163+
}, # Will be not saved, because source.value IS NOT NULL,
141164
{"id": 4, "value": "new_value"},
142165
{"id": 5, "value": "new_value"},
143166
]
@@ -154,7 +177,8 @@ def test_delta_new_table_merge(spark):
154177
df=source_df,
155178
).execute()
156179
result = {
157-
list(row.asDict().values())[0]: list(row.asDict().values())[1] for row in spark.read.table(table_name).collect()
180+
list(row.asDict().values())[0]: list(row.asDict().values())[1]
181+
for row in spark.read.table(table_name).collect()
158182
}
159183
assert result == expected
160184

@@ -270,7 +294,9 @@ def test_delta_with_options(spark):
270294
"""
271295
sample_df = spark.createDataFrame([{"id": 1, "value": "test_value"}])
272296

273-
with patch("koheesio.spark.writers.delta.DeltaTableWriter.writer", new_callable=MagicMock) as mock_writer:
297+
with patch(
298+
"koheesio.spark.writers.delta.DeltaTableWriter.writer", new_callable=MagicMock
299+
) as mock_writer:
274300
delta_writer = DeltaTableWriter(
275301
table="test_table",
276302
output_mode=BatchOutputMode.APPEND,
@@ -279,9 +305,13 @@ def test_delta_with_options(spark):
279305
df=sample_df,
280306
)
281307
delta_writer.execute()
282-
mock_writer.options.assert_called_once_with(testParam1="testValue1", testParam2="testValue2")
308+
mock_writer.options.assert_called_once_with(
309+
testParam1="testValue1", testParam2="testValue2"
310+
)
283311

284-
with patch("koheesio.spark.writers.delta.DeltaTableWriter.writer", new_callable=MagicMock) as mock_writer:
312+
with patch(
313+
"koheesio.spark.writers.delta.DeltaTableWriter.writer", new_callable=MagicMock
314+
) as mock_writer:
285315
delta_writer = DeltaTableWriter(
286316
table="test_table",
287317
output_mode=BatchOutputMode.OVERWRITE,
@@ -290,7 +320,9 @@ def test_delta_with_options(spark):
290320
df=sample_df,
291321
)
292322
delta_writer.execute()
293-
mock_writer.options.assert_called_once_with(testParam1="testValue1", testParam2="testValue2")
323+
mock_writer.options.assert_called_once_with(
324+
testParam1="testValue1", testParam2="testValue2"
325+
)
294326

295327

296328
def test_merge_from_args(spark, dummy_df):
@@ -313,7 +345,11 @@ def test_merge_from_args(spark, dummy_df):
313345
output_mode=BatchOutputMode.MERGE,
314346
output_mode_params={
315347
"merge_builder": [
316-
{"clause": "whenMatchedUpdate", "set": {"id": "source.id"}, "condition": "source.id=target.id"},
348+
{
349+
"clause": "whenMatchedUpdate",
350+
"set": {"id": "source.id"},
351+
"condition": "source.id=target.id",
352+
},
317353
{
318354
"clause": "whenNotMatchedInsert",
319355
"values": {"id": "source.id"},
@@ -328,7 +364,9 @@ def test_merge_from_args(spark, dummy_df):
328364
with pytest.raises(SparkConnectDeltaTableException) as exc_info:
329365
writer._merge_builder_from_args()
330366

331-
assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc")
367+
assert str(exc_info.value).startswith(
368+
"`DeltaTable.forName` is not supported due to delta calling _sc"
369+
)
332370
else:
333371
writer._merge_builder_from_args()
334372
mock_delta_builder.whenMatchedUpdate.assert_called_once_with(
@@ -337,17 +375,21 @@ def test_merge_from_args(spark, dummy_df):
337375
mock_delta_builder.whenNotMatchedInsert.assert_called_once_with(
338376
values={"id": "source.id"}, condition="source.id IS NOT NULL"
339377
)
340-
assert ["clause" in c for c in writer.params["merge_builder"]] == [True] * len(
341-
writer.params["merge_builder"]
342-
)
378+
assert ["clause" in c for c in writer.params["merge_builder"]] == [
379+
True
380+
] * len(writer.params["merge_builder"])
343381

344382

345383
@pytest.mark.parametrize(
346384
"output_mode_params",
347385
[
348386
{
349387
"merge_builder": [
350-
{"clause": "NOT-SUPPORTED-MERGE-CLAUSE", "set": {"id": "source.id"}, "condition": "source.id=target.id"}
388+
{
389+
"clause": "NOT-SUPPORTED-MERGE-CLAUSE",
390+
"set": {"id": "source.id"},
391+
"condition": "source.id=target.id",
392+
}
351393
],
352394
"merge_cond": "source.id=target.id",
353395
},
@@ -368,7 +410,11 @@ def test_merge_no_table(spark):
368410

369411
table_name = "test_merge_no_table"
370412
target_df = spark.createDataFrame(
371-
[{"id": 1, "value": "no_merge"}, {"id": 2, "value": "expected_merge"}, {"id": 5, "value": "expected_merge"}]
413+
[
414+
{"id": 1, "value": "no_merge"},
415+
{"id": 2, "value": "expected_merge"},
416+
{"id": 5, "value": "expected_merge"},
417+
]
372418
)
373419
source_df = spark.createDataFrame(
374420
[
@@ -402,18 +448,26 @@ def test_merge_no_table(spark):
402448
"merge_cond": "source.id=target.id",
403449
}
404450
writer1 = DeltaTableWriter(
405-
df=target_df, table=table_name, output_mode=BatchOutputMode.MERGE, output_mode_params=params
451+
df=target_df,
452+
table=table_name,
453+
output_mode=BatchOutputMode.MERGE,
454+
output_mode_params=params,
406455
)
407456
writer2 = DeltaTableWriter(
408-
df=source_df, table=table_name, output_mode=BatchOutputMode.MERGE, output_mode_params=params
457+
df=source_df,
458+
table=table_name,
459+
output_mode=BatchOutputMode.MERGE,
460+
output_mode_params=params,
409461
)
410462
if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session():
411463
writer1.execute()
412464

413465
with pytest.raises(SparkConnectDeltaTableException) as exc_info:
414466
writer2.execute()
415467

416-
assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc")
468+
assert str(exc_info.value).startswith(
469+
"`DeltaTable.forName` is not supported due to delta calling _sc"
470+
)
417471
else:
418472
writer1.execute()
419473
writer2.execute()
@@ -435,10 +489,14 @@ def test_log_clauses(mocker):
435489

436490
mock_clause = mocker.MagicMock()
437491
mock_clause.clauseType.return_value = "test"
438-
mock_clause.actions.return_value.toList.return_value.apply.return_value.toString.return_value = "test_column"
492+
mock_clause.actions.return_value.toList.return_value.apply.return_value.toString.return_value = (
493+
"test_column"
494+
)
439495

440496
mock_condition = mocker.MagicMock()
441-
mock_condition.value.return_value.toString.return_value = "source_alias == target_alias"
497+
mock_condition.value.return_value.toString.return_value = (
498+
"source_alias == target_alias"
499+
)
442500
mock_condition.toString.return_value = "None"
443501
mock_clause.condition.return_value = mock_condition
444502

@@ -448,4 +506,62 @@ def test_log_clauses(mocker):
448506
result = log_clauses(mock_clauses, "source_alias", "target_alias")
449507

450508
# Assert the result
451-
assert result == "Test will perform action:Test columns (test_column) if `source_alias == target_alias`"
509+
assert (
510+
result
511+
== "Test will perform action:Test columns (test_column) if `source_alias == target_alias`"
512+
)
513+
514+
515+
def test_merge_builder_type__list_of_merge_builders(mocker, spark):
516+
table_name = "test_merge_builder_type"
517+
df = spark.createDataFrame([{"id": 1, "value": "test"}])
518+
merge_builder = mocker.MagicMock(spec=list) # mocks a list of merge builders
519+
# No ValueError should be raised
520+
DeltaTableWriter(
521+
df=df,
522+
table=table_name,
523+
output_mode=BatchOutputMode.MERGE,
524+
output_mode_params={"merge_builder": merge_builder},
525+
)
526+
527+
528+
def test_merge_builder_type___delta_merge_builder(mocker, spark):
529+
table_name = "test_merge_builder_type"
530+
df = spark.createDataFrame([{"id": 1, "value": "test"}])
531+
merge_builder = mocker.MagicMock(spec=DeltaMergeBuilder)
532+
# No ValueError should be raised
533+
DeltaTableWriter(
534+
df=df,
535+
table=table_name,
536+
output_mode=BatchOutputMode.MERGE,
537+
output_mode_params={"merge_builder": merge_builder},
538+
)
539+
540+
541+
def test_merge_builder_type__connect_delta_merge_builder(mocker, spark):
542+
table_name = "test_merge_builder_type"
543+
df = spark.createDataFrame([{"id": 1, "value": "test"}])
544+
# Not a delta.tables.DeltaMergeBuilder instance but the name is DeltaMergeBuilder (as in delta.connect.tables.DeltaMergeBuilder)
545+
merge_builder = mocker.MagicMock()
546+
merge_builder.__class__.__name__ = "DeltaMergeBuilder"
547+
DeltaTableWriter(
548+
df=df,
549+
table=table_name,
550+
output_mode=BatchOutputMode.MERGE,
551+
output_mode_params={"merge_builder": merge_builder},
552+
)
553+
554+
555+
def test_merge_builder_type__invalid_merge_builder(mocker, spark):
556+
table_name = "test_merge_builder_type"
557+
df = spark.createDataFrame([{"id": 1, "value": "test"}])
558+
merge_builder = mocker.MagicMock(
559+
spec=str
560+
) # Not a DeltaMergeBuilder instance nor a list
561+
with pytest.raises(ValueError):
562+
DeltaTableWriter(
563+
df=df,
564+
table=table_name,
565+
output_mode=BatchOutputMode.MERGE,
566+
output_mode_params={"merge_builder": merge_builder},
567+
)

0 commit comments

Comments
 (0)