Skip to content

Commit f1c2ef2

Browse files
FokkoHonahX
andauthored
Add Strict projection (#539)
* Add Strict projection * Update pyiceberg/expressions/visitors.py Co-authored-by: Honah J. <[email protected]> * Comments, thanks Honah! --------- Co-authored-by: Honah J. <[email protected]>
1 parent 0d12cf4 commit f1c2ef2

File tree

4 files changed

+1029
-10
lines changed

4 files changed

+1029
-10
lines changed

pyiceberg/expressions/visitors.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,6 +1433,30 @@ def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool
14331433
return ROWS_MIGHT_MATCH
14341434

14351435

1436+
def strict_projection(
1437+
schema: Schema, spec: PartitionSpec, case_sensitive: bool = True
1438+
) -> Callable[[BooleanExpression], BooleanExpression]:
1439+
return StrictProjection(schema, spec, case_sensitive).project
1440+
1441+
1442+
class StrictProjection(ProjectionEvaluator):
1443+
def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression:
1444+
parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id)
1445+
1446+
result: BooleanExpression = AlwaysFalse()
1447+
for part in parts:
1448+
# consider (ts > 2019-01-01T01:00:00) with day(ts) and hour(ts)
1449+
# projections: d >= 2019-01-02 and h >= 2019-01-01-02 (note the inclusive bounds).
1450+
# any timestamp where either projection predicate is true must match the original
1451+
# predicate. For example, ts = 2019-01-01T03:00:00 matches the hour projection but not
1452+
# the day, but does match the original predicate.
1453+
strict_projection = part.transform.strict_project(name=part.name, pred=predicate)
1454+
if strict_projection is not None:
1455+
result = Or(result, strict_projection)
1456+
1457+
return result
1458+
1459+
14361460
class _StrictMetricsEvaluator(_MetricsEvaluator):
14371461
struct: StructType
14381462
expr: BooleanExpression

pyiceberg/transforms.py

Lines changed: 137 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
BoundLessThan,
3636
BoundLessThanOrEqual,
3737
BoundLiteralPredicate,
38+
BoundNotEqualTo,
3839
BoundNotIn,
3940
BoundNotStartsWith,
4041
BoundPredicate,
@@ -43,8 +44,11 @@
4344
BoundTerm,
4445
BoundUnaryPredicate,
4546
EqualTo,
47+
GreaterThan,
4648
GreaterThanOrEqual,
49+
LessThan,
4750
LessThanOrEqual,
51+
NotEqualTo,
4852
NotStartsWith,
4953
Reference,
5054
StartsWith,
@@ -144,6 +148,9 @@ def result_type(self, source: IcebergType) -> IcebergType: ...
144148
@abstractmethod
145149
def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]: ...
146150

151+
@abstractmethod
152+
def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]: ...
153+
147154
@property
148155
def preserves_order(self) -> bool:
149156
return False
@@ -216,6 +223,21 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica
216223
# For example, (x > 0) and (x < 3) can be turned into in({1, 2}) and projected.
217224
return None
218225

