Skip to content

Commit ce28e00

Browse files
authored
Make Literal Pydantic serializeable (#2575)
# Rationale for this change Resolves #2572 ## Are these changes tested? ## Are there any user-facing changes? <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent 2f9bb3e commit ce28e00

File tree

4 files changed

+63
-20
lines changed

4 files changed

+63
-20
lines changed

pyiceberg/expressions/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ def __new__( # type: ignore # pylint: disable=W0221
696696
if count == 0:
697697
return AlwaysFalse()
698698
elif count == 1:
699-
return EqualTo(term, next(iter(literals))) # type: ignore
699+
return EqualTo(term, next(iter(literals)))
700700
else:
701701
return super().__new__(cls)
702702

pyiceberg/expressions/literals.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
from typing import Any, Generic, Type
3131
from uuid import UUID
3232

33-
from pyiceberg.typedef import L
33+
from pydantic import Field, model_serializer
34+
35+
from pyiceberg.typedef import IcebergRootModel, L
3436
from pyiceberg.types import (
3537
BinaryType,
3638
BooleanType,
@@ -52,7 +54,9 @@
5254
date_str_to_days,
5355
date_to_days,
5456
datetime_to_micros,
57+
days_to_date,
5558
micros_to_days,
59+
micros_to_timestamp,
5660
time_str_to_micros,
5761
time_to_micros,
5862
timestamp_to_micros,
@@ -64,21 +68,24 @@
6468
UUID_BYTES_LENGTH = 16
6569

6670

67-
class Literal(Generic[L], ABC):
71+
class Literal(IcebergRootModel[L], Generic[L], ABC): # type: ignore
6872
"""Literal which has a value and can be converted between types."""
6973

70-
_value: L
74+
root: L = Field()
75+
76+
def __init__(self, value: L, value_type: Type[L], /, **data): # type: ignore
77+
if value is None:
78+
raise TypeError("Invalid literal value: None")
7179

72-
def __init__(self, value: L, value_type: Type[L]):
80+
super().__init__(value)
7381
if value is None or not isinstance(value, value_type):
7482
raise TypeError(f"Invalid literal value: {value!r} (not a {value_type})")
7583
if isinstance(value, float) and isnan(value):
7684
raise ValueError("Cannot create expression literal from NaN.")
77-
self._value = value
7885

7986
@property
8087
def value(self) -> L:
81-
return self._value
88+
return self.root
8289

8390
@singledispatchmethod
8491
@abstractmethod
@@ -136,25 +143,25 @@ def literal(value: L) -> Literal[L]:
136143
LongLiteral(123)
137144
"""
138145
if isinstance(value, float):
139-
return DoubleLiteral(value) # type: ignore
146+
return DoubleLiteral(value)
140147
elif isinstance(value, bool):
141148
return BooleanLiteral(value)
142149
elif isinstance(value, int):
143150
return LongLiteral(value)
144151
elif isinstance(value, str):
145152
return StringLiteral(value)
146153
elif isinstance(value, UUID):
147-
return UUIDLiteral(value.bytes) # type: ignore
154+
return UUIDLiteral(value.bytes)
148155
elif isinstance(value, bytes):
149156
return BinaryLiteral(value)
150157
elif isinstance(value, Decimal):
151158
return DecimalLiteral(value)
152159
elif isinstance(value, datetime):
153-
return TimestampLiteral(datetime_to_micros(value)) # type: ignore
160+
return TimestampLiteral(datetime_to_micros(value))
154161
elif isinstance(value, date):
155-
return DateLiteral(date_to_days(value)) # type: ignore
162+
return DateLiteral(date_to_days(value))
156163
elif isinstance(value, time):
157-
return TimeLiteral(time_to_micros(value)) # type: ignore
164+
return TimeLiteral(time_to_micros(value))
158165
else:
159166
raise TypeError(f"Invalid literal value: {repr(value)}")
160167

@@ -411,6 +418,10 @@ class DateLiteral(Literal[int]):
411418
def __init__(self, value: int) -> None:
412419
super().__init__(value, int)
413420

421+
@model_serializer
422+
def ser_model(self) -> date:
423+
return days_to_date(self.root)
424+
414425
def increment(self) -> Literal[int]:
415426
return DateLiteral(self.value + 1)
416427

@@ -443,6 +454,10 @@ class TimestampLiteral(Literal[int]):
443454
def __init__(self, value: int) -> None:
444455
super().__init__(value, int)
445456

457+
@model_serializer
458+
def ser_model(self) -> str:
459+
return micros_to_timestamp(self.root).isoformat()
460+
446461
def increment(self) -> Literal[int]:
447462
return TimestampLiteral(self.value + 1)
448463

@@ -635,6 +650,10 @@ class UUIDLiteral(Literal[bytes]):
635650
def __init__(self, value: bytes) -> None:
636651
super().__init__(value, bytes)
637652

653+
@model_serializer
654+
def ser_model(self) -> UUID:
655+
return UUID(bytes=self.root)
656+
638657
@singledispatchmethod
639658
def to(self, type_var: IcebergType) -> Literal: # type: ignore
640659
raise TypeError(f"Cannot convert UUIDLiteral into {type_var}")
@@ -661,6 +680,10 @@ class FixedLiteral(Literal[bytes]):
661680
def __init__(self, value: bytes) -> None:
662681
super().__init__(value, bytes)
663682

683+
@model_serializer
684+
def ser_model(self) -> str:
685+
return self.root.hex()
686+
664687
@singledispatchmethod
665688
def to(self, type_var: IcebergType) -> Literal: # type: ignore
666689
raise TypeError(f"Cannot convert FixedLiteral into {type_var}")
@@ -692,6 +715,10 @@ class BinaryLiteral(Literal[bytes]):
692715
def __init__(self, value: bytes) -> None:
693716
super().__init__(value, bytes)
694717

718+
@model_serializer
719+
def ser_model(self) -> str:
720+
return self.root.hex()
721+
695722
@singledispatchmethod
696723
def to(self, type_var: IcebergType) -> Literal: # type: ignore
697724
raise TypeError(f"Cannot convert BinaryLiteral into {type_var}")

tests/expressions/test_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def data_file_nan() -> DataFile:
683683

684684

685685
def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_file_nan: Schema, data_file_nan: DataFile) -> None:
686-
for operator in [LessThan, LessThanOrEqual]:
686+
for operator in [LessThan, LessThanOrEqual]: # type: ignore
687687
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
688688
assert not should_read, "Should not match: all nan column doesn't contain number"
689689

@@ -711,7 +711,7 @@ def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_f
711711
def test_inclusive_metrics_evaluator_greater_than_and_greater_than_equal(
712712
schema_data_file_nan: Schema, data_file_nan: DataFile
713713
) -> None:
714-
for operator in [GreaterThan, GreaterThanOrEqual]:
714+
for operator in [GreaterThan, GreaterThanOrEqual]: # type: ignore
715715
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
716716
assert not should_read, "Should not match: all nan column doesn't contain number"
717717

tests/expressions/test_literals.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ def test_string_to_time_literal() -> None:
319319

320320
avro_val = 51661919000
321321

322-
assert isinstance(time_lit, TimeLiteral) # type: ignore
323-
assert avro_val == time_lit.value # type: ignore
322+
assert isinstance(time_lit, TimeLiteral)
323+
assert avro_val == time_lit.value
324324

325325

326326
def test_string_to_timestamp_literal() -> None:
@@ -428,8 +428,8 @@ def test_python_date_conversion() -> None:
428428

429429
from_str_lit = literal(one_day_str).to(DateType())
430430

431-
assert isinstance(from_str_lit, DateLiteral) # type: ignore
432-
assert from_str_lit.value == 19079 # type: ignore
431+
assert isinstance(from_str_lit, DateLiteral)
432+
assert from_str_lit.value == 19079
433433

434434

435435
@pytest.mark.parametrize(
@@ -911,15 +911,15 @@ def test_uuid_to_fixed() -> None:
911911
with pytest.raises(TypeError) as e:
912912
uuid_literal.to(FixedType(15))
913913
assert "Cannot convert UUIDLiteral into fixed[15], different length: 15 <> 16" in str(e.value)
914-
assert isinstance(fixed_literal, FixedLiteral) # type: ignore
914+
assert isinstance(fixed_literal, FixedLiteral)
915915

916916

917917
def test_uuid_to_binary() -> None:
918918
test_uuid = uuid.uuid4()
919919
uuid_literal = literal(test_uuid)
920920
binary_literal = uuid_literal.to(BinaryType())
921921
assert test_uuid.bytes == binary_literal.value
922-
assert isinstance(binary_literal, BinaryLiteral) # type: ignore
922+
assert isinstance(binary_literal, BinaryLiteral)
923923

924924

925925
def test_literal_from_datetime() -> None:
@@ -930,6 +930,22 @@ def test_literal_from_date() -> None:
930930
assert isinstance(literal(datetime.date.today()), DateLiteral)
931931

932932

933+
def test_to_json() -> None:
934+
assert literal(True).model_dump_json() == "true"
935+
assert literal(float(123)).model_dump_json() == "123.0"
936+
assert literal(123).model_dump_json() == "123"
937+
assert literal("vo").model_dump_json() == '"vo"'
938+
assert (
939+
literal(uuid.UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7")).model_dump_json() == '"f79c3e09-677c-4bbd-a479-3f349cb785e7"'
940+
)
941+
assert literal(bytes([0x01, 0x02, 0x03])).model_dump_json() == '"010203"'
942+
assert literal(Decimal("19.25")).model_dump_json() == '"19.25"'
943+
assert literal(datetime.date.fromisoformat("2022-03-28")).model_dump_json() == '"2022-03-28"'
944+
assert (
945+
literal(datetime.datetime.fromisoformat("1970-11-22T00:00:00.000000+00:00")).model_dump_json() == '"1970-11-22T00:00:00"'
946+
)
947+
948+
933949
# __ __ ___
934950
# | \/ |_ _| _ \_ _
935951
# | |\/| | || | _/ || |

0 commit comments

Comments
 (0)