Skip to content

Commit 5cce906

Browse files
Fokkosungwy
andauthored
Use VisitorWithPartner for name-mapping (apache#1014)
* Use `VisitorWithPartner` for name-mapping This will correctly handle fields with `.` in the name. * Fix versions in deprecation Co-authored-by: Sung Yun <[email protected]> * Use full path in error --------- Co-authored-by: Sung Yun <[email protected]>
1 parent f05b1ae commit 5cce906

File tree

3 files changed

+189
-13
lines changed

3 files changed

+189
-13
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
visit_with_partner,
131131
)
132132
from pyiceberg.table.metadata import TableMetadata
133-
from pyiceberg.table.name_mapping import NameMapping
133+
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
134134
from pyiceberg.transforms import TruncateTransform
135135
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
136136
from pyiceberg.types import (
@@ -818,14 +818,14 @@ def pyarrow_to_schema(
818818
) -> Schema:
819819
has_ids = visit_pyarrow(schema, _HasIds())
820820
if has_ids:
821-
visitor = _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
821+
return visit_pyarrow(schema, _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us))
822822
elif name_mapping is not None:
823-
visitor = _ConvertToIceberg(name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
823+
schema_without_ids = _pyarrow_to_schema_without_ids(schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
824+
return apply_name_mapping(schema_without_ids, name_mapping)
824825
else:
825826
raise ValueError(
826827
"Parquet file does not have field-ids and the Iceberg table does not have 'schema.name-mapping.default' defined"
827828
)
828-
return visit_pyarrow(schema, visitor)
829829

830830

831831
def _pyarrow_to_schema_without_ids(schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> Schema:
@@ -1002,17 +1002,13 @@ class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
10021002
"""Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided."""
10031003

10041004
_field_names: List[str]
1005-
_name_mapping: Optional[NameMapping]
10061005

1007-
def __init__(self, name_mapping: Optional[NameMapping] = None, downcast_ns_timestamp_to_us: bool = False) -> None:
1006+
def __init__(self, downcast_ns_timestamp_to_us: bool = False) -> None:
10081007
self._field_names = []
1009-
self._name_mapping = name_mapping
10101008
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
10111009

10121010
def _field_id(self, field: pa.Field) -> int:
1013-
if self._name_mapping:
1014-
return self._name_mapping.find(*self._field_names).field_id
1015-
elif (field_id := _get_field_id(field)) is not None:
1011+
if (field_id := _get_field_id(field)) is not None:
10161012
return field_id
10171013
else:
10181014
raise ValueError(f"Cannot convert {field} to Iceberg Field as field_id is empty.")

pyiceberg/table/name_mapping.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030

3131
from pydantic import Field, conlist, field_validator, model_serializer
3232

33-
from pyiceberg.schema import Schema, SchemaVisitor, visit
33+
from pyiceberg.schema import P, PartnerAccessor, Schema, SchemaVisitor, SchemaWithPartnerVisitor, visit, visit_with_partner
3434
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel
35-
from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType
35+
from pyiceberg.types import IcebergType, ListType, MapType, NestedField, PrimitiveType, StructType
36+
from pyiceberg.utils.deprecated import deprecated
3637

3738

3839
class MappedField(IcebergBaseModel):
@@ -74,6 +75,11 @@ class NameMapping(IcebergRootModel[List[MappedField]]):
7475
def _field_by_name(self) -> Dict[str, MappedField]:
7576
return visit_name_mapping(self, _IndexByName())
7677

78+
@deprecated(
79+
deprecated_in="0.8.0",
80+
removed_in="0.9.0",
81+
help_message="Please use `apply_name_mapping` instead",
82+
)
7783
def find(self, *names: str) -> MappedField:
7884
name = ".".join(names)
7985
try:
@@ -248,3 +254,127 @@ def create_mapping_from_schema(schema: Schema) -> NameMapping:
248254

249255
def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]) -> NameMapping:
250256
return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, adds)))
257+
258+
259+
class NameMappingAccessor(PartnerAccessor[MappedField]):
260+
def schema_partner(self, partner: Optional[MappedField]) -> Optional[MappedField]:
261+
return partner
262+
263+
def field_partner(
264+
self, partner_struct: Optional[Union[List[MappedField], MappedField]], _: int, field_name: str
265+
) -> Optional[MappedField]:
266+
if partner_struct is not None:
267+
if isinstance(partner_struct, MappedField):
268+
partner_struct = partner_struct.fields
269+
270+
for field in partner_struct:
271+
if field_name in field.names:
272+
return field
273+
274+
return None
275+
276+
def list_element_partner(self, partner_list: Optional[MappedField]) -> Optional[MappedField]:
277+
if partner_list is not None:
278+
for field in partner_list.fields:
279+
if "element" in field.names:
280+
return field
281+
return None
282+
283+
def map_key_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]:
284+
if partner_map is not None:
285+
for field in partner_map.fields:
286+
if "key" in field.names:
287+
return field
288+
return None
289+
290+
def map_value_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]:
291+
if partner_map is not None:
292+
for field in partner_map.fields:
293+
if "value" in field.names:
294+
return field
295+
return None
296+
297+
298+
class NameMappingProjectionVisitor(SchemaWithPartnerVisitor[MappedField, IcebergType]):
299+
current_path: List[str]
300+
301+
def __init__(self) -> None:
302+
# For keeping track where we are in case when a field cannot be found
303+
self.current_path = []
304+
305+
def before_field(self, field: NestedField, field_partner: Optional[P]) -> None:
306+
self.current_path.append(field.name)
307+
308+
def after_field(self, field: NestedField, field_partner: Optional[P]) -> None:
309+
self.current_path.pop()
310+
311+
def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
312+
self.current_path.append("element")
313+
314+
def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
315+
self.current_path.pop()
316+
317+
def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
318+
self.current_path.append("key")
319+
320+
def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
321+
self.current_path.pop()
322+
323+
def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
324+
self.current_path.append("value")
325+
326+
def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
327+
self.current_path.pop()
328+
329+
def schema(self, schema: Schema, schema_partner: Optional[MappedField], struct_result: StructType) -> IcebergType:
330+
return Schema(*struct_result.fields, schema_id=schema.schema_id)
331+
332+
def struct(self, struct: StructType, struct_partner: Optional[MappedField], field_results: List[NestedField]) -> IcebergType:
333+
return StructType(*field_results)
334+
335+
def field(self, field: NestedField, field_partner: Optional[MappedField], field_result: IcebergType) -> IcebergType:
336+
if field_partner is None:
337+
raise ValueError(f"Field missing from NameMapping: {'.'.join(self.current_path)}")
338+
339+
return NestedField(
340+
field_id=field_partner.field_id,
341+
name=field.name,
342+
field_type=field_result,
343+
required=field.required,
344+
doc=field.doc,
345+
initial_default=field.initial_default,
346+
initial_write=field.write_default,
347+
)
348+
349+
def list(self, list_type: ListType, list_partner: Optional[MappedField], element_result: IcebergType) -> IcebergType:
350+
if list_partner is None:
351+
raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}")
352+
353+
element_id = next(field for field in list_partner.fields if "element" in field.names).field_id
354+
return ListType(element_id=element_id, element=element_result, element_required=list_type.element_required)
355+
356+
def map(
357+
self, map_type: MapType, map_partner: Optional[MappedField], key_result: IcebergType, value_result: IcebergType
358+
) -> IcebergType:
359+
if map_partner is None:
360+
raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}")
361+
362+
key_id = next(field for field in map_partner.fields if "key" in field.names).field_id
363+
value_id = next(field for field in map_partner.fields if "value" in field.names).field_id
364+
return MapType(
365+
key_id=key_id,
366+
key_type=key_result,
367+
value_id=value_id,
368+
value_type=value_result,
369+
value_required=map_type.value_required,
370+
)
371+
372+
def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[MappedField]) -> PrimitiveType:
373+
if primitive_partner is None:
374+
raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}")
375+
376+
return primitive
377+
378+
379+
def apply_name_mapping(schema_without_ids: Schema, name_mapping: NameMapping) -> Schema:
380+
return visit_with_partner(schema_without_ids, name_mapping, NameMappingProjectionVisitor(), NameMappingAccessor()) # type: ignore

