|
73 | 73 |
|
74 | 74 | from pyiceberg.conversions import to_bytes
|
75 | 75 | from pyiceberg.exceptions import ResolveError
|
76 |
| -from pyiceberg.expressions import ( |
77 |
| - AlwaysTrue, |
78 |
| - BooleanExpression, |
79 |
| - BoundTerm, |
80 |
| -) |
| 76 | +from pyiceberg.expressions import AlwaysTrue, BooleanExpression, BoundIsNaN, BoundIsNull, BoundTerm, Not, Or |
81 | 77 | from pyiceberg.expressions.literals import Literal
|
82 | 78 | from pyiceberg.expressions.visitors import (
|
83 | 79 | BoundBooleanExpressionVisitor,
|
@@ -576,11 +572,11 @@ def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar:
|
576 | 572 |
|
577 | 573 |
|
578 | 574 | class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
|
579 |
| - def visit_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression: |
| 575 | + def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression: |
580 | 576 | pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
|
581 | 577 | return pc.field(term.ref().field.name).isin(pyarrow_literals)
|
582 | 578 |
|
583 |
| - def visit_not_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression: |
| 579 | + def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression: |
584 | 580 | pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
|
585 | 581 | return ~pc.field(term.ref().field.name).isin(pyarrow_literals)
|
586 | 582 |
|
@@ -638,10 +634,152 @@ def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> p
|
638 | 634 | return left_result | right_result
|
639 | 635 |
|
640 | 636 |
|
| 637 | +class _NullNaNUnmentionedTermsCollector(BoundBooleanExpressionVisitor[None]): |
| 638 | + # BoundTerms which have either is_null or is_not_null appearing at least once in the boolean expr. |
| 639 | + is_null_or_not_bound_terms: set[BoundTerm[Any]] |
| 640 | + # The remaining BoundTerms appearing in the boolean expr. |
| 641 | + null_unmentioned_bound_terms: set[BoundTerm[Any]] |
| 642 | + # BoundTerms which have either is_nan or is_not_nan appearing at least once in the boolean expr. |
| 643 | + is_nan_or_not_bound_terms: set[BoundTerm[Any]] |
| 644 | + # The remaining BoundTerms appearing in the boolean expr. |
| 645 | + nan_unmentioned_bound_terms: set[BoundTerm[Any]] |
| 646 | + |
| 647 | + def __init__(self) -> None: |
| 648 | + super().__init__() |
| 649 | + self.is_null_or_not_bound_terms = set() |
| 650 | + self.null_unmentioned_bound_terms = set() |
| 651 | + self.is_nan_or_not_bound_terms = set() |
| 652 | + self.nan_unmentioned_bound_terms = set() |
| 653 | + |
| 654 | + def _handle_explicit_is_null_or_not(self, term: BoundTerm[Any]) -> None: |
| 655 | + """Handle the predicate case where either is_null or is_not_null is included.""" |
| 656 | + if term in self.null_unmentioned_bound_terms: |
| 657 | + self.null_unmentioned_bound_terms.remove(term) |
| 658 | + self.is_null_or_not_bound_terms.add(term) |
| 659 | + |
| 660 | + def _handle_null_unmentioned(self, term: BoundTerm[Any]) -> None: |
| 661 | + """Handle the predicate case where neither is_null or is_not_null is included.""" |
| 662 | + if term not in self.is_null_or_not_bound_terms: |
| 663 | + self.null_unmentioned_bound_terms.add(term) |
| 664 | + |
| 665 | + def _handle_explicit_is_nan_or_not(self, term: BoundTerm[Any]) -> None: |
| 666 | + """Handle the predicate case where either is_nan or is_not_nan is included.""" |
| 667 | + if term in self.nan_unmentioned_bound_terms: |
| 668 | + self.nan_unmentioned_bound_terms.remove(term) |
| 669 | + self.is_nan_or_not_bound_terms.add(term) |
| 670 | + |
| 671 | + def _handle_nan_unmentioned(self, term: BoundTerm[Any]) -> None: |
| 672 | + """Handle the predicate case where neither is_nan or is_not_nan is included.""" |
| 673 | + if term not in self.is_nan_or_not_bound_terms: |
| 674 | + self.nan_unmentioned_bound_terms.add(term) |
| 675 | + |
| 676 | + def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> None: |
| 677 | + self._handle_null_unmentioned(term) |
| 678 | + self._handle_nan_unmentioned(term) |
| 679 | + |
| 680 | + def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> None: |
| 681 | + self._handle_null_unmentioned(term) |
| 682 | + self._handle_nan_unmentioned(term) |
| 683 | + |
| 684 | + def visit_is_nan(self, term: BoundTerm[Any]) -> None: |
| 685 | + self._handle_null_unmentioned(term) |
| 686 | + self._handle_explicit_is_nan_or_not(term) |
| 687 | + |
| 688 | + def visit_not_nan(self, term: BoundTerm[Any]) -> None: |
| 689 | + self._handle_null_unmentioned(term) |
| 690 | + self._handle_explicit_is_nan_or_not(term) |
| 691 | + |
| 692 | + def visit_is_null(self, term: BoundTerm[Any]) -> None: |
| 693 | + self._handle_explicit_is_null_or_not(term) |
| 694 | + self._handle_nan_unmentioned(term) |
| 695 | + |
| 696 | + def visit_not_null(self, term: BoundTerm[Any]) -> None: |
| 697 | + self._handle_explicit_is_null_or_not(term) |
| 698 | + self._handle_nan_unmentioned(term) |
| 699 | + |
| 700 | + def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: |
| 701 | + self._handle_null_unmentioned(term) |
| 702 | + self._handle_nan_unmentioned(term) |
| 703 | + |
| 704 | + def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: |
| 705 | + self._handle_null_unmentioned(term) |
| 706 | + self._handle_nan_unmentioned(term) |
| 707 | + |
| 708 | + def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: |
| 709 | + self._handle_null_unmentioned(term) |
| 710 | + self._handle_nan_unmentioned(term) |
| 711 | + |
| 712 | + def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: |
| 713 | + self._handle_null_unmentioned(term) |
| 714 | + self._handle_nan_unmentioned(term) |
| 715 | + |
| 716 | + def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: |
| 717 | + self._handle_null_unmentioned(term) |
| 718 | + self._handle_nan_unmentioned(term) |
| 719 | + |
| 720 | + def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: |
| 721 | + self._handle_null_unmentioned(term) |
| 722 | + self._handle_nan_unmentioned(term) |
| 723 | + |
| 724 | + def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: |
| 725 | + self._handle_null_unmentioned(term) |
| 726 | + self._handle_nan_unmentioned(term) |
| 727 | + |
| 728 | + def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: |
| 729 | + self._handle_null_unmentioned(term) |
| 730 | + self._handle_nan_unmentioned(term) |
| 731 | + |
| 732 | + def visit_true(self) -> None: |
| 733 | + return |
| 734 | + |
| 735 | + def visit_false(self) -> None: |
| 736 | + return |
| 737 | + |
| 738 | + def visit_not(self, child_result: None) -> None: |
| 739 | + return |
| 740 | + |
| 741 | + def visit_and(self, left_result: None, right_result: None) -> None: |
| 742 | + return |
| 743 | + |
| 744 | + def visit_or(self, left_result: None, right_result: None) -> None: |
| 745 | + return |
| 746 | + |
| 747 | + def collect( |
| 748 | + self, |
| 749 | + expr: BooleanExpression, |
| 750 | + ) -> None: |
| 751 | + """Collect the bound references categorized by having at least one is_null or is_not_null in the expr and the remaining.""" |
| 752 | + boolean_expression_visit(expr, self) |
| 753 | + |
| 754 | + |
641 | 755 | def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
|
642 | 756 | return boolean_expression_visit(expr, _ConvertToArrowExpression())
|
643 | 757 |
|
644 | 758 |
|
| 759 | +def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expression: |
| 760 | + """Complementary filter conversion function of expression_to_pyarrow. |
| 761 | +
|
| 762 | + Could not use expression_to_pyarrow(Not(expr)) to achieve this complementary effect because ~ in pyarrow.compute.Expression does not handle null. |
| 763 | + """ |
| 764 | + collector = _NullNaNUnmentionedTermsCollector() |
| 765 | + collector.collect(expr) |
| 766 | + |
| 767 | + # Convert the set of terms to a sorted list so that layout of the expression to build is deterministic. |
| 768 | + null_unmentioned_bound_terms: List[BoundTerm[Any]] = sorted( |
| 769 | + collector.null_unmentioned_bound_terms, key=lambda term: term.ref().field.name |
| 770 | + ) |
| 771 | + nan_unmentioned_bound_terms: List[BoundTerm[Any]] = sorted( |
| 772 | + collector.nan_unmentioned_bound_terms, key=lambda term: term.ref().field.name |
| 773 | + ) |
| 774 | + |
| 775 | + preserve_expr: BooleanExpression = Not(expr) |
| 776 | + for term in null_unmentioned_bound_terms: |
| 777 | + preserve_expr = Or(preserve_expr, BoundIsNull(term=term)) |
| 778 | + for term in nan_unmentioned_bound_terms: |
| 779 | + preserve_expr = Or(preserve_expr, BoundIsNaN(term=term)) |
| 780 | + return expression_to_pyarrow(preserve_expr) |
| 781 | + |
| 782 | + |
645 | 783 | @lru_cache
|
646 | 784 | def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat:
|
647 | 785 | if file_format == FileFormat.PARQUET:
|
|
0 commit comments