Skip to content

Commit 1fb1de2

Browse files
david-plRoger-luo
andauthored
Add Pauli algebra dialect example to docs (#365)
There's still an issue with a reference to `Fold`, that mkdocs can't resolve for some reason. --------- Co-authored-by: Xiu-zhe (Roger) Luo <[email protected]>
1 parent 2f4fb12 commit 1fb1de2

File tree

10 files changed

+875
-0
lines changed

10 files changed

+875
-0
lines changed

docs/cookbook/paulilang/pauli.md

Lines changed: 683 additions & 0 deletions
Large diffs are not rendered by default.

docs/scripts/katex.js

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
document$.subscribe(({ body }) => {
2+
renderMathInElement(body, {
3+
delimiters: [
4+
{ left: "$$", right: "$$", display: true },
5+
{ left: "$", right: "$", display: false },
6+
{ left: "\\(", right: "\\)", display: false },
7+
{ left: "\\[", right: "\\]", display: true }
8+
],
9+
})
10+
})

example/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ This folder contains examples of how to use the Kirin library. Each example is a
66

77
- `simple.py`: A simple example that demonstrates how to create a simple Kirin dialect group and its kernel.
88
- `food`: A more sophisticated example but without any domain specifics. It demonstrates how to create a new Kirin dialect and combine it with existing dialects with custom analysis and rewrites.
9+
- `pauli`: An example that implements a dialect with rewrites that simplifies products of Pauli matrices.
910

1011
## Examples outside this folder with more domain-specific contents
1112

example/pauli/dialect.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
_dialect = ir.Dialect("pauli")

example/pauli/group.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from dialect import _dialect
2+
3+
from kirin import ir
4+
from kirin.prelude import basic_no_opt
5+
6+
7+
@ir._dialectgroup(basic_no_opt.add(dialect=_dialect))
8+
def pauli(self):
9+
def run_pass(mt):
10+
# TODO
11+
pass
12+
13+
return run_pass

example/pauli/interp.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
3+
from kirin.interp import MethodTable, impl
4+
5+
from .stmts import X, Y, Z, Id
6+
from .dialect import _dialect
7+
8+
9+
@_dialect.register
10+
class PauliMethods(MethodTable):
11+
X_mat = np.array([[0, 1], [1, 0]])
12+
Y_mat = np.array([[0, -1j], [1j, 0]])
13+
Z_mat = np.array([[1, 0], [0, -1]])
14+
Id_mat = np.array([[1, 0], [0, 1]])
15+
16+
@impl(X) # (1)!
17+
def x(self, interp, frame, stmt: X):
18+
return (stmt.pre_factor * self.X_mat,)
19+
20+
@impl(Y)
21+
def y(self, interp, frame, stmt: Y):
22+
return (self.Y_mat * stmt.pre_factor,)
23+
24+
@impl(Z)
25+
def z(self, interp, frame, stmt: Z):
26+
return (self.Z_mat * stmt.pre_factor,)
27+
28+
@impl(Id)
29+
def id(self, interp, frame, stmt: Id):
30+
return (self.Id_mat * stmt.pre_factor,)

example/pauli/rewrite.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from dataclasses import dataclass
2+
3+
from kirin import ir
4+
from kirin.rewrite import abc, result
5+
from kirin.dialects import py
6+
7+
from .stmts import X, Y, Z, Id, PauliOperator
8+
9+
10+
@dataclass
11+
class RewritePauliMult(abc.RewriteRule):
12+
def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
13+
if not isinstance(node, py.binop.Mult):
14+
return result.RewriteResult()
15+
16+
if not isinstance(node.lhs.owner, PauliOperator) and not isinstance(
17+
node.rhs.owner, PauliOperator
18+
):
19+
return result.RewriteResult()
20+
21+
if isinstance(node.lhs.owner, py.Constant):
22+
assert isinstance(node.rhs.owner, PauliOperator) # make the linter happy
23+
new_op = self.number_pauli_mult(node.lhs.owner, node.rhs.owner)
24+
node.replace_by(new_op)
25+
return result.RewriteResult(has_done_something=True)
26+
elif isinstance(node.rhs.owner, py.Constant):
27+
assert isinstance(node.lhs.owner, PauliOperator) # make the linter happy
28+
new_op = self.number_pauli_mult(node.rhs.owner, node.lhs.owner)
29+
node.replace_by(new_op)
30+
return result.RewriteResult(has_done_something=True)
31+
32+
if not isinstance(node.lhs.owner, PauliOperator) or not isinstance(
33+
node.rhs.owner, PauliOperator
34+
):
35+
return result.RewriteResult()
36+
37+
new_op = self.pauli_pauli_mult(node.lhs.owner, node.rhs.owner)
38+
node.replace_by(new_op)
39+
return result.RewriteResult(has_done_something=True)
40+
41+
@staticmethod
42+
def number_pauli_mult(lhs: py.Constant, rhs: PauliOperator) -> PauliOperator:
43+
num = lhs.value.unwrap() * rhs.pre_factor
44+
return type(rhs)(pre_factor=num)
45+
46+
@staticmethod
47+
def pauli_pauli_mult(lhs: PauliOperator, rhs: PauliOperator) -> PauliOperator:
48+
num = rhs.pre_factor * lhs.pre_factor
49+
50+
if isinstance(lhs, type(rhs)):
51+
return Id(pre_factor=num)
52+
53+
if isinstance(lhs, type(rhs)):
54+
return Id(pre_factor=num)
55+
56+
if isinstance(lhs, Id):
57+
return type(rhs)(pre_factor=num)
58+
59+
if isinstance(rhs, Id):
60+
return type(lhs)(pre_factor=num)
61+
62+
if isinstance(lhs, X):
63+
if isinstance(rhs, Y):
64+
return Z(pre_factor=1j * num)
65+
elif isinstance(rhs, Z):
66+
return Y(pre_factor=-1j * num)
67+
68+
if isinstance(lhs, Y):
69+
if isinstance(rhs, X):
70+
return Z(pre_factor=-1j * num)
71+
elif isinstance(rhs, Z):
72+
return X(pre_factor=1j * num)
73+
74+
if isinstance(lhs, Z):
75+
if isinstance(rhs, Y):
76+
return X(pre_factor=-1j * num)
77+
elif isinstance(rhs, X):
78+
return Y(pre_factor=1j * num)
79+
80+
raise RuntimeError("How on earth did we end up here?")

example/pauli/script.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .group import pauli
2+
from .stmts import X, Y, Z
3+
4+
5+
@pauli
6+
def main():
7+
ex = (X() + 2 * Y()) * Z()
8+
return ex
9+
10+
11+
main.print()

example/pauli/stmts.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from numbers import Number
2+
3+
import numpy as np
4+
5+
from kirin import ir, types, lowering
6+
from kirin.decl import info, statement
7+
8+
from .dialect import _dialect
9+
10+
11+
@statement
12+
class PauliOperator(ir.Statement):
13+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
14+
pre_factor: Number = info.attribute(default=1)
15+
result: ir.ResultValue = info.result(types.PyClass(np.matrix))
16+
17+
18+
@statement(dialect=_dialect)
19+
class X(PauliOperator):
20+
pass
21+
22+
23+
@statement(dialect=_dialect)
24+
class Y(PauliOperator):
25+
pass
26+
27+
28+
@statement(dialect=_dialect)
29+
class Z(PauliOperator):
30+
pass
31+
32+
33+
@statement(dialect=_dialect)
34+
class Id(PauliOperator):
35+
pass

mkdocs.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ nav:
3535
- Advanced Rewriting: cookbook/foodlang/cf_rewrite.md
3636
- Food Pricing Analysis Pass: cookbook/foodlang/analysis.md
3737
- Receipt Codegen: cookbook/foodlang/codegen.md
38+
- Pauli Algebra: cookbook/paulilang/pauli.md
3839

3940
- Blog:
4041
- blog/index.md
@@ -144,6 +145,7 @@ plugins:
144145

145146
extra_css:
146147
- stylesheets/extra.css
148+
- https://unpkg.com/katex@0/dist/katex.min.css
147149

148150
markdown_extensions:
149151
- abbr
@@ -158,6 +160,8 @@ markdown_extensions:
158160
- pymdownx.tilde
159161
- pymdownx.tabbed:
160162
alternate_style: true
163+
- pymdownx.arithmatex:
164+
generic: true
161165

162166
copyright: Copyright &copy; 2024 Kirin contributors
163167

@@ -169,3 +173,8 @@ extra:
169173
link: https://x.com/QueraComputing
170174
- icon: fontawesome/brands/linkedin
171175
link: https://www.linkedin.com/company/quera-computing-inc/
176+
177+
extra_javascript:
178+
- scripts/katex.js
179+
- https://unpkg.com/katex@0/dist/katex.min.js
180+
- https://unpkg.com/katex@0/dist/contrib/auto-render.min.js

0 commit comments

Comments
 (0)