Skip to content

Commit 12802ad

Browse files
committed
Draft implementation done. Pending tests and verify optimal approach
1 parent 14f3194 commit 12802ad

File tree

2 files changed

+77
-27
lines changed

2 files changed

+77
-27
lines changed

awswrangler/athena/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
_logger: logging.Logger = logging.getLogger(__name__)
4545

4646
class _MergeClause(TypedDict, total=False):
47-
when: Literal["MATCHED", "NOT_MATCHED", "NOT_MATCHED_BY_SOURCE"]
47+
when: Literal["MATCHED", "NOT MATCHED", "NOT MATCHED BY SOURCE"]
4848
condition: str | None
4949
action: Literal["UPDATE", "DELETE", "INSERT"]
5050
columns: list[str] | None

awswrangler/athena/_write_iceberg.py

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,10 @@ def _validate_args(
220220
mode: Literal["append", "overwrite", "overwrite_partitions"],
221221
partition_cols: list[str] | None,
222222
merge_cols: list[str] | None,
223-
merge_on_condition: str | None,
223+
merge_on_clause: str | None,
224224
merge_condition: Literal["update", "ignore", "conditional_merge"],
225225
merge_conditional_clauses: list[_MergeClause] | None,
226+
merge_match_nulls: bool,
226227
) -> None:
227228
if df.empty is True:
228229
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")
@@ -232,17 +233,22 @@ def _validate_args(
232233
"Either path or workgroup path must be specified to store the temporary results."
233234
)
234235

235-
if merge_cols and merge_on_condition:
236+
if merge_cols and merge_on_clause:
236237
raise exceptions.InvalidArgumentCombination(
237-
"Cannot specify both merge_cols and merge_on_condition. Use either merge_cols for simple equality matching or merge_on_condition for custom logic."
238+
"Cannot specify both merge_cols and merge_on_clause. Use either merge_cols for simple equality matching or merge_on_clause for custom logic."
239+
)
240+
241+
if merge_on_clause and merge_match_nulls:
242+
raise exceptions.InvalidArgumentCombination(
243+
"merge_match_nulls can only be used together with merge_cols."
238244
)
239245

240246
if merge_conditional_clauses and merge_condition != "conditional_merge":
241247
raise exceptions.InvalidArgumentCombination(
242248
"merge_conditional_clauses can only be used when merge_condition is 'conditional_merge'."
243249
)
244250

