Skip to content

Commit e92bee0

Browse files
committed
Handle nested field names when sanitizing table at ParquetWriter when flavor='spark'
1 parent eb73bc6 commit e92bee0

File tree

2 files changed

+182
-24
lines changed

2 files changed

+182
-24
lines changed

python/pyarrow/parquet/core.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -715,32 +715,77 @@ def _sanitized_spark_field_name(name):
715715
return _SPARK_DISALLOWED_CHARS.sub('_', name)
716716

717717

718-
def _sanitize_schema(schema, flavor):
719-
if 'spark' in flavor:
720-
sanitized_fields = []
718+
def _sanitize_field_recursive(field):
719+
"""
720+
Recursively sanitize field names in struct types for Spark compatibility.
721721
722-
schema_changed = False
722+
Returns
723+
-------
724+
tuple
725+
(sanitized_field, changed) where changed is True if any sanitization occurred
726+
"""
727+
sanitized_name = _sanitized_spark_field_name(field.name)
728+
sanitized_type = field.type
729+
type_changed = False
730+
731+
if pa.types.is_struct(field.type):
732+
sanitized_fields = [_sanitize_field_recursive(f) for f in field.type]
733+
if any(changed for _, changed in sanitized_fields):
734+
sanitized_type = pa.struct([f for f, _ in sanitized_fields])
735+
type_changed = True
736+
elif pa.types.is_list(field.type) or pa.types.is_large_list(field.type):
737+
# Sanitize the value field of list types
738+
value_field = field.type.value_field
739+
sanitized_value_field, value_changed = _sanitize_field_recursive(value_field)
740+
if value_changed:
741+
if pa.types.is_list(field.type):
742+
sanitized_type = pa.list_(sanitized_value_field)
743+
else: # large_list
744+
sanitized_type = pa.large_list(sanitized_value_field)
745+
type_changed = True
746+
elif pa.types.is_fixed_size_list(field.type):
747+
# Sanitize the value field of fixed_size_list types
748+
value_field = field.type.value_field
749+
list_size = field.type.list_size
750+
sanitized_value_field, value_changed = _sanitize_field_recursive(value_field)
751+
if value_changed:
752+
sanitized_type = pa.list_(sanitized_value_field, list_size)
753+
type_changed = True
754+
elif pa.types.is_map(field.type):
755+
# Sanitize both key and item fields of map types
756+
key_field = field.type.key_field
757+
item_field = field.type.item_field
758+
sanitized_key_field, key_changed = _sanitize_field_recursive(key_field)
759+
sanitized_item_field, item_changed = _sanitize_field_recursive(item_field)
760+
if key_changed or item_changed:
761+
sanitized_type = pa.map_(sanitized_key_field, sanitized_item_field,
762+
keys_sorted=field.type.keys_sorted)
763+
type_changed = True
764+
765+
name_changed = sanitized_name != field.name
766+
if name_changed or type_changed:
767+
return pa.field(sanitized_name, sanitized_type, field.nullable,
768+
field.metadata), True
769+
return field, False
723770

724-
for field in schema:
725-
name = field.name
726-
sanitized_name = _sanitized_spark_field_name(name)
727771

728-
if sanitized_name != name:
729-
schema_changed = True
730-
sanitized_field = pa.field(sanitized_name, field.type,
731-
field.nullable, field.metadata)
732-
sanitized_fields.append(sanitized_field)
733-
else:
734-
sanitized_fields.append(field)
735-
736-
new_schema = pa.schema(sanitized_fields, metadata=schema.metadata)
737-
return new_schema, schema_changed
738-
else:
772+
def _sanitize_schema(schema, flavor):
773+
if 'spark' not in flavor:
739774
return schema, False
740775

776+
sanitized_fields = []
777+
schema_changed = False
778+
779+
for field in schema:
780+
sanitized_field, changed = _sanitize_field_recursive(field)
781+
sanitized_fields.append(sanitized_field)
782+
schema_changed = schema_changed or changed
783+
784+
new_schema = pa.schema(sanitized_fields, metadata=schema.metadata)
785+
return new_schema, schema_changed
786+
741787

742788
def _sanitize_table(table, new_schema, flavor):
743-
# TODO: This will not handle prohibited characters in nested field names
744789
if 'spark' in flavor:
745790
column_data = [table[i] for i in range(table.num_columns)]
746791
return pa.Table.from_arrays(column_data, schema=new_schema)

python/pyarrow/tests/parquet/test_basic.py

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -613,14 +613,127 @@ def test_compression_level():
613613

614614

