Skip to content

Commit 80af27d

Browse files
authored
add webinar example (#445)
1 parent 95b478e commit 80af27d

File tree

2 files changed

+254
-0
lines changed

2 files changed

+254
-0
lines changed

example/quantum/script.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# type: ignore
2+
3+
# [section]
4+
from enum import Enum
5+
from typing import ClassVar
6+
from dataclasses import dataclass
7+
from qulacs import QuantumState
8+
9+
10+
# this could be your own class implementing the runtime in whatever way you want
11+
@dataclass
12+
class Qubit:
13+
count: ClassVar[int] = 0 # class variable to count qubits
14+
id: int
15+
16+
def __init__(self):
17+
self.id = Qubit.count
18+
Qubit.count += 1
19+
20+
21+
# some your own classes
22+
class Basis(Enum):
23+
X = "X"
24+
Y = "Y"
25+
Z = "Z"
26+
27+
28+
# [section]
29+
from kirin import ir, types, lowering
30+
from kirin.decl import statement, info
31+
from kirin.prelude import basic
32+
33+
# our language definitions and compiler begins
34+
dialect = ir.Dialect("quantum")
35+
QubitType = types.PyClass(Qubit)
36+
StateType = types.PyClass(QuantumState)
37+
38+
39+
@statement(dialect=dialect)
40+
class NewQubit(ir.Statement):
41+
traits = frozenset({lowering.FromPythonCall()})
42+
state: ir.SSAValue = info.argument(
43+
StateType
44+
) # we can use Python objects as arguments
45+
qubit: ir.ResultValue = info.result(QubitType)
46+
47+
48+
@statement(dialect=dialect)
49+
class X(ir.Statement):
50+
traits = frozenset({lowering.FromPythonCall()})
51+
state: ir.SSAValue = info.argument(StateType)
52+
qubit: ir.SSAValue = info.argument(QubitType)
53+
54+
55+
@statement(dialect=dialect)
56+
class H(ir.Statement):
57+
traits = frozenset({lowering.FromPythonCall()})
58+
state: ir.SSAValue = info.argument(StateType)
59+
qubit: ir.SSAValue = info.argument(QubitType)
60+
61+
62+
@statement(dialect=dialect)
63+
class CX(ir.Statement):
64+
traits = frozenset({lowering.FromPythonCall()})
65+
state: ir.SSAValue = info.argument(StateType)
66+
control: ir.SSAValue = info.argument(QubitType)
67+
target: ir.SSAValue = info.argument(QubitType)
68+
69+
70+
@statement(dialect=dialect)
71+
class CZ(ir.Statement):
72+
traits = frozenset({lowering.FromPythonCall()})
73+
state: ir.SSAValue = info.argument(StateType)
74+
control: ir.SSAValue = info.argument(QubitType)
75+
target: ir.SSAValue = info.argument(QubitType)
76+
77+
78+
@statement(dialect=dialect)
79+
class Measure(ir.Statement):
80+
traits = frozenset({lowering.FromPythonCall()})
81+
basis: Basis = (
82+
info.attribute()
83+
) # we can use Python objects as attributes (compile-time values)!
84+
state: ir.SSAValue = info.argument(StateType)
85+
qubit: ir.SSAValue = info.argument(QubitType)
86+
result: ir.ResultValue = info.result(types.Int)
87+
88+
89+
# now we have the miminim set of statements to represent a quantum circuit
90+
# the following defines a group of "dialects" so we can use it as a decorator
91+
@ir.dialect_group(basic.add(dialect))
92+
def quantum(self): # group self
93+
def run_default_pass(method, option_a=True):
94+
# default pass to run right after calling the decorator
95+
# a.k.a the default JIT compilation part of the compiler
96+
pass
97+
98+
return run_default_pass
99+
100+
101+
# Ok let's try it out
102+
@quantum
103+
def main(state: QuantumState):
104+
a = NewQubit(state)
105+
b = NewQubit(state)
106+
H(state, a)
107+
CX(state, control=a, target=b)
108+
return Measure(state, basis=Basis.Z, qubit=b)
109+
110+
111+
# well Linter is mad at us
112+
113+
# [section]
114+
# fortunately, Kirin provides a way to give hints to a standard Python linter
115+
# now let's make some lowering wrappers to make Python type hinting happy
116+
117+
118+
@lowering.wraps(NewQubit)
119+
def new_qubit(state: QuantumState) -> Qubit: ...
120+
121+
122+
@lowering.wraps(X)
123+
def x(state: QuantumState, qubit: Qubit) -> None: ...
124+
125+
126+
@lowering.wraps(H)
127+
def h(state: QuantumState, qubit: Qubit) -> None: ...
128+
129+
130+
@lowering.wraps(CX)
131+
def cx(state: QuantumState, control: Qubit, target: Qubit) -> None: ...
132+
133+
134+
@lowering.wraps(CZ)
135+
def cz(state: QuantumState, control: Qubit, target: Qubit) -> None: ...
136+
137+
138+
@lowering.wraps(Measure)
139+
def measure(state: QuantumState, basis: Basis, qubit: Qubit) -> None: ...
140+
141+
142+
# this is a lot nicer now!
143+
@quantum
144+
def main(state: QuantumState):
145+
a = new_qubit(state)
146+
b = new_qubit(state)
147+
h(state, a)
148+
h(state, b)
149+
cx(state, control=a, target=b)
150+
if measure(state, basis=Basis.Z, qubit=b):
151+
x(state, a) # we can use the result of Measure to conditionally apply X gate
152+
return
153+
154+
155+
main.print()
156+
157+
# Ok but this doesn't work yet, I cannot run it
158+
# main()
159+
160+
# [section]
161+
# we need to implement the runtime for the quantum circuit
162+
# let's just import qulacs a quantum circuit simulator
163+
164+
from kirin import interp
165+
from qulacs import gate, QuantumState
166+
167+
168+
@dialect.register
169+
class MethodTable(interp.MethodTable):
170+
@interp.impl(NewQubit)
171+
def impl_new_qubit(
172+
self, interp: interp.Interpreter, frame: interp.Frame, stmt: NewQubit
173+
) -> tuple[Qubit]:
174+
return (Qubit(),)
175+
176+
@interp.impl(X)
177+
def impl_x(self, interp: interp.Interpreter, frame: interp.Frame, stmt: X) -> None:
178+
state = frame.get_casted(
179+
stmt.state, QuantumState
180+
) # assume state is QuantumState at runtime
181+
qubit = frame.get_casted(
182+
stmt.qubit, Qubit
183+
) # we assume qubits are Qubit at runtime
184+
gate.X(qubit.id).update_quantum_state(state)
185+
186+
@interp.impl(H)
187+
def impl_h(self, interp: interp.Interpreter, frame: interp.Frame, stmt: H) -> None:
188+
state = frame.get_casted(stmt.state, QuantumState)
189+
qubit = frame.get_casted(stmt.qubit, Qubit)
190+
gate.H(qubit.id).update_quantum_state(state)
191+
192+
@interp.impl(CX)
193+
def impl_cx(
194+
self, interp: interp.Interpreter, frame: interp.Frame, stmt: CX
195+
) -> None:
196+
state = frame.get_casted(stmt.state, QuantumState)
197+
control = frame.get_casted(stmt.control, Qubit)
198+
target = frame.get_casted(stmt.target, Qubit)
199+
print(f"Applying CNOT gate with control {control.id} and target {target.id}")
200+
gate.CNOT(control.id, target.id).update_quantum_state(state)
201+
202+
@interp.impl(CZ)
203+
def impl_cz(
204+
self, interp: interp.Interpreter, frame: interp.Frame, stmt: CZ
205+
) -> None:
206+
state = frame.get_casted(stmt.state, QuantumState)
207+
control = frame.get_casted(stmt.control, Qubit)
208+
target = frame.get_casted(stmt.target, Qubit)
209+
print(f"Applying CZ gate with control {control.id} and target {target.id}")
210+
gate.CZ(control.id, target.id).update_quantum_state(state)
211+
212+
@interp.impl(Measure)
213+
def impl_measure(
214+
self, interp: interp.Interpreter, frame: interp.Frame, stmt: Measure
215+
) -> tuple[int]:
216+
state = frame.get_casted(stmt.state, QuantumState)
217+
qubit = frame.get_casted(stmt.qubit, Qubit)
218+
basis = stmt.basis.value # get the basis as a string
219+
result = gate.Measurement(qubit.id, qubit.id).update_quantum_state(state)
220+
return (
221+
state.get_classical_value(qubit.id),
222+
) # return the measurement result as an int
223+
224+
225+
state = QuantumState(2) # 2 qubits
226+
state.set_zero_state()
227+
main(state)
228+
print(state.get_vector())
229+
230+
# [section]
231+
# ok now we can run it, how about rewriting the program?
232+
233+
from kirin.rewrite.abc import RewriteRule, RewriteResult
234+
235+
236+
class CX2CZ(RewriteRule):
237+
238+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
239+
if not isinstance(node, CX):
240+
return RewriteResult()
241+
242+
H(node.state, node.target).insert_before(node)
243+
node.replace_by(
244+
cz_node := CZ(state=node.state, control=node.control, target=node.target)
245+
)
246+
H(node.state, node.target).insert_after(cz_node)
247+
return RewriteResult(has_done_something=True)
248+
249+
250+
from kirin.rewrite import Walk
251+
252+
Walk(CX2CZ()).rewrite(main.code)
253+
main.print()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ exclude = [
5454
"dist",
5555
"node_modules",
5656
"venv",
57+
"example/quantum/script.py", # Ignore specific file
5758
]
5859

5960
[tool.ruff.lint]

0 commit comments

Comments
 (0)