Skip to content

Commit cd7d8c7

Browse files
ErigaraRoman ShaninkevinjqliuCopilotFokko
authored
Fix projected fields predicate evaluation (#2029)
<!-- Thanks for opening a pull request! --> <!-- In the case this PR will resolve an issue, please replace ${GITHUB_ISSUE_ID} below with the actual Github issue id. --> Closes #2028 # Rationale for this change Provide expected result aligned with `spark` implementation. This PR fixes a bug where predicate evaluation for a column that is missing from the parquet file schema will return no result. This is due to `_ColumnNameTranslator` visitor returning `AlwaysFalse` when the column cannot be found in the file schema. The solution is to pass in the projected field value for evaluation. This follows the order of operation described in https://iceberg.apache.org/spec/#column-projection # Are these changes tested? I've checked it on script attached to issue + new test was added. Yes, added some unit tests for `_ColumnNameTranslator`/`translate_column_names` Added a test for predicate evaluation for projected columns. # Are there any user-facing changes? Kinda yes, because results of some scans now different. <!-- In the case of user-facing changes, please add the changelog label. --> --------- Co-authored-by: Roman Shanin <[email protected]> Co-authored-by: Kevin Liu <[email protected]> Co-authored-by: Kevin Liu <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Fokko Driesprong <[email protected]>
1 parent 5c604c2 commit cd7d8c7

File tree

4 files changed

+317
-12
lines changed

4 files changed

+317
-12
lines changed

pyiceberg/expressions/visitors.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,7 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
861861
Args:
862862
file_schema (Schema): The schema of the file.
863863
case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True.
864+
projected_field_values (Dict[str, Any]): Values for projected fields not present in the data file.
864865
865866
Raises:
866867
TypeError: In the case of an UnboundPredicate.
@@ -869,10 +870,12 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
869870

870871
file_schema: Schema
871872
case_sensitive: bool
873+
projected_field_values: Dict[str, Any]
872874

873-
def __init__(self, file_schema: Schema, case_sensitive: bool) -> None:
875+
def __init__(self, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[str, Any] = EMPTY_DICT) -> None:
874876
self.file_schema = file_schema
875877
self.case_sensitive = case_sensitive
878+
self.projected_field_values = projected_field_values or {}
876879

877880
def visit_true(self) -> BooleanExpression:
878881
return AlwaysTrue()
@@ -897,9 +900,8 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
897900
file_column_name = self.file_schema.find_column_name(field.field_id)
898901

899902
if file_column_name is None:
900-
# In the case of schema evolution, the column might not be present
901-
# we can use the default value as a constant and evaluate it against
902-
# the predicate
903+
# In the case of schema evolution or column projection, the field might not be present in the file schema.
904+
# we can use the projected value or the field's default value as a constant and evaluate it against the predicate
903905
pred: BooleanExpression
904906
if isinstance(predicate, BoundUnaryPredicate):
905907
pred = predicate.as_unbound(field.name)
@@ -910,6 +912,14 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
910912
else:
911913
raise ValueError(f"Unsupported predicate: {predicate}")
912914

915+
# In the order described by the "Column Projection" section of the Iceberg spec:
916+
# https://iceberg.apache.org/spec/#column-projection
917+
# Evaluate column projection first if it exists
918+
if projected_field_value := self.projected_field_values.get(field.name):
919+
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(projected_field_value)):
920+
return AlwaysTrue()
921+
922+
# Evaluate initial_default value
913923
return (
914924
AlwaysTrue()
915925
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field.initial_default))
@@ -926,8 +936,10 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
926936
raise ValueError(f"Unsupported predicate: {predicate}")
927937

928938

929-
def translate_column_names(expr: BooleanExpression, file_schema: Schema, case_sensitive: bool) -> BooleanExpression:
930-
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive))
939+
def translate_column_names(
940+
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[str, Any] = EMPTY_DICT
941+
) -> BooleanExpression:
942+
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_field_values))
931943

932944

933945
class _ExpressionFieldIDs(BooleanExpressionVisitor[Set[int]]):

pyiceberg/io/pyarrow.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,18 +1460,20 @@ def _task_to_record_batches(
14601460
# the table format version.
14611461
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)
14621462

1463-
pyarrow_filter = None
1464-
if bound_row_filter is not AlwaysTrue():
1465-
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
1466-
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
1467-
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
1468-
14691463
# Apply column projection rules
14701464
# https://iceberg.apache.org/spec/#column-projection
14711465
should_project_columns, projected_missing_fields = _get_column_projection_values(
14721466
task.file, projected_schema, partition_spec, file_schema.field_ids
14731467
)
14741468

1469+
pyarrow_filter = None
1470+
if bound_row_filter is not AlwaysTrue():
1471+
translated_row_filter = translate_column_names(
1472+
bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields
1473+
)
1474+
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
1475+
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
1476+
14751477
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
14761478

