Skip to content

Commit 617e258

Browse files
authored
feat: make LiteralPredicate serializable via internal IcebergBaseModel (#2561)
<!-- Thanks for opening a pull request! --> <!-- In the case this PR will resolve an issue, please replace ${GITHUB_ISSUE_ID} below with the actual Github issue id. --> Closes [#2523](#2523) # Rationale for this change ### Spec alignment `LiteralPredicate.type` uses the same enum as the REST OpenAPI `LiteralExpression.type`: `"lt" | "lt-eq" | "gt" | "gt-eq" | "eq" | "not-eq" | "starts-with" | "not-starts-with"`. Source: OpenAPI spec (LiteralExpression). Ref: https://github.com/apache/iceberg/blob/b987e60bbd581d6e9e583107d5a85022261ff0d8/open-api/rest-catalog-open-api.yaml#L2264 ## Are these changes tested? yes ## Are there any user-facing changes? <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent 9bff326 commit 617e258

File tree

4 files changed

+86
-29
lines changed

4 files changed

+86
-29
lines changed

pyiceberg/expressions/__init__.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -743,12 +743,18 @@ def as_bound(self) -> Type[BoundNotIn[L]]:
743743
return BoundNotIn[L]
744744

745745

746-
class LiteralPredicate(UnboundPredicate[L], ABC):
747-
literal: Literal[L]
746+
class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC):
747+
type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type")
748+
term: UnboundTerm[Any]
749+
value: Literal[L] = Field()
750+
model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True)
748751

749-
def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]): # pylint: disable=W0621
750-
super().__init__(term)
751-
self.literal = _to_literal(literal) # pylint: disable=W0621
752+
def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]):
753+
super().__init__(term=_to_unbound_term(term), value=_to_literal(literal)) # type: ignore[call-arg]
754+
755+
@property
756+
def literal(self) -> Literal[L]:
757+
return self.value
752758

753759
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]:
754760
bound_term = self.term.bind(schema, case_sensitive)
@@ -773,6 +779,10 @@ def __eq__(self, other: Any) -> bool:
773779
return self.term == other.term and self.literal == other.literal
774780
return False
775781

782+
def __str__(self) -> str:
783+
"""Return the string representation of the LiteralPredicate class."""
784+
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"
785+
776786
def __repr__(self) -> str:
777787
"""Return the string representation of the LiteralPredicate class."""
778788
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"
@@ -886,6 +896,8 @@ def as_unbound(self) -> Type[NotStartsWith[L]]:
886896

887897

888898
class EqualTo(LiteralPredicate[L]):
899+
type: TypingLiteral["eq"] = Field(default="eq", alias="type")
900+
889901
def __invert__(self) -> NotEqualTo[L]:
890902
"""Transform the Expression into its negated version."""
891903
return NotEqualTo[L](self.term, self.literal)
@@ -896,6 +908,8 @@ def as_bound(self) -> Type[BoundEqualTo[L]]:
896908

897909

898910
class NotEqualTo(LiteralPredicate[L]):
911+
type: TypingLiteral["not-eq"] = Field(default="not-eq", alias="type")
912+
899913
def __invert__(self) -> EqualTo[L]:
900914
"""Transform the Expression into its negated version."""
901915
return EqualTo[L](self.term, self.literal)
@@ -906,6 +920,8 @@ def as_bound(self) -> Type[BoundNotEqualTo[L]]:
906920

907921

908922
class LessThan(LiteralPredicate[L]):
923+
type: TypingLiteral["lt"] = Field(default="lt", alias="type")
924+
909925
def __invert__(self) -> GreaterThanOrEqual[L]:
910926
"""Transform the Expression into its negated version."""
911927
return GreaterThanOrEqual[L](self.term, self.literal)
@@ -916,6 +932,8 @@ def as_bound(self) -> Type[BoundLessThan[L]]:
916932

917933

918934
class GreaterThanOrEqual(LiteralPredicate[L]):
935+
type: TypingLiteral["gt-eq"] = Field(default="gt-eq", alias="type")
936+
919937
def __invert__(self) -> LessThan[L]:
920938
"""Transform the Expression into its negated version."""
921939
return LessThan[L](self.term, self.literal)
@@ -926,6 +944,8 @@ def as_bound(self) -> Type[BoundGreaterThanOrEqual[L]]:
926944

