Skip to content

Commit 4ff1558

Browse files
tusharchouFokko
andauthored
Add ResidualVisitor to compute residuals (#1388)
closes issue: Count rows as a metadata-only operation #1223 --------- Co-authored-by: Fokko Driesprong <[email protected]>
1 parent d6497a5 commit 4ff1558

File tree

4 files changed

+702
-3
lines changed

4 files changed

+702
-3
lines changed

pyiceberg/expressions/visitors.py

Lines changed: 231 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Generic,
2525
List,
2626
Set,
27+
SupportsFloat,
2728
Tuple,
2829
TypeVar,
2930
Union,
@@ -60,9 +61,9 @@
6061
)
6162
from pyiceberg.expressions.literals import Literal
6263
from pyiceberg.manifest import DataFile, ManifestFile, PartitionFieldSummary
63-
from pyiceberg.partitioning import PartitionSpec
64+
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
6465
from pyiceberg.schema import Schema
65-
from pyiceberg.typedef import EMPTY_DICT, L, StructProtocol
66+
from pyiceberg.typedef import EMPTY_DICT, L, Record, StructProtocol
6667
from pyiceberg.types import (
6768
DoubleType,
6869
FloatType,
@@ -1731,3 +1732,231 @@ def _can_contain_nulls(self, field_id: int) -> bool:
17311732

17321733
def _can_contain_nans(self, field_id: int) -> bool:
17331734
return (nan_count := self.nan_counts.get(field_id)) is not None and nan_count > 0
1735+
1736+
1737+
class ResidualVisitor(BoundBooleanExpressionVisitor[BooleanExpression], ABC):
1738+
"""Finds the residuals for an Expression the partitions in the given PartitionSpec.
1739+
1740+
A residual expression is made by partially evaluating an expression using partition values.
1741+
For example, if a table is partitioned by day(utc_timestamp) and is read with a filter expression
1742+
utc_timestamp > a and utc_timestamp < b, then there are 4 possible residuals expressions
1743+
for the partition data, d:
1744+
1745+
1746+
1. If d > day(a) and d &lt; day(b), the residual is always true
1747+
2. If d == day(a) and d != day(b), the residual is utc_timestamp > a
1748+
3. if d == day(b) and d != day(a), the residual is utc_timestamp < b
1749+
4. If d == day(a) == day(b), the residual is utc_timestamp > a and utc_timestamp < b
1750+
Partition data is passed using StructLike. Residuals are returned by residualFor(StructLike).
1751+
"""
1752+
1753+
schema: Schema
1754+
spec: PartitionSpec
1755+
case_sensitive: bool
1756+
expr: BooleanExpression
1757+
1758+
def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, expr: BooleanExpression) -> None:
1759+
self.schema = schema
1760+
self.spec = spec
1761+
self.case_sensitive = case_sensitive
1762+
self.expr = expr
1763+
1764+
def eval(self, partition_data: Record) -> BooleanExpression:
1765+
self.struct = partition_data
1766+
return visit(self.expr, visitor=self)
1767+
1768+
def visit_true(self) -> BooleanExpression:
1769+
return AlwaysTrue()
1770+
1771+
def visit_false(self) -> BooleanExpression:
1772+
return AlwaysFalse()
1773+
1774+
def visit_not(self, child_result: BooleanExpression) -> BooleanExpression:
1775+
return Not(child_result)
1776+
1777+
def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression:
1778+
return And(left_result, right_result)
1779+
1780+
def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression:
1781+
return Or(left_result, right_result)
1782+
1783+
def visit_is_null(self, term: BoundTerm[L]) -> BooleanExpression:
1784+
if term.eval(self.struct) is None:
1785+
return AlwaysTrue()
1786+
else:
1787+
return AlwaysFalse()
1788+
1789+
def visit_not_null(self, term: BoundTerm[L]) -> BooleanExpression:
1790+
if term.eval(self.struct) is not None:
1791+
return AlwaysTrue()
1792+
else:
1793+
return AlwaysFalse()
1794+
1795+
def visit_is_nan(self, term: BoundTerm[L]) -> BooleanExpression:
1796+
val = term.eval(self.struct)
1797+
if isinstance(val, SupportsFloat) and math.isnan(val):
1798+
return self.visit_true()
1799+
else:
1800+
return self.visit_false()
1801+
1802+
def visit_not_nan(self, term: BoundTerm[L]) -> BooleanExpression:
1803+
val = term.eval(self.struct)
1804+
if isinstance(val, SupportsFloat) and not math.isnan(val):
1805+
return self.visit_true()
1806+
else:
1807+
return self.visit_false()
1808+
1809+
def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression:
1810+
if term.eval(self.struct) < literal.value:
1811+
return self.visit_true()
1812+
else:
1813+
return self.visit_false()
1814+
1815+
def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression:
1816+
if term.eval(self.struct) <= literal.value:
1817+
return self.visit_true()
1818+
else:
1819+
return self.visit_false()
1820+
1821+
def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression:
1822+
if term.eval(self.struct) > literal.value:
1823+
return self.visit_true()
1824+
else:
1825+
return self.visit_false()
1826+
1827+
def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression:
1828+
if term.eval(self.struct) >= literal.value:
1829+
return self.visit_true()
1830+
else:
1831+
return self.visit_false()
1832+
1833+
def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression:
1834+
if term.eval(self.struct) == literal.value:
1835+
return self.visit_true()
1836+
else:
1837+
return self.visit_false()
1838+
1839+
def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression:
1840+
if term.eval(self.struct) != literal.value:
1841+
return self.visit_true()
1842+
else:
1843+
return self.visit_false()
1844+
1845+
def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression:
1846+
if term.eval(self.struct) in literals:
1847+
return self.visit_true()
1848+
else:
1849+
return self.visit_false()
1850+
1851+
def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression:
1852+
if term.eval(self.struct) not in literals:
1853+
return self.visit_true()
1854+
else:
1855+
return self.visit_false()
1856+
1857+
def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression:
1858+
eval_res = term.eval(self.struct)
1859+
if eval_res is not None and str(eval_res).startswith(str(literal.value)):
1860+
return AlwaysTrue()
1861+
else:
1862+
return AlwaysFalse()
1863+
1864+
def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression:
1865+
if not self.visit_starts_with(term, literal):
1866+
return AlwaysTrue()
1867+
else:
1868+
return AlwaysFalse()
1869+
1870+
def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression:
1871+
"""
1872+
If there is no strict projection or if it evaluates to false, then return the predicate.
1873+
1874+
Get the strict projection and inclusive projection of this predicate in partition data,
1875+
then use them to determine whether to return the original predicate. The strict projection
1876+
returns true iff the original predicate would have returned true, so the predicate can be
1877+
eliminated if the strict projection evaluates to true. Similarly the inclusive projection
1878+
returns false iff the original predicate would have returned false, so the predicate can
1879+
also be eliminated if the inclusive projection evaluates to false.
1880+
1881+
"""
1882+
parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id)
1883+
if parts == []:
1884+
return predicate
1885+
1886+
def struct_to_schema(struct: StructType) -> Schema:
1887+
return Schema(*struct.fields)
1888+
1889+
for part in parts:
1890+
strict_projection = part.transform.strict_project(part.name, predicate)
1891+
strict_result = None
1892+
1893+
if strict_projection is not None:
1894+
bound = strict_projection.bind(
1895+
struct_to_schema(self.spec.partition_type(self.schema)), case_sensitive=self.case_sensitive
1896+
)
1897+
if isinstance(bound, BoundPredicate):
1898+
strict_result = super().visit_bound_predicate(bound)
1899+
else:
1900+
# if the result is not a predicate, then it must be a constant like alwaysTrue or alwaysFalse
1901+
strict_result = bound
1902+
1903+
if isinstance(strict_result, AlwaysTrue):
1904+
return AlwaysTrue()
1905+
1906+
inclusive_projection = part.transform.project(part.name, predicate)
1907+
inclusive_result = None
1908+
if inclusive_projection is not None:
1909+
bound_inclusive = inclusive_projection.bind(
1910+
struct_to_schema(self.spec.partition_type(self.schema)), case_sensitive=self.case_sensitive
1911+
)
1912+
if isinstance(bound_inclusive, BoundPredicate):
1913+
# using predicate method specific to inclusive
1914+
inclusive_result = super().visit_bound_predicate(bound_inclusive)
1915+
else:
1916+
# if the result is not a predicate, then it must be a constant like alwaysTrue or
1917+
# alwaysFalse
1918+
inclusive_result = bound_inclusive
1919+
if isinstance(inclusive_result, AlwaysFalse):
1920+
return AlwaysFalse()
1921+
1922+
return predicate
1923+
1924+
def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression:
1925+
bound = predicate.bind(self.schema, case_sensitive=self.case_sensitive)
1926+
1927+
if isinstance(bound, BoundPredicate):
1928+
bound_residual = self.visit_bound_predicate(predicate=bound)
1929+
if not isinstance(bound_residual, (AlwaysFalse, AlwaysTrue)):
1930+
# replace inclusive original unbound predicate
1931+
return predicate
1932+
1933+
# use the non-predicate residual (e.g. alwaysTrue)
1934+
return bound_residual
1935+
1936+
# if binding didn't result in a Predicate, return the expression
1937+
return bound
1938+
1939+
1940+
class ResidualEvaluator(ResidualVisitor):
1941+
def residual_for(self, partition_data: Record) -> BooleanExpression:
1942+
return self.eval(partition_data)
1943+
1944+
1945+
class UnpartitionedResidualEvaluator(ResidualEvaluator):
1946+
# Finds the residuals for an Expression the partitions in the given PartitionSpec
1947+
def __init__(self, schema: Schema, expr: BooleanExpression):
1948+
super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False)
1949+
self.expr = expr
1950+
1951+
def residual_for(self, partition_data: Record) -> BooleanExpression:
1952+
return self.expr
1953+
1954+
1955+
def residual_evaluator_of(
1956+
spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema
1957+
) -> ResidualEvaluator:
1958+
return (
1959+
UnpartitionedResidualEvaluator(schema=schema, expr=expr)
1960+
if spec.is_unpartitioned()
1961+
else ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive)
1962+
)

0 commit comments

Comments
 (0)