Skip to content

Commit 9b6503d

Browse files
authored
fix table.delete()/overwrite() with null values (#955)
* fix * naming * handle nan as well * naming as sung suggested * one more test to fix; one more comment to address * fix a test * refactor code organization * fix mouthful naming * restore usage of BoundTerm * small fixes for comments * small fix for typing * fix typing according to pr comment
1 parent 861c563 commit 9b6503d

File tree

5 files changed

+431
-12
lines changed

5 files changed

+431
-12
lines changed

pyiceberg/expressions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def __repr__(self) -> str:
135135
def ref(self) -> BoundReference[L]:
136136
return self
137137

138+
def __hash__(self) -> int:
139+
"""Return hash value of the BoundReference class."""
140+
return hash(str(self))
141+
138142

139143
class UnboundTerm(Term[Any], Unbound[BoundTerm[L]], ABC):
140144
"""Represents an unbound term."""

pyiceberg/io/pyarrow.py

Lines changed: 145 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,7 @@
7373

7474
from pyiceberg.conversions import to_bytes
7575
from pyiceberg.exceptions import ResolveError
76-
from pyiceberg.expressions import (
77-
AlwaysTrue,
78-
BooleanExpression,
79-
BoundTerm,
80-
)
76+
from pyiceberg.expressions import AlwaysTrue, BooleanExpression, BoundIsNaN, BoundIsNull, BoundTerm, Not, Or
8177
from pyiceberg.expressions.literals import Literal
8278
from pyiceberg.expressions.visitors import (
8379
BoundBooleanExpressionVisitor,
@@ -576,11 +572,11 @@ def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar:
576572

577573

578574
class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
579-
def visit_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression:
575+
def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
580576
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
581577
return pc.field(term.ref().field.name).isin(pyarrow_literals)
582578

583-
def visit_not_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression:
579+
def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
584580
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
585581
return ~pc.field(term.ref().field.name).isin(pyarrow_literals)
586582

@@ -638,10 +634,152 @@ def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> p
638634
return left_result | right_result
639635

640636

637+
class _NullNaNUnmentionedTermsCollector(BoundBooleanExpressionVisitor[None]):
638+
# BoundTerms which have either is_null or is_not_null appearing at least once in the boolean expr.
639+
is_null_or_not_bound_terms: set[BoundTerm[Any]]
640+
# The remaining BoundTerms appearing in the boolean expr.
641+
null_unmentioned_bound_terms: set[BoundTerm[Any]]
642+
# BoundTerms which have either is_nan or is_not_nan appearing at least once in the boolean expr.
643+
is_nan_or_not_bound_terms: set[BoundTerm[Any]]
644+
# The remaining BoundTerms appearing in the boolean expr.
645+
nan_unmentioned_bound_terms: set[BoundTerm[Any]]
646+
647+
def __init__(self) -> None:
648+
super().__init__()
649+
self.is_null_or_not_bound_terms = set()
650+
self.null_unmentioned_bound_terms = set()
651+
self.is_nan_or_not_bound_terms = set()
652+
self.nan_unmentioned_bound_terms = set()
653+
654+
def _handle_explicit_is_null_or_not(self, term: BoundTerm[Any]) -> None:
655+
"""Handle the predicate case where either is_null or is_not_null is included."""
656+
if term in self.null_unmentioned_bound_terms:
657+
self.null_unmentioned_bound_terms.remove(term)
658+
self.is_null_or_not_bound_terms.add(term)
659+
660+
def _handle_null_unmentioned(self, term: BoundTerm[Any]) -> None:
661+
"""Handle the predicate case where neither is_null or is_not_null is included."""
662+
if term not in self.is_null_or_not_bound_terms:
663+
self.null_unmentioned_bound_terms.add(term)
664+
665+
def _handle_explicit_is_nan_or_not(self, term: BoundTerm[Any]) -> None:
666+
"""Handle the predicate case where either is_nan or is_not_nan is included."""
667+
if term in self.nan_unmentioned_bound_terms:
668+
self.nan_unmentioned_bound_terms.remove(term)
669+
self.is_nan_or_not_bound_terms.add(term)
670+
671+
def _handle_nan_unmentioned(self, term: BoundTerm[Any]) -> None:
672+
"""Handle the predicate case where neither is_nan or is_not_nan is included."""
673+
if term not in self.is_nan_or_not_bound_terms:
674+
self.nan_unmentioned_bound_terms.add(term)
675+
676+
def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> None:
677+
self._handle_null_unmentioned(term)
678+
self._handle_nan_unmentioned(term)
679+
680+
def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> None:
681+
self._handle_null_unmentioned(term)
682+
self._handle_nan_unmentioned(term)
683+
684+
def visit_is_nan(self, term: BoundTerm[Any]) -> None:
685+
self._handle_null_unmentioned(term)
686+
self._handle_explicit_is_nan_or_not(term)
687+
688+
def visit_not_nan(self, term: BoundTerm[Any]) -> None:
689+
self._handle_null_unmentioned(term)
690+
self._handle_explicit_is_nan_or_not(term)
691+
692+
def visit_is_null(self, term: BoundTerm[Any]) -> None:
693+
self._handle_explicit_is_null_or_not(term)
694+
self._handle_nan_unmentioned(term)
695+
696+
def visit_not_null(self, term: BoundTerm[Any]) -> None:
697+
self._handle_explicit_is_null_or_not(term)
698+
self._handle_nan_unmentioned(term)
699+
700+
def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
701+
self._handle_null_unmentioned(term)
702+
self._handle_nan_unmentioned(term)
703+
704+
def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
705+
self._handle_null_unmentioned(term)
706+
self._handle_nan_unmentioned(term)
707+
708+
def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
709+
self._handle_null_unmentioned(term)
710+
self._handle_nan_unmentioned(term)
711+
712+
def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
713+
self._handle_null_unmentioned(term)
714+
self._handle_nan_unmentioned(term)
715+
716+
def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
717+
self._handle_null_unmentioned(term)
718+
self._handle_nan_unmentioned(term)
719+
720+
def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
721+
self._handle_null_unmentioned(term)
722+
self._handle_nan_unmentioned(term)
723+
724+
def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
725+
self._handle_null_unmentioned(term)
726+
self._handle_nan_unmentioned(term)
727+
728+
def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
729+
self._handle_null_unmentioned(term)
730+
self._handle_nan_unmentioned(term)
731+
732+
def visit_true(self) -> None:
733+
return
734+
735+
def visit_false(self) -> None:
736+
return
737+
738+
def visit_not(self, child_result: None) -> None:
739+
return
740+
741+
def visit_and(self, left_result: None, right_result: None) -> None:
742+
return
743+
744+
def visit_or(self, left_result: None, right_result: None) -> None:
745+
return
746+
747+
def collect(
748+
self,
749+
expr: BooleanExpression,
750+
) -> None:
751+
"""Collect the bound references categorized by having at least one is_null or is_not_null in the expr and the remaining."""
752+
boolean_expression_visit(expr, self)
753+
754+
641755
def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
642756
return boolean_expression_visit(expr, _ConvertToArrowExpression())
643757

644758

759+
def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expression:
760+
"""Complementary filter conversion function of expression_to_pyarrow.
761+
762+
Could not use expression_to_pyarrow(Not(expr)) to achieve this complementary effect because ~ in pyarrow.compute.Expression does not handle null.
763+
"""
764+
collector = _NullNaNUnmentionedTermsCollector()
765+
collector.collect(expr)
766+
767+
# Convert the set of terms to a sorted list so that layout of the expression to build is deterministic.
768+
null_unmentioned_bound_terms: List[BoundTerm[Any]] = sorted(
769+
collector.null_unmentioned_bound_terms, key=lambda term: term.ref().field.name
770+
)
771+
nan_unmentioned_bound_terms: List[BoundTerm[Any]] = sorted(
772+
collector.nan_unmentioned_bound_terms, key=lambda term: term.ref().field.name
773+
)
774+
775+
preserve_expr: BooleanExpression = Not(expr)
776+
for term in null_unmentioned_bound_terms:
777+
preserve_expr = Or(preserve_expr, BoundIsNull(term=term))
778+
for term in nan_unmentioned_bound_terms:
779+
preserve_expr = Or(preserve_expr, BoundIsNaN(term=term))
780+
return expression_to_pyarrow(preserve_expr)
781+
782+
645783
@lru_cache
646784
def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat:
647785
if file_format == FileFormat.PARQUET:

pyiceberg/table/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
And,
5959
BooleanExpression,
6060
EqualTo,
61-
Not,
6261
Or,
6362
Reference,
6463
)
@@ -576,7 +575,11 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
576575
delete_filter: A boolean expression to delete rows from a table
577576
snapshot_properties: Custom properties to be added to the snapshot summary
578577
"""
579-
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table
578+
from pyiceberg.io.pyarrow import (
579+
_dataframe_to_data_files,
580+
_expression_to_complementary_pyarrow,
581+
project_table,
582+
)
580583

581584
if (
582585
self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT)
@@ -593,7 +596,7 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
593596
# Check if there are any files that require an actual rewrite of a data file
594597
if delete_snapshot.rewrites_needed is True:
595598
bound_delete_filter = bind(self._table.schema(), delete_filter, case_sensitive=True)
596-
preserve_row_filter = expression_to_pyarrow(Not(bound_delete_filter))
599+
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter)
597600

598601
files = self._scan(row_filter=delete_filter).plan_files()
599602

tests/integration/test_deletes.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pyiceberg.manifest import ManifestEntryStatus
2828
from pyiceberg.schema import Schema
2929
from pyiceberg.table.snapshots import Operation, Summary
30-
from pyiceberg.types import IntegerType, NestedField
30+
from pyiceberg.types import FloatType, IntegerType, NestedField
3131

3232

3333
def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None:
@@ -105,6 +105,40 @@ def test_partitioned_table_rewrite(spark: SparkSession, session_catalog: RestCat
105105
assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 10], "number": [30, 30]}
106106

107107

108+
@pytest.mark.parametrize("format_version", [1, 2])
109+
def test_rewrite_partitioned_table_with_null(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None:
110+
identifier = "default.table_partitioned_delete"
111+
112+
run_spark_commands(
113+
spark,
114+
[
115+
f"DROP TABLE IF EXISTS {identifier}",
116+
f"""
117+
CREATE TABLE {identifier} (
118+
number_partitioned int,
119+
number int
120+
)
121+
USING iceberg
122+
PARTITIONED BY (number_partitioned)
123+
TBLPROPERTIES('format-version' = {format_version})
124+
""",
125+
f"""
126+
INSERT INTO {identifier} VALUES (10, 20), (10, 30)
127+
""",
128+
f"""
129+
INSERT INTO {identifier} VALUES (11, 20), (11, NULL)
130+
""",
131+
],
132+
)
133+
134+
tbl = session_catalog.load_table(identifier)
135+
tbl.delete(EqualTo("number", 20))
136+
137+
# We don't delete a whole partition, so there is only a overwrite
138+
assert [snapshot.summary.operation.value for snapshot in tbl.snapshots()] == ["append", "append", "overwrite"]
139+
assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 10], "number": [None, 30]}
140+
141+
108142
@pytest.mark.integration
109143
@pytest.mark.parametrize("format_version", [1, 2])
110144
def test_partitioned_table_no_match(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None:
@@ -417,3 +451,105 @@ def test_delete_truncate(session_catalog: RestCatalog) -> None:
417451
assert len(entries) == 1
418452

419453
assert entries[0].status == ManifestEntryStatus.DELETED
454+
455+
456+
def test_delete_overwrite_table_with_null(session_catalog: RestCatalog) -> None:
457+
arrow_schema = pa.schema([pa.field("ints", pa.int32())])
458+
arrow_tbl = pa.Table.from_pylist(
459+
[{"ints": 1}, {"ints": 2}, {"ints": None}],
460+
schema=arrow_schema,
461+
)
462+
463+
iceberg_schema = Schema(NestedField(1, "ints", IntegerType()))
464+
465+
tbl_identifier = "default.test_delete_overwrite_with_null"
466+
467+
try:
468+
session_catalog.drop_table(tbl_identifier)
469+
except NoSuchTableError:
470+
pass
471+
472+
tbl = session_catalog.create_table(tbl_identifier, iceberg_schema)
473+
tbl.append(arrow_tbl)
474+
475+
assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [Operation.APPEND]
476+
477+
arrow_tbl_overwrite = pa.Table.from_pylist(
478+
[
479+
{"ints": 3},
480+
{"ints": 4},
481+
],
482+
schema=arrow_schema,
483+
)
484+
tbl.overwrite(arrow_tbl_overwrite, "ints == 2") # Should rewrite one file
485+
486+
assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [
487+
Operation.APPEND,
488+
Operation.OVERWRITE,
489+
Operation.APPEND,
490+
]
491+
492+
assert tbl.scan().to_arrow()["ints"].to_pylist() == [3, 4, 1, None]
493+
494+
495+
def test_delete_overwrite_table_with_nan(session_catalog: RestCatalog) -> None:
496+
arrow_schema = pa.schema([pa.field("floats", pa.float32())])
497+
498+
# Create Arrow Table with NaN values
499+
data = [pa.array([1.0, float("nan"), 2.0], type=pa.float32())]
500+
arrow_tbl = pa.Table.from_arrays(
501+
data,
502+
schema=arrow_schema,
503+
)
504+
505+
iceberg_schema = Schema(NestedField(1, "floats", FloatType()))
506+
507+
tbl_identifier = "default.test_delete_overwrite_with_nan"
508+
509+
try:
510+
session_catalog.drop_table(tbl_identifier)
511+
except NoSuchTableError:
512+
pass
513+
514+
tbl = session_catalog.create_table(tbl_identifier, iceberg_schema)
515+
tbl.append(arrow_tbl)
516+
517+
assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [Operation.APPEND]
518+
519+
arrow_tbl_overwrite = pa.Table.from_pylist(
520+
[
521+
{"floats": 3.0},
522+
{"floats": 4.0},
523+
],
524+
schema=arrow_schema,
525+
)
526+
"""
527+
We want to test the _expression_to_complementary_pyarrow function can generate a correct complimentary filter
528+
for selecting records to remain in the new overwritten file.
529+
Compared with test_delete_overwrite_table_with_null which tests rows with null cells,
530+
nan testing is faced with a more tricky issue:
531+
A filter of (field == value) will not include cells of nan but col != val will.
532+
(Interestingly, neither == or != will include null)
533+
534+
This means if we set the test case as floats == 2.0 (equal predicate as in test_delete_overwrite_table_with_null),
535+
test will pass even without the logic under test
536+
in _NullNaNUnmentionedTermsCollector (a helper of _expression_to_complementary_pyarrow
537+
to handle revert of iceberg expression of is_null/not_null/is_nan/not_nan).
538+
Instead, we test the filter of !=, so that the revert is == which exposes the issue.
539+
"""
540+
tbl.overwrite(arrow_tbl_overwrite, "floats != 2.0") # Should rewrite one file
541+
542+
assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [
543+
Operation.APPEND,
544+
Operation.OVERWRITE,
545+
Operation.APPEND,
546+
]
547+
548+
result = tbl.scan().to_arrow()["floats"].to_pylist()
549+
550+
from math import isnan
551+
552+
assert any(isnan(e) for e in result)
553+
assert 2.0 in result
554+
assert 3.0 in result
555+
assert 4.0 in result

0 commit comments

Comments
 (0)