927945

928946
class GreaterThan(LiteralPredicate[L]):
947+
type: TypingLiteral["gt"] = Field(default="gt", alias="type")
948+
929949
def __invert__(self) -> LessThanOrEqual[L]:
930950
"""Transform the Expression into its negated version."""
931951
return LessThanOrEqual[L](self.term, self.literal)
@@ -936,6 +956,8 @@ def as_bound(self) -> Type[BoundGreaterThan[L]]:
936956

937957

938958
class LessThanOrEqual(LiteralPredicate[L]):
959+
type: TypingLiteral["lt-eq"] = Field(default="lt-eq", alias="type")
960+
939961
def __invert__(self) -> GreaterThan[L]:
940962
"""Transform the Expression into its negated version."""
941963
return GreaterThan[L](self.term, self.literal)
@@ -946,6 +968,8 @@ def as_bound(self) -> Type[BoundLessThanOrEqual[L]]:
946968

947969

948970
class StartsWith(LiteralPredicate[L]):
971+
type: TypingLiteral["starts-with"] = Field(default="starts-with", alias="type")
972+
949973
def __invert__(self) -> NotStartsWith[L]:
950974
"""Transform the Expression into its negated version."""
951975
return NotStartsWith[L](self.term, self.literal)
@@ -956,6 +980,8 @@ def as_bound(self) -> Type[BoundStartsWith[L]]:
956980

957981

958982
class NotStartsWith(LiteralPredicate[L]):
983+
type: TypingLiteral["not-starts-with"] = Field(default="not-starts-with", alias="type")
984+
959985
def __invert__(self) -> StartsWith[L]:
960986
"""Transform the Expression into its negated version."""
961987
return StartsWith[L](self.term, self.literal)

pyiceberg/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _try_import(module_name: str, extras_name: Optional[str] = None) -> types.Mo
120120
raise NotInstalledError(msg) from None
121121

122122

123-
def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]:
123+
def _transform_literal(func: Callable[[Any], Any], lit: Literal[L]) -> Literal[L]:
124124
"""Small helper to upwrap the value from the literal, and wrap it again."""
125125
return literal(func(lit.value))
126126

tests/expressions/test_evaluator.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pyiceberg.conversions import to_bytes
2323
from pyiceberg.expressions import (
2424
And,
25+
BooleanExpression,
2526
EqualTo,
2627
GreaterThan,
2728
GreaterThanOrEqual,
@@ -30,6 +31,7 @@
3031
IsNull,
3132
LessThan,
3233
LessThanOrEqual,
34+
LiteralPredicate,
3335
Not,
3436
NotEqualTo,
3537
NotIn,
@@ -301,7 +303,7 @@ def test_missing_stats() -> None:
301303
upper_bounds=None,
302304
)
303305

304-
expressions = [
306+
expressions: list[BooleanExpression] = [
305307
LessThan("no_stats", 5),
306308
LessThanOrEqual("no_stats", 30),
307309
EqualTo("no_stats", 70),
@@ -324,7 +326,7 @@ def test_zero_record_file_stats(schema_data_file: Schema) -> None:
324326
file_path="file_1.parquet", file_format=FileFormat.PARQUET, partition=Record(), record_count=0
325327
)
326328

327-
expressions = [
329+
expressions: list[BooleanExpression] = [
328330
LessThan("no_stats", 5),
329331
LessThanOrEqual("no_stats", 30),
330332
EqualTo("no_stats", 70),
@@ -683,26 +685,27 @@ def data_file_nan() -> DataFile:
683685

684686

685687
def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_file_nan: Schema, data_file_nan: DataFile) -> None:
686-
for operator in [LessThan, LessThanOrEqual]: # type: ignore
687-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
688+
operators: tuple[type[LiteralPredicate[Any]], ...] = (LessThan, LessThanOrEqual)
689+
for operator in operators:
690+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan)
688691
assert not should_read, "Should not match: all nan column doesn't contain number"
689692

690-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
693+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan)
691694
assert not should_read, "Should not match: 1 is smaller than lower bound"
692695

