Skip to content

Commit 831a72a

Browse files
feat: Add schema evolution to athena.to_iceberg (#2465)
1 parent 0d1eede commit 831a72a

File tree

2 files changed

+267
-2
lines changed

2 files changed

+267
-2
lines changed

awswrangler/athena/_write_iceberg.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Amazon Athena Module containing all to_* write functions."""
22

33
import logging
4+
import typing
45
import uuid
5-
from typing import Any, Dict, List, Optional
6+
from typing import Any, Dict, List, Optional, Set, TypedDict
67

78
import boto3
89
import pandas as pd
910

10-
from awswrangler import _utils, catalog, exceptions, s3
11+
from awswrangler import _data_types, _utils, catalog, exceptions, s3
1112
from awswrangler._config import apply_configs
1213
from awswrangler.athena._executions import wait_query
1314
from awswrangler.athena._utils import (
@@ -67,6 +68,111 @@ def _create_iceberg_table(
6768
wait_query(query_execution_id=query_execution_id, boto3_session=boto3_session)
6869

6970

71+
class _SchemaChanges(TypedDict):
72+
to_add: Dict[str, str]
73+
to_change: Dict[str, str]
74+
to_remove: Set[str]
75+
76+
77+
def _determine_differences(
78+
df: pd.DataFrame,
79+
database: str,
80+
table: str,
81+
index: bool,
82+
partition_cols: Optional[List[str]],
83+
boto3_session: Optional[boto3.Session],
84+
dtype: Optional[Dict[str, str]],
85+
catalog_id: Optional[str],
86+
) -> _SchemaChanges:
87+
frame_columns_types, frame_partitions_types = _data_types.athena_types_from_pandas_partitioned(
88+
df=df, index=index, partition_cols=partition_cols, dtype=dtype
89+
)
90+
frame_columns_types.update(frame_partitions_types)
91+
92+
catalog_column_types = typing.cast(
93+
Dict[str, str],
94+
catalog.get_table_types(database=database, table=table, catalog_id=catalog_id, boto3_session=boto3_session),
95+
)
96+
97+
original_columns = set(catalog_column_types)
98+
new_columns = set(frame_columns_types)
99+
100+
to_add = {col: frame_columns_types[col] for col in new_columns - original_columns}
101+
to_remove = original_columns - new_columns
102+
103+
columns_to_change = [
104+
col
105+
for col in original_columns.intersection(new_columns)
106+
if frame_columns_types[col] != catalog_column_types[col]
107+
]
108+
to_change = {col: frame_columns_types[col] for col in columns_to_change}
109+
110+
return _SchemaChanges(to_add=to_add, to_change=to_change, to_remove=to_remove)
111+
112+
113+
def _alter_iceberg_table(
114+
database: str,
115+
table: str,
116+
schema_changes: _SchemaChanges,
117+
wg_config: _WorkGroupConfig,
118+
data_source: Optional[str] = None,
119+
workgroup: Optional[str] = None,
120+
encryption: Optional[str] = None,
121+
kms_key: Optional[str] = None,
122+
boto3_session: Optional[boto3.Session] = None,
123+
) -> None:
124+
sql_statements: List[str] = []
125+
126+
if schema_changes["to_add"]:
127+
sql_statements += _alter_iceberg_table_add_columns_sql(
128+
table=table,
129+
columns_to_add=schema_changes["to_add"],
130+
)
131+
132+
if schema_changes["to_change"]:
133+
sql_statements += _alter_iceberg_table_change_columns_sql(
134+
table=table,
135+
columns_to_change=schema_changes["to_change"],
136+
)
137+
138+
if schema_changes["to_remove"]:
139+
raise exceptions.InvalidArgumentCombination("Removing columns of Iceberg tables is not currently supported.")
140+
141+
for statement in sql_statements:
142+
query_execution_id: str = _start_query_execution(
143+
sql=statement,
144+
workgroup=workgroup,
145+
wg_config=wg_config,
146+
database=database,
147+
data_source=data_source,
148+
encryption=encryption,
149+
kms_key=kms_key,
150+
boto3_session=boto3_session,
151+
)
152+
wait_query(query_execution_id=query_execution_id, boto3_session=boto3_session)
153+
154+
155+
def _alter_iceberg_table_add_columns_sql(
156+
table: str,
157+
columns_to_add: Dict[str, str],
158+
) -> List[str]:
159+
add_cols_str = ", ".join([f"{col_name} {columns_to_add[col_name]}" for col_name in columns_to_add])
160+
161+
return [f"ALTER TABLE {table} ADD COLUMNS ({add_cols_str})"]
162+
163+
164+
def _alter_iceberg_table_change_columns_sql(
165+
table: str,
166+
columns_to_change: Dict[str, str],
167+
) -> List[str]:
168+
sql_statements = []
169+
170+
for col_name, col_type in columns_to_change.items():
171+
sql_statements.append(f"ALTER TABLE {table} CHANGE COLUMN {col_name} {col_name} {col_type}")
172+
173+
return sql_statements
174+
175+
70176
@apply_configs
71177
@_utils.validate_distributed_kwargs(
72178
unsupported_kwargs=["boto3_session", "s3_additional_kwargs"],
@@ -89,6 +195,7 @@ def to_iceberg(
89195
additional_table_properties: Optional[Dict[str, Any]] = None,
90196
dtype: Optional[Dict[str, str]] = None,
91197
catalog_id: Optional[str] = None,
198+
schema_evolution: bool = False,
92199
) -> None:
93200
"""
94201
Insert into Athena Iceberg table using INSERT INTO ... SELECT. Will create Iceberg table if it does not exist.
@@ -143,6 +250,8 @@ def to_iceberg(
143250
catalog_id : str, optional
144251
The ID of the Data Catalog from which to retrieve Databases.
145252
If none is provided, the AWS account ID is used by default
253+
schema_evolution: bool
254+
If True allows schema evolution for new columns or changes in column types.
146255
147256
Returns
148257
-------
@@ -206,6 +315,31 @@ def to_iceberg(
206315
boto3_session=boto3_session,
207316
dtype=dtype,
208317
)
318+
else:
319+
schema_differences = _determine_differences(
320+
df=df,
321+
database=database,
322+
table=table,
323+
index=index,
324+
partition_cols=partition_cols,
325+
boto3_session=boto3_session,
326+
dtype=dtype,
327+
catalog_id=catalog_id,
328+
)
329+
if schema_evolution is False and any([schema_differences[x] for x in schema_differences]): # type: ignore[literal-required]
330+
raise exceptions.InvalidArgumentValue(f"Schema change detected: {schema_differences}")
331+
332+
_alter_iceberg_table(
333+
database=database,
334+
table=table,
335+
schema_changes=schema_differences,
336+
wg_config=wg_config,
337+
data_source=data_source,
338+
workgroup=workgroup,
339+
encryption=encryption,
340+
kms_key=kms_key,
341+
boto3_session=boto3_session,
342+
)
209343

210344
# Create temporary external table, write the results
211345
s3.to_parquet(

tests/unit/test_athena.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,137 @@ def test_athena_to_iceberg(path, path2, glue_database, glue_table, partition_col
15251525
assert df.equals(df_out)
15261526

15271527

1528+
def test_athena_to_iceberg_schema_evolution_add_columns(
1529+
path: str, path2: str, glue_database: str, glue_table: str
1530+
) -> None:
1531+
df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5]})
1532+
wr.athena.to_iceberg(
1533+
df=df,
1534+
database=glue_database,
1535+
table=glue_table,
1536+
table_location=path,
1537+
temp_path=path2,
1538+
keep_files=False,
1539+
schema_evolution=True,
1540+
)
1541+
1542+
df["c2"] = [6, 7, 8]
1543+
wr.athena.to_iceberg(
1544+
df=df,
1545+
database=glue_database,
1546+
table=glue_table,
1547+
table_location=path,
1548+
temp_path=path2,
1549+
keep_files=False,
1550+
schema_evolution=True,
1551+
)
1552+
1553+
column_types = wr.catalog.get_table_types(glue_database, glue_table)
1554+
assert len(column_types) == len(df.columns)
1555+
1556+
df_out = wr.athena.read_sql_table(
1557+
table=glue_table,
1558+
database=glue_database,
1559+
ctas_approach=False,
1560+
unload_approach=False,
1561+
)
1562+
assert len(df_out) == len(df) * 2
1563+
1564+
df["c3"] = [9, 10, 11]
1565+
with pytest.raises(wr.exceptions.InvalidArgumentValue):
1566+
wr.athena.to_iceberg(
1567+
df=df,
1568+
database=glue_database,
1569+
table=glue_table,
1570+
table_location=path,
1571+
temp_path=path2,
1572+
keep_files=False,
1573+
schema_evolution=False,
1574+
)
1575+
1576+
1577+
def test_athena_to_iceberg_schema_evolution_modify_columns(
1578+
path: str, path2: str, glue_database: str, glue_table: str
1579+
) -> None:
1580+
# Version 1
1581+
df = pd.DataFrame({"c1": pd.Series([1.0, 2.0], dtype="float32"), "c2": pd.Series([-1, -2], dtype="int32")})
1582+
1583+
wr.athena.to_iceberg(
1584+
df=df,
1585+
database=glue_database,
1586+
table=glue_table,
1587+
table_location=path,
1588+
temp_path=path2,
1589+
keep_files=False,
1590+
schema_evolution=True,
1591+
)
1592+
1593+
df_out = wr.athena.read_sql_table(
1594+
table=glue_table,
1595+
database=glue_database,
1596+
ctas_approach=False,
1597+
unload_approach=False,
1598+
)
1599+
1600+
assert len(df_out) == 2
1601+
assert len(df_out.columns) == 2
1602+
assert str(df_out["c1"].dtype).startswith("float32")
1603+
assert str(df_out["c2"].dtype).startswith("Int32")
1604+
1605+
# Version 2
1606+
df2 = pd.DataFrame({"c1": pd.Series([3.0, 4.0], dtype="float64"), "c2": pd.Series([-3, -4], dtype="int64")})
1607+
1608+
wr.athena.to_iceberg(
1609+
df=df2,
1610+
database=glue_database,
1611+
table=glue_table,
1612+
table_location=path,
1613+
temp_path=path2,
1614+
keep_files=False,
1615+
schema_evolution=True,
1616+
)
1617+
1618+
df2_out = wr.athena.read_sql_table(
1619+
table=glue_table,
1620+
database=glue_database,
1621+
ctas_approach=False,
1622+
unload_approach=False,
1623+
)
1624+
1625+
assert len(df2_out) == 4
1626+
assert len(df2_out.columns) == 2
1627+
assert str(df2_out["c1"].dtype).startswith("float64")
1628+
assert str(df2_out["c2"].dtype).startswith("Int64")
1629+
1630+
1631+
def test_athena_to_iceberg_schema_evolution_remove_columns_error(
1632+
path: str, path2: str, glue_database: str, glue_table: str
1633+
) -> None:
1634+
df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5]})
1635+
wr.athena.to_iceberg(
1636+
df=df,
1637+
database=glue_database,
1638+
table=glue_table,
1639+
table_location=path,
1640+
temp_path=path2,
1641+
keep_files=False,
1642+
schema_evolution=True,
1643+
)
1644+
1645+
df = pd.DataFrame({"c0": [6, 7, 8]})
1646+
1647+
with pytest.raises(wr.exceptions.InvalidArgumentCombination):
1648+
wr.athena.to_iceberg(
1649+
df=df,
1650+
database=glue_database,
1651+
table=glue_table,
1652+
table_location=path,
1653+
temp_path=path2,
1654+
keep_files=False,
1655+
schema_evolution=True,
1656+
)
1657+
1658+
15281659
def test_to_iceberg_cast(path, path2, glue_table, glue_database):
15291660
df = pd.DataFrame(
15301661
{

0 commit comments

Comments
 (0)