Skip to content

Commit 8042d82

Browse files
authored
Convert _get_column_projection_values to use Field-IDs (#2293)
# Rationale for this change This is a refactor of the `_get_column_projection_values` to rely on field-IDs rather than names. Field IDs will never change, while partitions and column names can be updated in a tables' lifetime. # Are these changes tested? # Are there any user-facing changes? <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent b6a45ed commit 8042d82

File tree

6 files changed

+87
-66
lines changed

6 files changed

+87
-66
lines changed

pyiceberg/expressions/visitors.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
861861
Args:
862862
file_schema (Schema): The schema of the file.
863863
case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True.
864-
projected_field_values (Dict[str, Any]): Values for projected fields not present in the data file.
864+
projected_field_values (Dict[int, Any]): Values for projected fields not present in the data file.
865865
866866
Raises:
867867
TypeError: In the case of an UnboundPredicate.
@@ -870,12 +870,12 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
870870

871871
file_schema: Schema
872872
case_sensitive: bool
873-
projected_field_values: Dict[str, Any]
873+
projected_field_values: Dict[int, Any]
874874

875-
def __init__(self, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[str, Any] = EMPTY_DICT) -> None:
875+
def __init__(self, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[int, Any] = EMPTY_DICT) -> None:
876876
self.file_schema = file_schema
877877
self.case_sensitive = case_sensitive
878-
self.projected_field_values = projected_field_values or {}
878+
self.projected_field_values = projected_field_values
879879

880880
def visit_true(self) -> BooleanExpression:
881881
return AlwaysTrue()
@@ -897,7 +897,8 @@ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpr
897897

898898
def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression:
899899
field = predicate.term.ref().field
900-
file_column_name = self.file_schema.find_column_name(field.field_id)
900+
field_id = field.field_id
901+
file_column_name = self.file_schema.find_column_name(field_id)
901902

902903
if file_column_name is None:
903904
# In the case of schema evolution or column projection, the field might not be present in the file schema.
@@ -915,8 +916,10 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
915916
# In the order described by the "Column Projection" section of the Iceberg spec:
916917
# https://iceberg.apache.org/spec/#column-projection
917918
# Evaluate column projection first if it exists
918-
if projected_field_value := self.projected_field_values.get(field.name):
919-
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(projected_field_value)):
919+
if field_id in self.projected_field_values:
920+
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(
921+
Record(self.projected_field_values[field_id])
922+
):
920923
return AlwaysTrue()
921924

922925
# Evaluate initial_default value
@@ -937,7 +940,7 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
937940

938941

939942
def translate_column_names(
940-
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[str, Any] = EMPTY_DICT
943+
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[int, Any] = EMPTY_DICT
941944
) -> BooleanExpression:
942945
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_field_values))
943946

pyiceberg/io/pyarrow.py

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@
131131
)
132132
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
133133
from pyiceberg.schema import (
134-
Accessor,
135134
PartnerAccessor,
136135
PreOrderSchemaVisitor,
137136
Schema,
@@ -1402,41 +1401,23 @@ def _field_id(self, field: pa.Field) -> int:
14021401

14031402
def _get_column_projection_values(
14041403
file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int]
1405-
) -> Tuple[bool, Dict[str, Any]]:
1404+
) -> Dict[int, Any]:
14061405
"""Apply Column Projection rules to File Schema."""
14071406
project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids)
1408-
should_project_columns = len(project_schema_diff) > 0
1409-
projected_missing_fields: Dict[str, Any] = {}
1407+
if len(project_schema_diff) == 0 or partition_spec is None:
1408+
return EMPTY_DICT
14101409

1411-
if not should_project_columns:
1412-
return False, {}
1413-
1414-
partition_schema: StructType
1415-
accessors: Dict[int, Accessor]
1416-
1417-
if partition_spec is not None:
1418-
partition_schema = partition_spec.partition_type(projected_schema)
1419-
accessors = build_position_accessors(partition_schema)
1420-
else:
1421-
return False, {}
1410+
partition_schema = partition_spec.partition_type(projected_schema)
1411+
accessors = build_position_accessors(partition_schema)
14221412

1413+
projected_missing_fields = {}
14231414
for field_id in project_schema_diff:
14241415
for partition_field in partition_spec.fields_by_source_id(field_id):
14251416
if isinstance(partition_field.transform, IdentityTransform):
1426-
accessor = accessors.get(partition_field.field_id)
1427-
1428-
if accessor is None:
1429-
continue
1417+
if partition_value := accessors[partition_field.field_id].get(file.partition):
1418+
projected_missing_fields[field_id] = partition_value
14301419

1431-
# The partition field may not exist in the partition record of the data file.
1432-
# This can happen when new partition fields are introduced after the file was written.
1433-
try:
1434-
if partition_value := accessor.get(file.partition):
1435-
projected_missing_fields[partition_field.name] = partition_value
1436-
except IndexError:
1437-
continue
1438-
1439-
return True, projected_missing_fields
1420+
return projected_missing_fields
14401421

14411422