693-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) # type: ignore[arg-type]
696+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan)
694697
assert should_read, "Should match: 10 is larger than lower bound"
695698

696-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
699+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan)
697700
assert should_read, "Should match: no visibility"
698701

699-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
702+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan)
700703
assert not should_read, "Should not match: all nan column doesn't contain number"
701704

702-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
705+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan)
703706
assert not should_read, "Should not match: 1 is smaller than lower bound"
704707

705-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( # type: ignore[arg-type]
708+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval(
706709
data_file_nan
707710
)
708711
assert should_read, "Should match: 10 larger than lower bound"
@@ -711,31 +714,32 @@ def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_f
711714
def test_inclusive_metrics_evaluator_greater_than_and_greater_than_equal(
712715
schema_data_file_nan: Schema, data_file_nan: DataFile
713716
) -> None:
714-
for operator in [GreaterThan, GreaterThanOrEqual]: # type: ignore
715-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
717+
operators: tuple[type[LiteralPredicate[Any]], ...] = (GreaterThan, GreaterThanOrEqual)
718+
for operator in operators:
719+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan)
716720
assert not should_read, "Should not match: all nan column doesn't contain number"
717721

718-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
722+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan)
719723
assert should_read, "Should match: upper bound is larger than 1"
720724

721-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) # type: ignore[arg-type]
725+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan)
722726
assert should_read, "Should match: upper bound is larger than 10"
723727

724-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
728+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan)
725729
assert should_read, "Should match: no visibility"
726730

727-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
731+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan)
728732
assert not should_read, "Should not match: all nan column doesn't contain number"
729733

730-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
734+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan)
731735
assert should_read, "Should match: 1 is smaller than upper bound"
732736

733-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( # type: ignore[arg-type]
737+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval(
734738
data_file_nan
735739
)
736740
assert should_read, "Should match: 10 is smaller than upper bound"
737741

738-
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 30)).eval(data_file_nan) # type: ignore[arg-type]
742+
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 30)).eval(data_file_nan)
739743
assert not should_read, "Should not match: 30 is greater than upper bound"
740744

741745

@@ -1162,7 +1166,7 @@ def test_strict_missing_stats(strict_data_file_schema: Schema, strict_data_file_
11621166
upper_bounds=None,
11631167
)
11641168