226+
def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]:
227+
transformer = self.transform(pred.term.ref().field.field_type)
228+
229+
if isinstance(pred.term, BoundTransform):
230+
return _project_transform_predicate(self, name, pred)
231+
elif isinstance(pred, BoundUnaryPredicate):
232+
return pred.as_unbound(Reference(name))
233+
elif isinstance(pred, BoundNotEqualTo):
234+
return pred.as_unbound(Reference(name), _transform_literal(transformer, pred.literal))
235+
elif isinstance(pred, BoundNotIn):
236+
return pred.as_unbound(Reference(name), {_transform_literal(transformer, literal) for literal in pred.literals})
237+
else:
238+
# no strict projection for comparison or equality
239+
return None
240+
219241
def can_transform(self, source: IcebergType) -> bool:
220242
return isinstance(
221243
source,
@@ -306,6 +328,19 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica
306328
else:
307329
return None
308330

331+
def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]:
332+
transformer = self.transform(pred.term.ref().field.field_type)
333+
if isinstance(pred.term, BoundTransform):
334+
return _project_transform_predicate(self, name, pred)
335+
elif isinstance(pred, BoundUnaryPredicate):
336+
return pred.as_unbound(Reference(name))
337+
elif isinstance(pred, BoundLiteralPredicate):
338+
return _truncate_number_strict(name, pred, transformer)
339+
elif isinstance(pred, BoundNotIn):
340+
return _set_apply_transform(name, pred, transformer)
341+
else:
342+
return None
343+
309344
@property
310345
def dedup_name(self) -> str:
311346
return "time"
@@ -516,10 +551,20 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica
516551
return pred.as_unbound(Reference(name))
517552
elif isinstance(pred, BoundLiteralPredicate):
518553
return pred.as_unbound(Reference(name), pred.literal)
519-
elif isinstance(pred, (BoundIn, BoundNotIn)):
554+
elif isinstance(pred, BoundSetPredicate):
555+
return pred.as_unbound(Reference(name), pred.literals)
556+
else:
557+
return None
558+
559+
def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]:
560+
if isinstance(pred, BoundUnaryPredicate):
561+
return pred.as_unbound(Reference(name))
562+
elif isinstance(pred, BoundLiteralPredicate):
563+
return pred.as_unbound(Reference(name), pred.literal)
564+
elif isinstance(pred, BoundSetPredicate):
520565
return pred.as_unbound(Reference(name), pred.literals)
521566
else:
522-
raise ValueError(f"Could not project: {pred}")
567+
return None
523568

524569
@property
525570
def preserves_order(self) -> bool:
@@ -590,6 +635,47 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica
590635
return _truncate_array(name, pred, self.transform(field_type))
591636
return None
592637

638+
def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]:
639+
field_type = pred.term.ref().field.field_type
640+
641+
if isinstance(pred.term, BoundTransform):
642+
return _project_transform_predicate(self, name, pred)
643+
644+
if isinstance(field_type, (IntegerType, LongType, DecimalType)):
645+
if isinstance(pred, BoundUnaryPredicate):
646+
return pred.as_unbound(Reference(name))
647+
elif isinstance(pred, BoundLiteralPredicate):
648+
return _truncate_number_strict(name, pred, self.transform(field_type))
649+
elif isinstance(pred, BoundNotIn):
650+
return _set_apply_transform(name, pred, self.transform(field_type))
651+
else:
652+
return None
653+
654+
if isinstance(pred, BoundLiteralPredicate):
655+
if isinstance(pred, BoundStartsWith):
656+
literal_width = len(pred.literal.value)
657+
if literal_width < self.width:
658+
return pred.as_unbound(name, pred.literal.value)
659+
elif literal_width == self.width:
660+
return EqualTo(name, pred.literal.value)
661+
else:
662+
return None
663+
elif isinstance(pred, BoundNotStartsWith):
664+
literal_width = len(pred.literal.value)
665+
if literal_width < self.width:
666+
return pred.as_unbound(name, pred.literal.value)
667+
elif literal_width == self.width:
668+
return NotEqualTo(name, pred.literal.value)
669+
else:
670+
return pred.as_unbound(name, self.transform(field_type)(pred.literal.value))
671+
else:
672+
# ProjectionUtil.truncateArrayStrict(name, pred, this);
673+
return _truncate_array_strict(name, pred, self.transform(field_type))
674+
elif isinstance(pred, BoundNotIn):
675+
return _set_apply_transform(name, pred, self.transform(field_type))
676+
else:
677+
return None
678+
593679
@property
594680
def width(self) -> int:
595681
return self._width
@@ -714,6 +800,9 @@ def result_type(self, source: IcebergType) -> StringType:
714800
def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]:
715801
return None
716802

803+
def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]:
804+
return None
805+
717806
def __repr__(self) -> str:
718807
"""Return the string representation of the UnknownTransform class."""
719808
return f"UnknownTransform(transform={repr(self._transform)})"
@@ -736,6 +825,9 @@ def result_type(self, source: IcebergType) -> IcebergType:
736825
def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]:
737826
return None
738827