14421423
def _task_to_record_batches(
@@ -1460,9 +1441,8 @@ def _task_to_record_batches(
14601441
# the table format version.
14611442
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)
14621443

1463-
# Apply column projection rules
1464-
# https://iceberg.apache.org/spec/#column-projection
1465-
should_project_columns, projected_missing_fields = _get_column_projection_values(
1444+
# Apply column projection rules: https://iceberg.apache.org/spec/#column-projection
1445+
projected_missing_fields = _get_column_projection_values(
14661446
task.file, projected_schema, partition_spec, file_schema.field_ids
14671447
)
14681448

@@ -1517,16 +1497,9 @@ def _task_to_record_batches(
15171497
file_project_schema,
15181498
current_batch,
15191499
downcast_ns_timestamp_to_us=True,
1500+
projected_missing_fields=projected_missing_fields,
15201501
)
15211502

1522-
# Inject projected column values if available
1523-
if should_project_columns:
1524-
for name, value in projected_missing_fields.items():
1525-
index = result_batch.schema.get_field_index(name)
1526-
if index != -1:
1527-
arr = pa.repeat(value, result_batch.num_rows)
1528-
result_batch = result_batch.set_column(index, name, arr)
1529-
15301503
yield result_batch
15311504

15321505

@@ -1696,7 +1669,7 @@ def _record_batches_from_scan_tasks_and_deletes(
16961669
deletes_per_file.get(task.file.file_path),
16971670
self._case_sensitive,
16981671
self._table_metadata.name_mapping(),
1699-
self._table_metadata.spec(),
1672+
self._table_metadata.specs().get(task.file.spec_id),
17001673
)
17011674
for batch in batches:
17021675
if self._limit is not None:
@@ -1714,12 +1687,15 @@ def _to_requested_schema(
17141687
batch: pa.RecordBatch,
17151688
downcast_ns_timestamp_to_us: bool = False,
17161689
include_field_ids: bool = False,
1690+
projected_missing_fields: Dict[int, Any] = EMPTY_DICT,
17171691
) -> pa.RecordBatch:
17181692
# We could reuse some of these visitors
17191693
struct_array = visit_with_partner(
17201694
requested_schema,
17211695
batch,
1722-
ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids),
1696+
ArrowProjectionVisitor(
1697+
file_schema, downcast_ns_timestamp_to_us, include_field_ids, projected_missing_fields=projected_missing_fields
1698+
),
17231699
ArrowAccessor(file_schema),
17241700
)
17251701
return pa.RecordBatch.from_struct_array(struct_array)
@@ -1730,18 +1706,21 @@ class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
17301706
_include_field_ids: bool
17311707
_downcast_ns_timestamp_to_us: bool
17321708
_use_large_types: Optional[bool]
1709+
_projected_missing_fields: Dict[int, Any]
17331710

17341711
def __init__(
17351712
self,
17361713
file_schema: Schema,
17371714
downcast_ns_timestamp_to_us: bool = False,
17381715
include_field_ids: bool = False,
17391716
use_large_types: Optional[bool] = None,
1717+
projected_missing_fields: Dict[int, Any] = EMPTY_DICT,
17401718
) -> None:
17411719
self._file_schema = file_schema
17421720
self._include_field_ids = include_field_ids
17431721
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
17441722
self._use_large_types = use_large_types
1723+
self._projected_missing_fields = projected_missing_fields
17451724

17461725
if use_large_types is not None:
17471726
deprecation_message(
@@ -1821,7 +1800,9 @@ def struct(
18211800
elif field.optional or field.initial_default is not None:
18221801
# When an optional field is added, or when a required field with a non-null initial default is added
18231802
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)
1824-
if field.initial_default is None:
1803+
if projected_value := self._projected_missing_fields.get(field.field_id):
1804+
field_arrays.append(pa.repeat(pa.scalar(projected_value, type=arrow_type), len(struct_array)))
1805+
elif field.initial_default is None:
18251806
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
18261807
else:
18271808
field_arrays.append(pa.repeat(field.initial_default, len(struct_array)))

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2375,8 +2375,10 @@ def data_file(table_schema_simple: Schema, tmp_path: str) -> str:
23752375

23762376
@pytest.fixture
23772377
def example_task(data_file: str) -> FileScanTask:
2378+
datafile = DataFile.from_args(file_path=data_file, file_format=FileFormat.PARQUET, file_size_in_bytes=1925)
2379+
datafile.spec_id = 0
23782380
return FileScanTask(
2379-
data_file=DataFile.from_args(file_path=data_file, file_format=FileFormat.PARQUET, file_size_in_bytes=1925),
2381+
data_file=datafile,
23802382
)
23812383

23822384

tests/expressions/test_visitors.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,6 +1730,33 @@ def test_translate_column_names_missing_column_match_null() -> None:
17301730
assert translated_expr == AlwaysTrue()
17311731

17321732

1733+
def test_translate_column_names_missing_column_match_explicit_null() -> None:
1734+
"""Test translate_column_names when missing column matches null."""
1735+
# Original schema
1736+
original_schema = Schema(
1737+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1738+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False),
1739+
schema_id=1,
1740+
)
1741+
1742+
# Create bound expression for the missing column
1743+
unbound_expr = IsNull("missing_col")
1744+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1745+
1746+
# File schema only has the existing column (field_id=1), missing field_id=2
1747+
file_schema = Schema(
1748+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1749+
schema_id=1,
1750+
)
1751+
1752+
# Translate column names
1753+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True, projected_field_values={2: None})
1754+
1755+
# Should evaluate to AlwaysTrue because the missing column is treated as null
1756+
# missing_col's default initial_default (None) satisfies the IsNull predicate
1757+
assert translated_expr == AlwaysTrue()
1758+
1759+
17331760
def test_translate_column_names_missing_column_with_initial_default() -> None:
17341761
"""Test translate_column_names when missing column's initial_default matches expression."""
17351762
# Original schema
@@ -1801,7 +1828,7 @@ def test_translate_column_names_missing_column_with_projected_field_matches() ->
18011828
)
18021829

