Skip to content

Commit a626bc2

Browse files
Fokkokevinjqliu
andauthored
Update schema projection to support initial-defaults (#1644)
Add the projection piece of the initial defaults. Closes #1836 --------- Co-authored-by: Kevin Liu <[email protected]>
1 parent fa71498 commit a626bc2

File tree

4 files changed

+65
-8
lines changed

4 files changed

+65
-8
lines changed

pyiceberg/expressions/visitors.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -893,15 +893,28 @@ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpr
893893
raise TypeError(f"Expected Bound Predicate, got: {predicate.term}")
894894

895895
def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression:
896-
file_column_name = self.file_schema.find_column_name(predicate.term.ref().field.field_id)
896+
field = predicate.term.ref().field
897+
file_column_name = self.file_schema.find_column_name(field.field_id)
897898

898899
if file_column_name is None:
899900
# In the case of schema evolution, the column might not be present
900-
# in the file schema when reading older data
901-
if isinstance(predicate, BoundIsNull):
902-
return AlwaysTrue()
901+
# we can use the default value as a constant and evaluate it against
902+
# the predicate
903+
pred: BooleanExpression
904+
if isinstance(predicate, BoundUnaryPredicate):
905+
pred = predicate.as_unbound(field.name)
906+
elif isinstance(predicate, BoundLiteralPredicate):
907+
pred = predicate.as_unbound(field.name, predicate.literal)
908+
elif isinstance(predicate, BoundSetPredicate):
909+
pred = predicate.as_unbound(field.name, predicate.literals)
903910
else:
904-
return AlwaysFalse()
911+
raise ValueError(f"Unsupported predicate: {predicate}")
912+
913+
return (
914+
AlwaysTrue()
915+
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field.initial_default))
916+
else AlwaysFalse()
917+
)
905918

906919
if isinstance(predicate, BoundUnaryPredicate):
907920
return predicate.as_unbound(file_column_name)

pyiceberg/io/pyarrow.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,9 +1814,13 @@ def struct(
18141814
array = self._cast_if_needed(field, field_array)
18151815
field_arrays.append(array)
18161816
fields.append(self._construct_field(field, array.type))
1817-
elif field.optional:
1817+
elif field.optional or field.initial_default is not None:
1818+
# When an optional field is added, or when a required field with a non-null initial default is added
18181819
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)
1819-
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
1820+
if field.initial_default is None:
1821+
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
1822+
else:
1823+
field_arrays.append(pa.repeat(field.initial_default, len(struct_array)))
18201824
fields.append(self._construct_field(field, arrow_type))
18211825
else:
18221826
raise ResolveError(f"Field is required, and could not be found in the file: {field}")
@@ -2249,7 +2253,7 @@ def parquet_path_to_id_mapping(
22492253
Compute the mapping of parquet column path to Iceberg ID.
22502254
22512255
For each column, the parquet file metadata has a path_in_schema attribute that follows
2252-
a specific naming scheme for nested columnds. This function computes a mapping of
2256+
a specific naming scheme for nested columns. This function computes a mapping of
22532257
the full paths to the corresponding Iceberg IDs.
22542258
22552259
Args:

tests/integration/test_reads.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from hive_metastore.ttypes import LockRequest, LockResponse, LockState, UnlockRequest
3030
from pyarrow.fs import S3FileSystem
3131
from pydantic_core import ValidationError
32+
from pyspark.sql import SparkSession
3233

3334
from pyiceberg.catalog import Catalog
3435
from pyiceberg.catalog.hive import HiveCatalog, _HiveClient
@@ -1024,3 +1025,31 @@ def test_scan_with_datetime(catalog: Catalog) -> None:
10241025

10251026
df = table.scan(row_filter=LessThan("datetime", yesterday)).to_pandas()
10261027
assert len(df) == 0
1028+
1029+
1030+
@pytest.mark.integration
1031+
# TODO: For Hive we require writing V3
1032+
# @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
1033+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog")])
1034+
def test_initial_default(catalog: Catalog, spark: SparkSession) -> None:
1035+
identifier = "default.test_initial_default"
1036+
try:
1037+
catalog.drop_table(identifier)
1038+
except NoSuchTableError:
1039+
pass
1040+
1041+
one_column = pa.table([pa.nulls(10, pa.int32())], names=["some_field"])
1042+
1043+
tbl = catalog.create_table(identifier, schema=one_column.schema, properties={"format-version": "2"})
1044+
1045+
tbl.append(one_column)
1046+
1047+
# Do the bump version through Spark, since PyIceberg does not support this (yet)
1048+
spark.sql(f"ALTER TABLE {identifier} SET TBLPROPERTIES('format-version'='3')")
1049+
1050+
with tbl.update_schema() as upd:
1051+
upd.add_column("so_true", BooleanType(), required=False, default_value=True)
1052+
1053+
result_table = tbl.scan().filter("so_true == True").to_arrow()
1054+
1055+
assert len(result_table) == 10

tests/io/test_pyarrow.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2398,6 +2398,17 @@ def test_identity_partition_on_multi_columns() -> None:
23982398
) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")])
23992399

24002400

2401+
def test_initial_value() -> None:
2402+
# Have some fake data, otherwise it will generate a table without records
2403+
data = pa.record_batch([pa.nulls(10, pa.int64())], names=["some_field"])
2404+
result = _to_requested_schema(
2405+
Schema(NestedField(1, "we-love-22", LongType(), required=True, initial_default=22)), Schema(), data
2406+
)
2407+
assert result.column_names == ["we-love-22"]
2408+
for val in result[0]:
2409+
assert val.as_py() == 22
2410+
2411+
24012412
def test__to_requested_schema_timestamps(
24022413
arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
24032414
arrow_table_with_all_timestamp_precisions: pa.Table,

0 commit comments

Comments
 (0)