245-
if (merge_cols or merge_on_condition) and merge_condition not in ["update", "ignore", "conditional_merge"]:
251+
if (merge_cols or merge_on_clause) and merge_condition not in ["update", "ignore", "conditional_merge"]:
246252
raise exceptions.InvalidArgumentValue(
247253
f"Invalid merge_condition: {merge_condition}. Valid values: ['update', 'ignore', 'conditional_merge']"
248254
)
@@ -262,9 +268,9 @@ def _validate_args(
262268
raise exceptions.InvalidArgumentValue(
263269
f"merge_conditional_clauses[{i}] must contain 'action' field."
264270
)
265-
if clause["when"] not in ['MATCHED', 'NOT_MATCHED', 'NOT_MATCHED_BY_SOURCE']:
271+
if clause["when"] not in ['MATCHED', 'NOT MATCHED', 'NOT MATCHED BY SOURCE']:
266272
raise exceptions.InvalidArgumentValue(
267-
f"merge_conditional_clauses[{i}]['when'] must be one of ['MATCHED', 'NOT_MATCHED', 'NOT_MATCHED_BY_SOURCE']."
273+
f"merge_conditional_clauses[{i}]['when'] must be one of ['MATCHED', 'NOT MATCHED', 'NOT MATCHED BY SOURCE']."
268274
)
269275
if clause["action"] not in ["UPDATE", "DELETE", "INSERT", "IGNORE"]:
270276
raise exceptions.InvalidArgumentValue(
@@ -287,7 +293,9 @@ def _merge_iceberg(
287293
table: str,
288294
source_table: str,
289295
merge_cols: list[str] | None = None,
290-
merge_condition: Literal["update", "ignore"] = "update",
296+
merge_on_clause: str | None = None,
297+
merge_condition: Literal["update", "ignore", "conditional_merge"] = "update",
298+
merge_conditional_clauses: list[_MergeClause] | None = None,
291299
merge_match_nulls: bool = False,
292300
kms_key: str | None = None,
293301
boto3_session: boto3.Session | None = None,
@@ -342,27 +350,66 @@ def _merge_iceberg(
342350
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
343351

344352
sql_statement: str
345-
if merge_cols:
346-
if merge_condition == "update":
347-
match_condition = f"""WHEN MATCHED THEN
348-
UPDATE SET {", ".join([f'"{x}" = source."{x}"' for x in df.columns])}"""
349-
else:
350-
match_condition = ""
351-
352-
if merge_match_nulls:
353-
merge_conditions = [f'(target."{x}" IS NOT DISTINCT FROM source."{x}")' for x in merge_cols]
353+
if merge_cols or merge_on_clause:
354+
if merge_on_clause:
355+
on_condition = merge_on_clause
354356
else:
355-
merge_conditions = [f'(target."{x}" = source."{x}")' for x in merge_cols]
356-
357+
if merge_match_nulls:
358+
merge_conditions = [f'(target."{x}" IS NOT DISTINCT FROM source."{x}")' for x in merge_cols]
359+
else:
360+
merge_conditions = [f'(target."{x}" = source."{x}")' for x in merge_cols]
361+
on_condition = " AND ".join(merge_conditions)
362+
363+
# Build WHEN clauses based on merge_condition
364+
when_clauses = []
365+
366+
if merge_condition == "update":
367+
when_clauses.append(f"""WHEN MATCHED THEN
368+
UPDATE SET {", ".join([f'"{x}" = source."{x}"' for x in df.columns])}""")
369+
when_clauses.append(f"""WHEN NOT MATCHED THEN
370+
INSERT ({", ".join([f'"{x}"' for x in df.columns])})
371+
VALUES ({", ".join([f'source."{x}"' for x in df.columns])})""")
372+
373+
elif merge_condition == "ignore":
374+
when_clauses.append(f"""WHEN NOT MATCHED THEN
375+
INSERT ({", ".join([f'"{x}"' for x in df.columns])})
376+
VALUES ({", ".join([f'source."{x}"' for x in df.columns])})""")
377+
378+
elif merge_condition == "conditional_merge":
379+
for clause in merge_conditional_clauses:
380+
when_type = clause["when"]
381+
action = clause["action"]
382+
condition = clause.get("condition")
383+
columns = clause.get("columns")
384+
385+
# Build WHEN clause
386+
when_part = f"WHEN {when_type}"
387+
if condition:
388+
when_part += f" AND {condition}"
389+
390+
# Build action
391+
if action == "UPDATE":
392+
update_columns = columns or df.columns.tolist()
393+
update_sets = [f'"{col}" = source."{col}"' for col in update_columns]
394+
when_part += f" THEN UPDATE SET {', '.join(update_sets)}"
395+
396+
elif action == "DELETE":
397+
when_part += " THEN DELETE"
398+
399+
elif action == "INSERT":
400+
insert_columns = columns or df.columns.tolist()
401+
column_list = ", ".join([f'"{col}"' for col in insert_columns])
402+
values_list = ", ".join([f'source."{col}"' for col in insert_columns])
403+
when_part += f" THEN INSERT ({column_list}) VALUES ({values_list})"
404+
405+
when_clauses.append(when_part)
406+
357407
sql_statement = f"""
358408
MERGE INTO "{database}"."{table}" target
359409
USING "{database}"."{source_table}" source
360-
ON {" AND ".join(merge_conditions)}
361-
{match_condition}
362-
WHEN NOT MATCHED THEN
363-
INSERT ({", ".join([f'"{x}"' for x in df.columns])})
364-
VALUES ({", ".join([f'source."{x}"' for x in df.columns])})
365-
"""
410+
ON {on_condition}
411+
{"\n ".join(when_clauses)}
412+
"""
366413
else:
367414
sql_statement = f"""
368415
INSERT INTO "{database}"."{table}" ({", ".join([f'"{x}"' for x in df.columns])})
@@ -397,7 +444,7 @@ def to_iceberg( # noqa: PLR0913
397444
table_location: str | None = None,
398445
partition_cols: list[str] | None = None,
399446
merge_cols: list[str] | None = None,
400-
merge_on_condition: str | None = None,
447+
merge_on_clause: str | None = None,
401448
merge_condition: Literal["update", "ignore", "conditional_merge"] = "update",
402449
merge_conditional_clauses: list[_MergeClause] | None = None,
403450
merge_match_nulls: bool = False,
@@ -536,9 +583,10 @@ def to_iceberg( # noqa: PLR0913
536583
mode=mode,
537584
partition_cols=partition_cols,
538585
merge_cols=merge_cols,
539-
merge_on_condition=merge_on_condition,
586+
merge_on_clause=merge_on_clause,
540587
merge_condition=merge_condition,
541588
merge_conditional_clauses=merge_conditional_clauses,
589+
merge_match_nulls=merge_match_nulls,
542590
)
543591

544592
glue_table_settings = glue_table_settings if glue_table_settings else {}
@@ -661,7 +709,9 @@ def to_iceberg( # noqa: PLR0913
661709
table=table,
662710
source_table=temp_table,
663711
merge_cols=merge_cols,
712+
merge_on_clause=merge_on_clause,
664713
merge_condition=merge_condition,
714+
merge_conditional_clauses=merge_conditional_clauses,
665715
merge_match_nulls=merge_match_nulls,
666716
kms_key=kms_key,
667717
boto3_session=boto3_session,

0 commit comments

Comments
 (0)