828+
def strict_project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]:
829+
return None
830+
739831
def to_human_string(self, _: IcebergType, value: Optional[S]) -> str:
740832
return "null"
741833

@@ -766,6 +858,47 @@ def _truncate_number(
766858
return None
767859

768860

861+
def _truncate_number_strict(
862+
name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]]
863+
) -> Optional[UnboundPredicate[Any]]:
864+
boundary = pred.literal
865+
866+
if not isinstance(boundary, (LongLiteral, DecimalLiteral, DateLiteral, TimestampLiteral)):
867+
raise ValueError(f"Expected a numeric literal, got: {type(boundary)}")
868+
869+
if isinstance(pred, BoundLessThan):
870+
return LessThan(Reference(name), _transform_literal(transform, boundary))
871+
elif isinstance(pred, BoundLessThanOrEqual):
872+
return LessThan(Reference(name), _transform_literal(transform, boundary.increment())) # type: ignore
873+
elif isinstance(pred, BoundGreaterThan):
874+
return GreaterThan(Reference(name), _transform_literal(transform, boundary))
875+
elif isinstance(pred, BoundGreaterThanOrEqual):
876+
return GreaterThan(Reference(name), _transform_literal(transform, boundary.decrement())) # type: ignore
877+
elif isinstance(pred, BoundNotEqualTo):
878+
return EqualTo(Reference(name), _transform_literal(transform, boundary))
879+
elif isinstance(pred, BoundEqualTo):
880+
# there is no predicate that guarantees equality because adjacent longs transform to the
881+
# same value
882+
return None
883+
else:
884+
return None
885+
886+
887+
def _truncate_array_strict(
888+
name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]]
889+
) -> Optional[UnboundPredicate[Any]]:
890+
boundary = pred.literal
891+
892+
if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)):
893+
return LessThan(Reference(name), _transform_literal(transform, boundary))
894+
elif isinstance(pred, (BoundGreaterThan, BoundGreaterThanOrEqual)):
895+
return GreaterThan(Reference(name), _transform_literal(transform, boundary))
896+
if isinstance(pred, BoundNotEqualTo):
897+
return NotEqualTo(Reference(name), _transform_literal(transform, boundary))
898+
else:
899+
return None
900+
901+
769902
def _truncate_array(
770903
name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]]
771904
) -> Optional[UnboundPredicate[Any]]:
@@ -808,7 +941,8 @@ def _remove_transform(partition_name: str, pred: BoundPredicate[L]) -> UnboundPr
808941
def _set_apply_transform(name: str, pred: BoundSetPredicate[L], transform: Callable[[L], L]) -> UnboundPredicate[Any]:
809942
literals = pred.literals
810943
if isinstance(pred, BoundSetPredicate):
811-
return pred.as_unbound(Reference(name), {_transform_literal(transform, literal) for literal in literals})
944+
transformed_literals = {_transform_literal(transform, literal) for literal in literals}
945+
return pred.as_unbound(Reference(name=name), literals=transformed_literals)
812946
else:
813947
raise ValueError(f"Unknown BoundSetPredicate: {pred}")
814948

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
NestedField,
8080
StringType,
8181
StructType,
82+
UUIDType,
8283
)
8384
from pyiceberg.utils.datetime import datetime_to_millis
8485

@@ -1928,6 +1929,16 @@ def bound_reference_str() -> BoundReference[str]:
19281929
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))
19291930

19301931

1932+
@pytest.fixture
1933+
def bound_reference_binary() -> BoundReference[str]:
1934+
return BoundReference(field=NestedField(1, "field", BinaryType(), required=False), accessor=Accessor(position=0, inner=None))
1935+
1936+
1937+
@pytest.fixture
1938+
def bound_reference_uuid() -> BoundReference[str]:
1939+
return BoundReference(field=NestedField(1, "field", UUIDType(), required=False), accessor=Accessor(position=0, inner=None))
1940+
1941+
19311942
@pytest.fixture(scope="session")
19321943
def session_catalog() -> Catalog:
19331944
return load_catalog(

0 commit comments

Comments
 (0)