14771479
fragment_scanner = ds.Scanner.from_fragment(

tests/expressions/test_visitors.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,15 @@
7272
expression_to_plain_format,
7373
rewrite_not,
7474
rewrite_to_dnf,
75+
translate_column_names,
7576
visit,
7677
visit_bound_predicate,
7778
)
7879
from pyiceberg.manifest import ManifestFile, PartitionFieldSummary
7980
from pyiceberg.schema import Accessor, Schema
8081
from pyiceberg.typedef import Record
8182
from pyiceberg.types import (
83+
BooleanType,
8284
DoubleType,
8385
FloatType,
8486
IcebergType,
@@ -1623,3 +1625,282 @@ def test_expression_evaluator_null() -> None:
16231625
assert expression_evaluator(schema, LessThan("a", 1), case_sensitive=True)(struct) is False
16241626
assert expression_evaluator(schema, StartsWith("a", 1), case_sensitive=True)(struct) is False
16251627
assert expression_evaluator(schema, NotStartsWith("a", 1), case_sensitive=True)(struct) is True
1628+
1629+
1630+
def test_translate_column_names_simple_case(table_schema_simple: Schema) -> None:
1631+
"""Test translate_column_names with matching column names."""
1632+
# Create a bound expression using the original schema
1633+
unbound_expr = EqualTo("foo", "test_value")
1634+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
1635+
1636+
# File schema has the same column names
1637+
file_schema = Schema(
1638+
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
1639+
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
1640+
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
1641+
schema_id=1,
1642+
)
1643+
1644+
# Translate column names
1645+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1646+
1647+
# Should return an unbound expression with the same column name since they match
1648+
assert isinstance(translated_expr, EqualTo)
1649+
assert translated_expr.term == Reference("foo")
1650+
assert translated_expr.literal == literal("test_value")
1651+
1652+
1653+
def test_translate_column_names_different_column_names() -> None:
1654+
"""Test translate_column_names with different column names in file schema."""
1655+
# Original schema
1656+
original_schema = Schema(
1657+
NestedField(field_id=1, name="original_name", field_type=StringType(), required=False),
1658+
schema_id=1,
1659+
)
1660+
1661+
# Create bound expression
1662+
unbound_expr = EqualTo("original_name", "test_value")
1663+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1664+
1665+
# File schema has different column name but same field ID
1666+
file_schema = Schema(
1667+
NestedField(field_id=1, name="file_column_name", field_type=StringType(), required=False),
1668+
schema_id=1,
1669+
)
1670+
1671+
# Translate column names
1672+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1673+
1674+
# Should use the file schema's column name
1675+
assert isinstance(translated_expr, EqualTo)
1676+
assert translated_expr.term == Reference("file_column_name")
1677+
assert translated_expr.literal == literal("test_value")
1678+
1679+
1680+
def test_translate_column_names_missing_column() -> None:
1681+
"""Test translate_column_names when column is missing from file schema (such as in schema evolution)."""
1682+
# Original schema
1683+
original_schema = Schema(
1684+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1685+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False),
1686+
schema_id=1,
1687+
)
1688+
1689+
# Create bound expression for the missing column
1690+
unbound_expr = EqualTo("missing_col", 42)
1691+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1692+
1693+
# File schema only has the existing column (field_id=1), missing field_id=2
1694+
file_schema = Schema(
1695+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1696+
schema_id=1,
1697+
)
1698+
1699+
# Translate column names
1700+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1701+
1702+
# missing_col's default initial_default (None) does not match the expression literal (42)
1703+
assert translated_expr == AlwaysFalse()
1704+
1705+
1706+
def test_translate_column_names_missing_column_match_null() -> None:
1707+
"""Test translate_column_names when missing column matches null."""
1708+
# Original schema
1709+
original_schema = Schema(
1710+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1711+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False),
1712+
schema_id=1,
1713+
)
1714+
1715+
# Create bound expression for the missing column
1716+
unbound_expr = IsNull("missing_col")
1717+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1718+
1719+
# File schema only has the existing column (field_id=1), missing field_id=2
1720+
file_schema = Schema(
1721+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1722+
schema_id=1,
1723+
)
1724+
1725+
# Translate column names
1726+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1727+
1728+
# Should evaluate to AlwaysTrue because the missing column is treated as null
1729+
# missing_col's default initial_default (None) satisfies the IsNull predicate
1730+
assert translated_expr == AlwaysTrue()
1731+
1732+
1733+
def test_translate_column_names_missing_column_with_initial_default() -> None:
1734+
"""Test translate_column_names when missing column's initial_default matches expression."""
1735+
# Original schema
1736+
original_schema = Schema(
1737+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1738+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=42),
1739+
schema_id=1,
1740+
)
1741+
1742+
# Create bound expression for the missing column
1743+
unbound_expr = EqualTo("missing_col", 42)
1744+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1745+
1746+
# File schema only has the existing column (field_id=1), missing field_id=2
1747+
file_schema = Schema(
1748+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1749+
schema_id=1,
1750+
)
1751+
1752+
# Translate column names
1753+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1754+
1755+
# Should evaluate to AlwaysTrue because the initial_default value (42) matches the literal (42)
1756+
assert translated_expr == AlwaysTrue()
1757+
1758+
1759+
def test_translate_column_names_missing_column_with_initial_default_mismatch() -> None:
1760+
"""Test translate_column_names when missing column's initial_default doesn't match expression."""
1761+
# Original schema
1762+
original_schema = Schema(
1763+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=10),
1764+
schema_id=1,
1765+
)
1766+
1767+
# Create bound expression that won't match the default value
1768+
unbound_expr = EqualTo("missing_col", 42)
1769+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1770+
1771+
# File schema doesn't have this column
1772+
file_schema = Schema(
1773+
NestedField(field_id=1, name="other_col", field_type=StringType(), required=False),
1774+
schema_id=1,
1775+
)
1776+
1777+
# Translate column names
1778+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1779+
1780+
# Should evaluate to AlwaysFalse because initial_default value (10) doesn't match literal (42)
1781+
assert translated_expr == AlwaysFalse()
1782+
1783+
1784+
def test_translate_column_names_missing_column_with_projected_field_matches() -> None:
1785+
"""Test translate_column_names with projected field value that matches expression."""
1786+
# Original schema with a field that has no initial_default (defaults to None)
1787+
original_schema = Schema(
1788+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1789+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False),
1790+
schema_id=1,
1791+
)
1792+
1793+
# Create bound expression for the missing column
1794+
unbound_expr = EqualTo("missing_col", 42)
1795+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1796+
1797+
# File schema only has the existing column (field_id=1), missing field_id=2
1798+
file_schema = Schema(
1799+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1800+
schema_id=1,
1801+
)
1802+
1803+
# Projected column that is missing in the file schema
1804+
projected_field_values = {"missing_col": 42}
1805+
1806+
# Translate column names
1807+
translated_expr = translate_column_names(
1808+
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1809+
)
1810+
1811+
# Should evaluate to AlwaysTrue since projected field value matches the expression literal
1812+
# even though the field is missing in the file schema
1813+
assert translated_expr == AlwaysTrue()
1814+
1815+
1816+
def test_translate_column_names_missing_column_with_projected_field_mismatch() -> None:
1817+
"""Test translate_column_names with projected field value that doesn't match expression."""
1818+
# Original schema with a field that has no initial_default (defaults to None)
1819+
original_schema = Schema(
1820+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1821+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False),
1822+
schema_id=1,
1823+
)
1824+
1825+
# Create bound expression for the missing column
1826+
unbound_expr = EqualTo("missing_col", 42)
1827+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1828+
1829+
# File schema only has the existing column (field_id=1), missing field_id=2
1830+
file_schema = Schema(
1831+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1832+
schema_id=1,
1833+
)
1834+
1835+
# Projected column that is missing in the file schema
1836+
projected_field_values = {"missing_col": 1}
1837+
1838+
# Translate column names
1839+
translated_expr = translate_column_names(
1840+
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1841+
)
1842+
1843+
# Should evaluate to AlwaysFalse since projected field value does not match the expression literal
1844+
assert translated_expr == AlwaysFalse()
1845+
1846+
1847+
def test_translate_column_names_missing_column_projected_field_fallbacks_to_initial_default() -> None:
1848+
"""Test translate_column_names when projected field value doesn't match but initial_default does."""
1849+
# Original schema with a field that has an initial_default
1850+
original_schema = Schema(
1851+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1852+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=42),
1853+
schema_id=1,
1854+
)
1855+
1856+
# Create bound expression for the missing column that would match initial_default
1857+
unbound_expr = EqualTo("missing_col", 42)
1858+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1859+
1860+
# File schema only has the existing column (field_id=1), missing field_id=2
1861+
file_schema = Schema(
1862+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1863+
schema_id=1,
1864+
)
1865+
1866+
# Projected field value that differs from both the expression literal and initial_default
1867+
projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42)
1868+
1869+
# Translate column names
1870+
translated_expr = translate_column_names(
1871+
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1872+
)
1873+
1874+
# Should evaluate to AlwaysTrue since projected field value doesn't match but initial_default does
1875+
assert translated_expr == AlwaysTrue()
1876+
1877+
1878+
def test_translate_column_names_missing_column_projected_field_matches_initial_default_mismatch() -> None:
1879+
"""Test translate_column_names when both projected field value and initial_default doesn't match."""
1880+
# Original schema with a field that has an initial_default that doesn't match the expression
1881+
original_schema = Schema(
1882+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1883+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=10),
1884+
schema_id=1,
1885+
)
1886+
1887+
# Create bound expression for the missing column
1888+
unbound_expr = EqualTo("missing_col", 42)
1889+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1890+
1891+
# File schema only has the existing column (field_id=1), missing field_id=2
1892+
file_schema = Schema(
1893+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1894+
schema_id=1,
1895+
)
1896+
1897+
# Projected field value that matches the expression literal
1898+
projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42)
1899+
1900+
# Translate column names
1901+
translated_expr = translate_column_names(
1902+
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1903+
)
1904+
1905+
# Should evaluate to AlwaysFalse since both projected field value and initial_default does not match
1906+
assert translated_expr == AlwaysFalse()

0 commit comments

Comments
 (0)