Skip to content

Commit 1bc9f87

Browse files
Serializer framework for Kirin (#492)
Serializer framework. With base Serializer, JSONSerializer, and BinarySerializer. --------- Co-authored-by: kaihsin <[email protected]>
1 parent 0e67ee5 commit 1bc9f87

File tree

26 files changed

+1666
-11
lines changed

26 files changed

+1666
-11
lines changed

src/kirin/analysis/const/lattice.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def is_subseteq(self, other: Result) -> bool:
5151
def __hash__(self) -> int:
5252
return id(self)
5353

54+
def is_structurally_equal(
55+
self, other: ir.Attribute, context: dict | None = None
56+
) -> bool:
57+
return isinstance(other, Unknown)
58+
5459

5560
@final
5661
@dataclass
@@ -63,6 +68,11 @@ def is_subseteq(self, other: Result) -> bool:
6368
def __hash__(self) -> int:
6469
return id(self)
6570

71+
def is_structurally_equal(
72+
self, other: ir.Attribute, context: dict | None = None
73+
) -> bool:
74+
return isinstance(other, Bottom)
75+
6676

6777
@final
6878
@dataclass
@@ -85,6 +95,15 @@ def __hash__(self) -> int:
8595
# the data is guaranteed to be unique.
8696
return id(self)
8797

98+
def is_structurally_equal(
99+
self, other: ir.Attribute, context: dict | None = None
100+
) -> bool:
101+
if not isinstance(other, Value):
102+
return False
103+
if isinstance(self.data, ir.Attribute) and isinstance(other.data, ir.Attribute):
104+
return self.data.is_structurally_equal(other.data, context=context)
105+
return self.data.is_structurally_equal(other.data, context=context)
106+
88107

89108
@dataclass
90109
class PartialConst(Result):
@@ -158,6 +177,18 @@ def is_subseteq_Value(self, other: Value) -> bool:
158177
def __hash__(self) -> int:
159178
return hash(self.data)
160179

180+
def is_structurally_equal(
181+
self, other: ir.Attribute, context: dict | None = None
182+
) -> bool:
183+
if not isinstance(other, PartialTuple):
184+
return False
185+
if len(self.data) != len(other.data):
186+
return False
187+
return all(
188+
x.is_structurally_equal(y, context=context)
189+
for x, y in zip(self.data, other.data)
190+
)
191+
161192

162193
@final
163194
@dataclass
@@ -230,3 +261,17 @@ def meet(self, other: Result) -> Result:
230261
tuple(x.meet(y) for x, y in zip(self.captured, other.captured)),
231262
self.argnames,
232263
)
264+
265+
def is_structurally_equal(
266+
self, other: ir.Attribute, context: dict | None = None
267+
) -> bool:
268+
return (
269+
isinstance(other, PartialLambda)
270+
and self.code.is_structurally_equal(other.code, context=context)
271+
and self.argnames == other.argnames
272+
and len(self.captured) == len(other.captured)
273+
and all(
274+
x.is_structurally_equal(y, context=context)
275+
for x, y in zip(self.captured, other.captured)
276+
)
277+
)

src/kirin/dialects/func/attrs.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
from typing import Generic, TypeVar
1+
from typing import TYPE_CHECKING, Generic, TypeVar
22
from dataclasses import dataclass
33

44
from kirin import types
55
from kirin.ir import Method, Attribute
66
from kirin.print.printer import Printer
7+
from kirin.serialization.base.serializationunit import SerializationUnit
8+
9+
if TYPE_CHECKING:
10+
from kirin.serialization.base.serializer import Serializer
11+
from kirin.serialization.base.deserializer import Deserializer
712

813
from ._dialect import dialect
914

@@ -38,3 +43,24 @@ def print_impl(self, printer: Printer) -> None:
3843
printer.print_seq(self.inputs, delim=", ", prefix="(", suffix=")")
3944
printer.plain_print(" -> ")
4045
printer.print(self.output)
46+
47+
def __eq__(self, value: object) -> bool:
48+
if not isinstance(value, Signature):
49+
return False
50+
return self.inputs == value.inputs and self.output == value.output
51+
52+
def serialize(self, serializer: "Serializer") -> "SerializationUnit":
53+
return serializer.serialize_signature(self)
54+
55+
def is_structurally_equal(
56+
self,
57+
other: Attribute,
58+
context: dict | None = None,
59+
) -> bool:
60+
return self == other
61+
62+
@classmethod
63+
def deserialize(
64+
cls, serUnit: "SerializationUnit", deserializer: "Deserializer"
65+
) -> "Signature":
66+
return deserializer.deserialize_signature(serUnit)

src/kirin/dialects/ilist/runtime.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# TODO: replace with something faster
2-
from typing import Any, Generic, TypeVar, overload
2+
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
33
from dataclasses import dataclass
44
from collections.abc import Sequence
55

66
from kirin import ir, types
77
from kirin.print.printer import Printer
8+
from kirin.serialization.base.serializationunit import SerializationUnit
9+
10+
if TYPE_CHECKING:
11+
from kirin.serialization.base.serializer import Serializer
12+
from kirin.serialization.base.deserializer import Deserializer
813

