Skip to content

Commit 0c20b95

Browse files
committed
Completed implementation but not quite working
1 parent 9ae811c commit 0c20b95

File tree

3 files changed

+73
-38
lines changed

3 files changed

+73
-38
lines changed

squin_op_playground.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from kirin import ir, types
22
from kirin.passes import Fold
3-
from kirin.dialects import py, func, ilist
3+
from kirin.dialects import py, func
44

5-
from bloqade import qasm2, squin
5+
from bloqade import squin
66
from bloqade.analysis import address
77
from bloqade.squin.analysis import shape
88

@@ -11,32 +11,20 @@ def as_int(value: int):
1111
return py.constant.Constant(value=value)
1212

1313

14-
squin_with_qasm_core = squin.groups.wired.add(qasm2.core).add(ilist)
14+
squin_with_qasm_core = squin.groups.wired
1515

1616
stmts: list[ir.Statement] = [
17-
# Create qubit register
18-
(n_qubits := as_int(1)),
19-
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
20-
# Get qubits out
21-
(idx0 := as_int(0)),
22-
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
23-
# Unwrap to get wires
24-
(w1 := squin.wire.Unwrap(qubit=q0.result)),
25-
# Pass wire into operator
26-
(op := squin.op.stmts.H()),
27-
(v1 := squin.wire.Apply(op.result, w1.result)),
28-
# Test Identity
29-
(id := squin.op.stmts.Identity(size=1)),
30-
(v2 := squin.wire.Apply(id.result, v1.results[0])),
31-
# Keep Passing Operators
32-
(func.Return(v2.results[0])),
17+
(h0 := squin.op.stmts.H()),
18+
(h1 := squin.op.stmts.H()),
19+
(hh := squin.op.stmts.Kron(lhs=h1.result, rhs=h0.result)),
20+
(func.Return(hh.result)),
3321
]
3422

3523
block = ir.Block(stmts)
3624
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
3725
func_wrapper = func.Function(
3826
sym_name="main",
39-
signature=func.Signature(inputs=(), output=ilist.IListType),
27+
signature=func.Signature(inputs=(), output=squin.op.types.OpType),
4028
body=ir.Region(blocks=block),
4129
)
4230

src/bloqade/squin/analysis/shape/analysis.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@ def initialize(self):
2323
# I can get the data I want from the SizedTrait
2424
# and go from there
2525
def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
26-
if stmt.has_trait(Sized):
27-
size = stmt.get_trait(Sized)
28-
return (OpShape(size=size.data),)
29-
# Handle op.Identity
26+
method = self.lookup_registry(frame, stmt)
27+
if method is not None:
28+
return method(self, frame, stmt)
3029
elif stmt.has_trait(HasSize):
3130
# Caution! This can return None
3231
has_size_inst = stmt.get_trait(HasSize)
3332
size = has_size_inst.get_size(stmt)
3433
return (OpShape(size=size),)
34+
elif stmt.has_trait(Sized):
35+
size = stmt.get_trait(Sized)
36+
return (OpShape(size=size.data),)
3537
else:
3638
return (NoShape(),)
3739

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,69 @@
1-
from kirin import interp
1+
from typing import cast
22

3-
from bloqade import squin
3+
from kirin import ir, interp
44

5-
""" from .lattice import (
6-
Shape,
5+
from bloqade.squin import op
6+
7+
from .lattice import (
78
NoShape,
89
OpShape,
910
)
10-
11-
from .analysis import ShapeAnalysis """
11+
from .analysis import ShapeAnalysis
1212

1313

14-
@squin.op.dialect.register(key="op.shape")
14+
@op.dialect.register(key="op.shape")
1515
class SquinOp(interp.MethodTable):
16-
pass
1716

18-
# Should be using the Sized trait
19-
# that the statements have
17+
@interp.impl(op.stmts.Kron)
18+
def kron(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Kron):
19+
lhs = frame.get(stmt.lhs)
20+
rhs = frame.get(stmt.rhs)
21+
if isinstance(lhs, OpShape) and isinstance(rhs, OpShape):
22+
new_size = lhs.size + rhs.size
23+
return (OpShape(size=new_size),)
24+
else:
25+
return (NoShape(),)
26+
27+
@interp.impl(op.stmts.Mult)
28+
def mult(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Mult):
29+
lhs = frame.get(stmt.lhs)
30+
rhs = frame.get(stmt.rhs)
31+
32+
if isinstance(lhs, OpShape) and isinstance(rhs, OpShape):
33+
lhs_size = lhs.size
34+
rhs_size = rhs.size
35+
# Sized trait implicitly enforces that
36+
# all operators are square matrices,
37+
# not sure if it's worth raising an exception here
38+
# or just letting this propagate...
39+
if lhs_size != rhs_size:
40+
return (NoShape(),)
41+
else:
42+
return (OpShape(size=lhs_size + rhs_size),)
43+
else:
44+
return (NoShape(),)
45+
46+
@interp.impl(op.stmts.Control)
47+
def control(
48+
self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Control
49+
):
50+
op_shape = frame.get(stmt.op)
51+
52+
if isinstance(op_shape, OpShape):
53+
op_size = op_shape.size
54+
n_controls_attr = stmt.get_attr_or_prop("n_controls")
55+
# raise exception if attribute is NOne
56+
n_controls = cast(ir.PyAttr[int], n_controls_attr).data
57+
return (OpShape(size=op_size + n_controls),)
58+
else:
59+
return (NoShape(),)
60+
61+
@interp.impl(op.stmts.Rot)
62+
def rot(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Rot):
63+
op_shape = frame.get(stmt.axis)
64+
return op_shape
2065

21-
# Need to keep in mind that Identity
22-
# has a HasSize() trait with "size:int"
23-
# as the corresponding attribute to query
24-
# @interp.impl(squin.op.stmts.ConstantUnitary)
66+
@interp.impl(op.stmts.Scale)
67+
def scale(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
68+
op_shape = frame.get(stmt.op)
69+
return op_shape

0 commit comments

Comments
 (0)