1165-
expressions = [
1169+
expressions: list[BooleanExpression] = [
11661170
LessThan("no_stats", 5),
11671171
LessThanOrEqual("no_stats", 30),
11681172
EqualTo("no_stats", 70),
@@ -1185,7 +1189,7 @@ def test_strict_zero_record_file_stats(strict_data_file_schema: Schema) -> None:
11851189
file_path="file_1.parquet", file_format=FileFormat.PARQUET, partition=Record(), record_count=0
11861190
)
11871191

1188-
expressions = [
1192+
expressions: list[BooleanExpression] = [
11891193
LessThan("no_stats", 5),
11901194
LessThanOrEqual("no_stats", 30),
11911195
EqualTo("no_stats", 70),

tests/expressions/test_expressions.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,22 @@
5050
IsNull,
5151
LessThan,
5252
LessThanOrEqual,
53+
LiteralPredicate,
5354
Not,
5455
NotEqualTo,
5556
NotIn,
5657
NotNaN,
5758
NotNull,
59+
NotStartsWith,
5860
Or,
5961
Reference,
62+
StartsWith,
6063
UnboundPredicate,
6164
)
6265
from pyiceberg.expressions.literals import Literal, literal
6366
from pyiceberg.expressions.visitors import _from_byte_buffer
6467
from pyiceberg.schema import Accessor, Schema
65-
from pyiceberg.typedef import Record
68+
from pyiceberg.typedef import L, Record
6669
from pyiceberg.types import (
6770
DecimalType,
6871
DoubleType,
@@ -915,6 +918,7 @@ def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None:
915918

916919
def test_equal_to() -> None:
917920
equal_to = EqualTo(Reference("a"), literal("a"))
921+
assert equal_to.model_dump_json() == '{"term":"a","type":"eq","value":"a"}'
918922
assert str(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
919923
assert repr(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
920924
assert equal_to == eval(repr(equal_to))
@@ -923,6 +927,7 @@ def test_equal_to() -> None:
923927

924928
def test_not_equal_to() -> None:
925929
not_equal_to = NotEqualTo(Reference("a"), literal("a"))
930+
assert not_equal_to.model_dump_json() == '{"term":"a","type":"not-eq","value":"a"}'
926931
assert str(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
927932
assert repr(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
928933
assert not_equal_to == eval(repr(not_equal_to))
@@ -931,6 +936,7 @@ def test_not_equal_to() -> None:
931936

932937
def test_greater_than_or_equal_to() -> None:
933938
greater_than_or_equal_to = GreaterThanOrEqual(Reference("a"), literal("a"))
939+
assert greater_than_or_equal_to.model_dump_json() == '{"term":"a","type":"gt-eq","value":"a"}'
934940
assert str(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
935941
assert repr(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
936942
assert greater_than_or_equal_to == eval(repr(greater_than_or_equal_to))
@@ -939,6 +945,7 @@ def test_greater_than_or_equal_to() -> None:
939945

940946
def test_greater_than() -> None:
941947
greater_than = GreaterThan(Reference("a"), literal("a"))
948+
assert greater_than.model_dump_json() == '{"term":"a","type":"gt","value":"a"}'
942949
assert str(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
943950
assert repr(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
944951
assert greater_than == eval(repr(greater_than))
@@ -947,6 +954,7 @@ def test_greater_than() -> None:
947954

948955
def test_less_than() -> None:
949956
less_than = LessThan(Reference("a"), literal("a"))
957+
assert less_than.model_dump_json() == '{"term":"a","type":"lt","value":"a"}'
950958
assert str(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
951959
assert repr(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
952960
assert less_than == eval(repr(less_than))
@@ -955,12 +963,23 @@ def test_less_than() -> None:
955963

956964
def test_less_than_or_equal() -> None:
957965
less_than_or_equal = LessThanOrEqual(Reference("a"), literal("a"))
966+
assert less_than_or_equal.model_dump_json() == '{"term":"a","type":"lt-eq","value":"a"}'
958967
assert str(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
959968
assert repr(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
960969
assert less_than_or_equal == eval(repr(less_than_or_equal))
961970
assert less_than_or_equal == pickle.loads(pickle.dumps(less_than_or_equal))
962971

963972

973+
def test_starts_with() -> None:
974+
starts_with = StartsWith(Reference("a"), literal("a"))
975+
assert starts_with.model_dump_json() == '{"term":"a","type":"starts-with","value":"a"}'
976+
977+
978+
def test_not_starts_with() -> None:
979+
not_starts_with = NotStartsWith(Reference("a"), literal("a"))
980+
assert not_starts_with.model_dump_json() == '{"term":"a","type":"not-starts-with","value":"a"}'
981+
982+
964983
def test_bound_reference_eval(table_schema_simple: Schema) -> None:
965984
"""Test creating a BoundReference and evaluating it on a StructProtocol"""
966985
struct = Record("foovalue", 123, True)
@@ -1199,7 +1218,15 @@ def test_bind_ambiguous_name() -> None:
11991218
# |_| |_|\_, |_| \_, |
12001219
# |__/ |__/
12011220

1202-
assert_type(EqualTo("a", "b"), EqualTo[str])
1221+
1222+
def _assert_literal_predicate_type(expr: LiteralPredicate[L]) -> None:
1223+
assert_type(expr, LiteralPredicate[L])
1224+
1225+
1226+
_assert_literal_predicate_type(EqualTo("a", "b"))
1227+
_assert_literal_predicate_type(In("a", ("a", "b", "c")))
1228+
_assert_literal_predicate_type(In("a", (1, 2, 3)))
1229+
_assert_literal_predicate_type(NotIn("a", ("a", "b", "c")))
12031230
assert_type(In("a", ("a", "b", "c")), In[str])
12041231
assert_type(In("a", (1, 2, 3)), In[int])
12051232
assert_type(NotIn("a", ("a", "b", "c")), NotIn[str])

0 commit comments

Comments
 (0)