615615
def test_sanitized_spark_field_names():
616-
a0 = pa.array([0, 1, 2, 3, 4])
617-
name = 'prohib; ,\t{}'
618-
table = pa.Table.from_arrays([a0], [name])
616+
field_metadata = {b'key': b'value'}
617+
schema_metadata = {b'schema_key': b'schema_value'}
618+
619+
schema = pa.schema([
620+
pa.field('prohib; ,\t{}', pa.int32()),
621+
pa.field('field=with\nspecial', pa.string(), metadata=field_metadata),
622+
pa.field('nested_struct', pa.struct([
623+
pa.field('field,comma', pa.int32()),
624+
pa.field('deeply{nested}', pa.struct([
625+
pa.field('field(parens)', pa.float64()),
626+
pa.field('normal_field', pa.bool_())
627+
]))
628+
]))
629+
], metadata=schema_metadata)
630+
631+
data = [
632+
pa.array([1, 2]),
633+
pa.array(['a', 'b']),
634+
pa.array([
635+
{'field,comma': 10, 'deeply{nested}': {
636+
'field(parens)': 1.5, 'normal_field': True}},
637+
{'field,comma': 20, 'deeply{nested}': {
638+
'field(parens)': 2.5, 'normal_field': False}}
639+
], type=schema[2].type)
640+
]
641+
642+
table = pa.Table.from_arrays(data, schema=schema)
643+
result = _roundtrip_table(table, write_table_kwargs={'flavor': 'spark'})
644+
645+
assert result.schema[0].name == 'prohib______'
646+
assert result.schema[1].name == 'field_with_special'
647+
648+
nested_type = result.schema[2].type
649+
assert nested_type[0].name == 'field_comma'
650+
assert nested_type[1].name == 'deeply_nested_'
651+
652+
deep_type = nested_type[1].type
653+
assert deep_type[0].name == 'field_parens_'
654+
assert deep_type[1].name == 'normal_field'
655+
656+
assert result.schema[1].metadata == field_metadata
657+
assert result.schema.metadata == schema_metadata
658+
assert len(result) == 2
659+
660+
661+
def test_sanitized_spark_field_names_nested():
662+
# Test that field name sanitization works for structs nested inside
663+
# lists, maps, and other complex types
664+
schema = pa.schema([
665+
# List containing struct with special chars
666+
pa.field('list;field', pa.list_(pa.field('item', pa.struct([
667+
pa.field('field,name', pa.int32()),
668+
pa.field('other{field}', pa.string())
669+
])))),
670+
# Large list with nested struct
671+
pa.field('large=list', pa.large_list(pa.field('element', pa.struct([
672+
pa.field('nested(field)', pa.float64())
673+
])))),
674+
# Fixed size list with nested struct
675+
pa.field('fixed\tlist', pa.list_(pa.field('item', pa.struct([
676+
pa.field('special field', pa.int32())
677+
])), 2)),
678+
# Map with structs in both key and value
679+
pa.field('map field', pa.map_(
680+
pa.field('key', pa.struct(
681+
[pa.field('key;field', pa.string())]), nullable=False),
682+
pa.field('value', pa.struct([pa.field('value,field', pa.int32())]))
683+
))
684+
])
685+
686+
list_data = pa.array([
687+
[{'field,name': 1, 'other{field}': 'a'}],
688+
[{'field,name': 2, 'other{field}': 'b'}]
689+
], type=schema[0].type)
690+
691+
large_list_data = pa.array([
692+
[{'nested(field)': 1.5}],
693+
[{'nested(field)': 2.5}]
694+
], type=schema[1].type)
695+
696+
fixed_list_data = pa.array([
697+
[{'special field': 10}, {'special field': 20}],
698+
[{'special field': 30}, {'special field': 40}]
699+
], type=schema[2].type)
700+
701+
map_data = pa.array([
702+
[({'key;field': 'k1'}, {'value,field': 100})],
703+
[({'key;field': 'k2'}, {'value,field': 200})]
704+
], type=schema[3].type)
705+
706+
table = pa.Table.from_arrays(
707+
[list_data, large_list_data, fixed_list_data, map_data],
708+
schema=schema
709+
)
619710

620711
result = _roundtrip_table(table, write_table_kwargs={'flavor': 'spark'})
621712

622-
expected_name = 'prohib______'
623-
assert result.schema[0].name == expected_name
713+
# Check top-level field names are sanitized
714+
assert result.schema[0].name == 'list_field'
715+
assert result.schema[1].name == 'large_list'
716+
assert result.schema[2].name == 'fixed_list'
717+
assert result.schema[3].name == 'map_field'
718+
719+
# Check list value field's struct has sanitized names
720+
list_value_type = result.schema[0].type.value_type
721+
assert list_value_type[0].name == 'field_name'
722+
assert list_value_type[1].name == 'other_field_'
723+
724+
# Check large list value field's struct has sanitized names
725+
large_list_value_type = result.schema[1].type.value_type
726+
assert large_list_value_type[0].name == 'nested_field_'
727+
728+
# Check fixed size list value field's struct has sanitized names
729+
fixed_list_value_type = result.schema[2].type.value_type
730+
assert fixed_list_value_type[0].name == 'special_field'
731+
732+
# Check map key and item structs have sanitized names
733+
map_key_type = result.schema[3].type.key_type
734+
map_item_type = result.schema[3].type.item_type
735+
assert map_key_type[0].name == 'key_field'
736+
assert map_item_type[0].name == 'value_field'
624737

625738

626739
@pytest.mark.pandas

0 commit comments

Comments
 (0)