tests/table/test_name_mapping.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from pyiceberg.table.name_mapping import (
2121
MappedField,
2222
NameMapping,
23+
apply_name_mapping,
2324
create_mapping_from_schema,
2425
parse_mapping_from_json,
2526
update_mapping,
2627
)
27-
from pyiceberg.types import NestedField, StringType
28+
from pyiceberg.types import BooleanType, FloatType, IntegerType, ListType, MapType, NestedField, StringType, StructType
2829

2930

3031
@pytest.fixture(scope="session")
@@ -321,3 +322,52 @@ def test_update_mapping(table_name_mapping_nested: NameMapping) -> None:
321322
MappedField(field_id=18, names=["add_18"]),
322323
])
323324
assert update_mapping(table_name_mapping_nested, updates, adds) == expected
325+
326+
327+
def test_mapping_using_by_visitor(table_schema_nested: Schema, table_name_mapping_nested: NameMapping) -> None:
328+
schema_without_ids = Schema(
329+
NestedField(field_id=0, name="foo", field_type=StringType(), required=False),
330+
NestedField(field_id=0, name="bar", field_type=IntegerType(), required=True),
331+
NestedField(field_id=0, name="baz", field_type=BooleanType(), required=False),
332+
NestedField(
333+
field_id=0,
334+
name="qux",
335+
field_type=ListType(element_id=0, element_type=StringType(), element_required=True),
336+
required=True,
337+
),
338+
NestedField(
339+
field_id=0,
340+
name="quux",
341+
field_type=MapType(
342+
key_id=0,
343+
key_type=StringType(),
344+
value_id=0,
345+
value_type=MapType(key_id=0, key_type=StringType(), value_id=0, value_type=IntegerType(), value_required=True),
346+
value_required=True,
347+
),
348+
required=True,
349+
),
350+
NestedField(
351+
field_id=0,
352+
name="location",
353+
field_type=ListType(
354+
element_id=0,
355+
element_type=StructType(
356+
NestedField(field_id=0, name="latitude", field_type=FloatType(), required=False),
357+
NestedField(field_id=0, name="longitude", field_type=FloatType(), required=False),
358+
),
359+
element_required=True,
360+
),
361+
required=True,
362+
),
363+
NestedField(
364+
field_id=0,
365+
name="person",
366+
field_type=StructType(
367+
NestedField(field_id=0, name="name", field_type=StringType(), required=False),
368+
NestedField(field_id=0, name="age", field_type=IntegerType(), required=True),
369+
),
370+
required=False,
371+
),
372+
)
373+
assert apply_name_mapping(schema_without_ids, table_name_mapping_nested).fields == table_schema_nested.fields

0 commit comments

Comments
 (0)