Skip to content

Commit 62b527e

Browse files
authored
Sanitize special character column names when writing (#590)
* write with sanitized column names * push down to when parquet writes * add test for writing special character column name * parameterize format_version * use to_requested_schema * refactor to_requested_schema * more refactor * test nested schema * special character inside nested field * comment on why arrow is enabled * use existing variable * move spark config to conftest * pyspark arrow turns pandas df from tuple to dict * Revert refactor to_requested_schema * reorder args * refactor * pushdown schema * only tranform when necessary
1 parent 2ee2d19 commit 62b527e

File tree

4 files changed

+35
-13
lines changed

4 files changed

+35
-13
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
pre_order_visit,
123123
promote,
124124
prune_columns,
125+
sanitize_column_names,
125126
visit,
126127
visit_with_partner,
127128
)
@@ -1016,7 +1017,6 @@ def _task_to_table(
10161017

10171018
if len(arrow_table) < 1:
10181019
return None
1019-
10201020
return to_requested_schema(projected_schema, file_project_schema, arrow_table)
10211021

10221022

@@ -1769,27 +1769,33 @@ def data_file_statistics_from_parquet_metadata(
17691769

17701770

17711771
def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
1772-
schema = table_metadata.schema()
1773-
arrow_file_schema = schema.as_arrow()
17741772
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
1775-
17761773
row_group_size = PropertyUtil.property_as_int(
17771774
properties=table_metadata.properties,
17781775
property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
17791776
default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT,
17801777
)
17811778

17821779
def write_parquet(task: WriteTask) -> DataFile:
1780+
table_schema = task.schema
1781+
arrow_table = pa.Table.from_batches(task.record_batches)
1782+
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
1783+
# otherwise use the original schema
1784+
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
1785+
file_schema = sanitized_schema
1786+
arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
1787+
else:
1788+
file_schema = table_schema
1789+
17831790
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
17841791
fo = io.new_output(file_path)
17851792
with fo.create(overwrite=True) as fos:
1786-
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
1787-
writer.write(pa.Table.from_batches(task.record_batches), row_group_size=row_group_size)
1788-
1793+
with pq.ParquetWriter(fos, schema=file_schema.as_arrow(), **parquet_writer_kwargs) as writer:
1794+
writer.write(arrow_table, row_group_size=row_group_size)
17891795
statistics = data_file_statistics_from_parquet_metadata(
17901796
parquet_metadata=writer.writer.metadata,
1791-
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
1792-
parquet_column_mapping=parquet_path_to_id_mapping(schema),
1797+
stats_columns=compute_statistics_plan(file_schema, table_metadata.properties),
1798+
parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
17931799
)
17941800
data_file = DataFile(
17951801
content=DataFileContent.DATA,

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,6 +2060,7 @@ def spark() -> "SparkSession":
20602060
.config("spark.sql.catalog.hive.warehouse", "s3://warehouse/hive/")
20612061
.config("spark.sql.catalog.hive.s3.endpoint", "http://localhost:9000")
20622062
.config("spark.sql.catalog.hive.s3.path-style-access", "true")
2063+
.config("spark.sql.execution.arrow.pyspark.enabled", "true")
20632064
.getOrCreate()
20642065
)
20652066

tests/integration/test_inspect_table.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> Non
171171
for column in df.column_names:
172172
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
173173
if column == 'data_file':
174-
right = right.asDict(recursive=True)
175174
for df_column in left.keys():
176175
if df_column == 'partition':
177176
# Spark leaves out the partition if the table is unpartitioned
@@ -185,8 +184,6 @@ def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> Non
185184

186185
assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}"
187186
elif column == 'readable_metrics':
188-
right = right.asDict(recursive=True)
189-
190187
assert list(left.keys()) == [
191188
'bool',
192189
'string',

tests/integration/test_writes/test_writes.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,27 @@ def test_python_writes_special_character_column_with_spark_reads(
280280
column_name_with_special_character = "letter/abc"
281281
TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN = {
282282
column_name_with_special_character: ['a', None, 'z'],
283+
'id': [1, 2, 3],
284+
'name': ['AB', 'CD', 'EF'],
285+
'address': [
286+
{'street': '123', 'city': 'SFO', 'zip': 12345, column_name_with_special_character: 'a'},
287+
{'street': '456', 'city': 'SW', 'zip': 67890, column_name_with_special_character: 'b'},
288+
{'street': '789', 'city': 'Random', 'zip': 10112, column_name_with_special_character: 'c'},
289+
],
283290
}
284291
pa_schema = pa.schema([
285-
(column_name_with_special_character, pa.string()),
292+
pa.field(column_name_with_special_character, pa.string()),
293+
pa.field('id', pa.int32()),
294+
pa.field('name', pa.string()),
295+
pa.field(
296+
'address',
297+
pa.struct([
298+
pa.field('street', pa.string()),
299+
pa.field('city', pa.string()),
300+
pa.field('zip', pa.int32()),
301+
pa.field(column_name_with_special_character, pa.string()),
302+
]),
303+
),
286304
])
287305
arrow_table_with_special_character_column = pa.Table.from_pydict(TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN, schema=pa_schema)
288306
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema)

0 commit comments

Comments
 (0)