914
T = TypeVar("T")
1015
L = TypeVar("L")
@@ -85,8 +90,25 @@ def unwrap(self) -> Sequence[T]:
8590
return self
8691

8792
def print_impl(self, printer: Printer) -> None:
88-
printer.plain_print("IList(")
8993
printer.print_seq(
9094
self.data, delim=", ", prefix="[", suffix="]", emit=printer.plain_print
9195
)
9296
printer.plain_print(")")
97+
98+
def is_structurally_equal(
99+
self, other: ir.Attribute, context: dict | None = None
100+
) -> bool:
101+
return (
102+
isinstance(other, IList)
103+
and self.data == other.data
104+
and self.elem.is_equal(other.elem)
105+
)
106+
107+
def serialize(self, serializer: "Serializer") -> "SerializationUnit":
108+
return serializer.serialize_ilist(self)
109+
110+
@classmethod
111+
def deserialize(
112+
cls, serUnit: "SerializationUnit", deserializer: "Deserializer"
113+
) -> "IList":
114+
return deserializer.deserialize_ilist(serUnit)

src/kirin/dialects/py/slice.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ def __hash__(self):
8989
def print_impl(self, printer: Printer) -> None:
9090
return printer.plain_print(f"slice({self.start}, {self.stop}, {self.step})")
9191

92+
def is_structurally_equal(
93+
self, other: ir.Attribute, context: dict | None = None
94+
) -> bool:
95+
return (
96+
isinstance(other, SliceAttribute)
97+
and self.start == other.start
98+
and self.stop == other.stop
99+
and self.step == other.step
100+
)
101+
92102

93103
@dialect.register
94104
class Concrete(interp.MethodTable):

src/kirin/idtable.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,14 @@ def __getitem__(self, value: T) -> str:
5555
return self.table[value]
5656
else:
5757
return self.add(value)
58+
59+
def __len__(self) -> int:
60+
return len(self.table)
61+
62+
def size(self) -> int:
63+
return len(self)
64+
65+
def clear(self) -> None:
66+
self.table.clear()
67+
self.name_count.clear()
68+
self.next_id = 0

src/kirin/ir/attrs/abc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from typing import TYPE_CHECKING, TypeVar, ClassVar, Optional
33
from dataclasses import field, dataclass
44

5+
from typing_extensions import Self
6+
57
from kirin.print import Printable
68
from kirin.ir.traits import Trait
79
from kirin.lattice.abc import LatticeMeta, SingletonMeta
@@ -53,6 +55,10 @@ def __hash__(self) -> int: ...
5355
@abstractmethod
5456
def __eq__(self, value: object) -> bool: ...
5557

58+
@abstractmethod
59+
def is_structurally_equal(self, other: Self, context: dict | None = None) -> bool:
60+
return self == other
61+
5662
@classmethod
5763
def has_trait(cls, trait_type: type[Trait["Attribute"]]) -> bool:
5864
"""Check if the Statement has a specific trait.

src/kirin/ir/attrs/py.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1-
from typing import TypeVar
1+
from typing import TYPE_CHECKING, Any, Type, TypeVar
22
from dataclasses import dataclass
33

4+
from typing_extensions import Protocol, runtime_checkable
5+
46
from kirin.print import Printer
7+
from kirin.ir.attrs.abc import Attribute
8+
from kirin.serialization.base.serializationunit import SerializationUnit
9+
10+
if TYPE_CHECKING:
11+
from kirin.serialization.base.serializer import Serializer
12+
from kirin.serialization.base.deserializer import Deserializer
513

614
from .data import Data
715
from .types import PyClass, TypeAttribute
@@ -51,3 +59,32 @@ def print_impl(self, printer: Printer) -> None:
5159

5260
def unwrap(self) -> T:
5361
return self.data
62+
63+
def is_structurally_equal(
64+
self, other: Attribute, context: dict | None = None
65+
) -> bool:
66+
if not isinstance(other, PyAttr):
67+
return False
68+
if self.type != other.type:
69+
return False
70+
if isinstance(self.data, StructurallyEqual) and isinstance(
71+
other.data, StructurallyEqual
72+
):
73+
return self.data.is_structurally_equal(other.data, context=context)
74+
return self.data == other.data
75+
76+
def serialize(self, serializer: "Serializer") -> "SerializationUnit":
77+
return serializer.serialize_pyattr(self)
78+
79+
@classmethod
80+
def deserialize(
81+
cls: Type["PyAttr"], serUnit: "SerializationUnit", deserializer: "Deserializer"
82+
) -> "PyAttr":
83+
return deserializer.deserialize_pyattr(serUnit)
84+
85+
86+
@runtime_checkable
87+
class StructurallyEqual(Protocol):
88+
def is_structurally_equal(
89+
self, other: Any, context: dict | None = None
90+
) -> bool: ...

0 commit comments

Comments
 (0)