Skip to content

Commit 924ee73

Browse files
AniketsyFokko
andauthored
Make UnaryPredicate JSON Serializable (#2598)
#2522 This PR makes the `UnaryPredicate` class and its subclasses (`IsNul`l, `NotNul`l, `IsNaN`, `NotNaN`) JSON serializable using Pydantic - Adds unit tests to verify JSON serialization for IsNull and NotNull Please let me know if my approach or fix needs any improvements . I’m open to feedback and happy to make changes based on suggestions. Thankyou ! --------- Co-authored-by: Fokko Driesprong <[email protected]>
1 parent 9ce7619 commit 924ee73

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

pyiceberg/expressions/__init__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,20 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BooleanExpression
454454
def as_bound(self) -> Type[BoundPredicate[L]]: ...
455455

456456

457-
class UnaryPredicate(UnboundPredicate[Any], ABC):
457+
class UnaryPredicate(IcebergBaseModel, UnboundPredicate[Any], ABC):
458+
type: str
459+
460+
model_config = {"arbitrary_types_allowed": True}
461+
462+
def __init__(self, term: Union[str, UnboundTerm[Any]]):
463+
unbound = _to_unbound_term(term)
464+
super().__init__(term=unbound)
465+
466+
def __str__(self) -> str:
467+
"""Return the string representation of the UnaryPredicate class."""
468+
# Sort to make it deterministic
469+
return f"{str(self.__class__.__name__)}(term={str(self.term)})"
470+
458471
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate[Any]:
459472
bound_term = self.term.bind(schema, case_sensitive)
460473
return self.as_bound(bound_term)
@@ -513,6 +526,8 @@ def as_unbound(self) -> Type[NotNull]:
513526

514527

515528
class IsNull(UnaryPredicate):
529+
type: str = "is-null"
530+
516531
def __invert__(self) -> NotNull:
517532
"""Transform the Expression into its negated version."""
518533
return NotNull(self.term)
@@ -523,6 +538,8 @@ def as_bound(self) -> Type[BoundIsNull[L]]:
523538

524539

525540
class NotNull(UnaryPredicate):
541+
type: str = "not-null"
542+
526543
def __invert__(self) -> IsNull:
527544
"""Transform the Expression into its negated version."""
528545
return IsNull(self.term)
@@ -565,6 +582,8 @@ def as_unbound(self) -> Type[NotNaN]:
565582

566583

567584
class IsNaN(UnaryPredicate):
585+
type: str = "is-nan"
586+
568587
def __invert__(self) -> NotNaN:
569588
"""Transform the Expression into its negated version."""
570589
return NotNaN(self.term)
@@ -575,6 +594,8 @@ def as_bound(self) -> Type[BoundIsNaN[L]]:
575594

576595

577596
class NotNaN(UnaryPredicate):
597+
type: str = "not-nan"
598+
578599
def __invert__(self) -> IsNaN:
579600
"""Transform the Expression into its negated version."""
580601
return IsNaN(self.term)

tests/expressions/test_expressions.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def test_and() -> None:
694694
assert and_ == pickle.loads(pickle.dumps(and_))
695695

696696
with pytest.raises(ValueError, match="Expected BooleanExpression, got: abc"):
697-
null & "abc" # type: ignore
697+
null & "abc"
698698

699699

700700
def test_or() -> None:
@@ -711,7 +711,7 @@ def test_or() -> None:
711711
assert or_ == pickle.loads(pickle.dumps(or_))
712712

713713
with pytest.raises(ValueError, match="Expected BooleanExpression, got: abc"):
714-
null | "abc" # type: ignore
714+
null | "abc"
715715

716716

717717
def test_or_serialization() -> None:
@@ -791,6 +791,16 @@ def test_not_null() -> None:
791791
assert non_null == pickle.loads(pickle.dumps(non_null))
792792

793793

794+
def test_serialize_is_null() -> None:
795+
pred = IsNull(term="foo")
796+
assert pred.model_dump_json() == '{"term":"foo","type":"is-null"}'
797+
798+
799+
def test_serialize_not_null() -> None:
800+
pred = NotNull(term="foo")
801+
assert pred.model_dump_json() == '{"term":"foo","type":"not-null"}'
802+
803+
794804
def test_bound_is_nan(accessor: Accessor) -> None:
795805
# We need a FloatType here
796806
term = BoundReference[float](

0 commit comments

Comments
 (0)