18031830
# Projected column that is missing in the file schema
1804-
projected_field_values = {"missing_col": 42}
1831+
projected_field_values = {2: 42}
18051832

18061833
# Translate column names
18071834
translated_expr = translate_column_names(
@@ -1833,7 +1860,7 @@ def test_translate_column_names_missing_column_with_projected_field_mismatch() -
18331860
)
18341861

18351862
# Projected column that is missing in the file schema
1836-
projected_field_values = {"missing_col": 1}
1863+
projected_field_values = {2: 1}
18371864

18381865
# Translate column names
18391866
translated_expr = translate_column_names(
@@ -1864,7 +1891,7 @@ def test_translate_column_names_missing_column_projected_field_fallbacks_to_init
18641891
)
18651892

18661893
# Projected field value that differs from both the expression literal and initial_default
1867-
projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42)
1894+
projected_field_values = {2: 10} # This doesn't match expression literal (42)
18681895

18691896
# Translate column names
18701897
translated_expr = translate_column_names(
@@ -1895,7 +1922,7 @@ def test_translate_column_names_missing_column_projected_field_matches_initial_d
18951922
)
18961923

18971924
# Projected field value that matches the expression literal
1898-
projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42)
1925+
projected_field_values = {2: 10} # This doesn't match expression literal (42)
18991926

19001927
# Translate column names
19011928
translated_expr = translate_column_names(

tests/integration/test_writes/test_partitioned_writes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,8 +711,10 @@ def test_dynamic_partition_overwrite_evolve_partition(spark: SparkSession, sessi
711711
)
712712

713713
identifier = f"default.partitioned_{format_version}_test_dynamic_partition_overwrite_evolve_partition"
714-
with pytest.raises(NoSuchTableError):
714+
try:
715715
session_catalog.drop_table(identifier)
716+
except NoSuchTableError:
717+
pass
716718

717719
tbl = session_catalog.create_table(
718720
identifier=identifier,

tests/io/test_pyarrow.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,10 @@ def file_map(schema_map: Schema, tmpdir: str) -> str:
970970
def project(
971971
schema: Schema, files: List[str], expr: Optional[BooleanExpression] = None, table_schema: Optional[Schema] = None
972972
) -> pa.Table:
973+
def _set_spec_id(datafile: DataFile) -> DataFile:
974+
datafile.spec_id = 0
975+
return datafile
976+
973977
return ArrowScan(
974978
table_metadata=TableMetadataV2(
975979
location="file://a/b/",
@@ -985,13 +989,15 @@ def project(
985989
).to_table(
986990
tasks=[
987991
FileScanTask(
988-
DataFile.from_args(
989-
content=DataFileContent.DATA,
990-
file_path=file,
991-
file_format=FileFormat.PARQUET,
992-
partition={},
993-
record_count=3,
994-
file_size_in_bytes=3,
992+
_set_spec_id(
993+
DataFile.from_args(
994+
content=DataFileContent.DATA,
995+
file_path=file,
996+
file_format=FileFormat.PARQUET,
997+
partition={},
998+
record_count=3,
999+
file_size_in_bytes=3,
1000+
)
9951001
)
9961002
)
9971003
for file in files
@@ -1189,7 +1195,7 @@ def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCa
11891195
with transaction.update_snapshot().overwrite() as update:
11901196
update.append_data_file(unpartitioned_file)
11911197

1192-
schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int64())])
1198+
schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int32())])
11931199
assert table.scan().to_arrow() == pa.table(
11941200
{
11951201
"other_field": ["foo", "bar", "baz"],
@@ -1264,8 +1270,8 @@ def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryC
12641270
str(table.scan().to_arrow())
12651271
== """pyarrow.Table
12661272
field_1: string
1267-
field_2: int64
1268-
field_3: int64
1273+
field_2: int32
1274+
field_3: int32
12691275
----
12701276
field_1: [["foo"]]
12711277
field_2: [[2]]

0 commit comments

Comments
 (0)