Skip to content

Commit 9ae811c

Browse files
committed
proper handling of Sized(1) trait operators
1 parent 9aa2964 commit 9ae811c

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

squin_op_playground.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from bloqade import qasm2, squin
66
from bloqade.analysis import address
7+
from bloqade.squin.analysis import shape
78

89

910
def as_int(value: int):
@@ -18,21 +19,24 @@ def as_int(value: int):
1819
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
1920
# Get qubits out
2021
(idx0 := as_int(0)),
21-
(q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
22+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
2223
# Unwrap to get wires
23-
(w1 := squin.wire.Unwrap(qubit=q1.result)),
24-
# Put them in an ilist and return to prevent elimination
25-
# Put the wire into one operator
24+
(w1 := squin.wire.Unwrap(qubit=q0.result)),
25+
# Pass wire into operator
2626
(op := squin.op.stmts.H()),
2727
(v1 := squin.wire.Apply(op.result, w1.result)),
28-
(func.Return(v1.results[0])),
28+
# Test Identity
29+
(id := squin.op.stmts.Identity(size=1)),
30+
(v2 := squin.wire.Apply(id.result, v1.results[0])),
31+
# Keep Passing Operators
32+
(func.Return(v2.results[0])),
2933
]
3034

3135
block = ir.Block(stmts)
3236
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
3337
func_wrapper = func.Function(
3438
sym_name="main",
35-
signature=func.Signature(inputs=(), output=squin.wire.WireType),
39+
signature=func.Signature(inputs=(), output=ilist.IListType),
3640
body=ir.Region(blocks=block),
3741
)
3842

@@ -51,3 +55,14 @@ def as_int(value: int):
5155
frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis(
5256
constructed_method, no_raise=False
5357
)
58+
59+
frame, _ = shape.ShapeAnalysis(constructed_method.dialects).run_analysis(
60+
constructed_method, no_raise=False
61+
)
62+
63+
"""
64+
frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis(
65+
constructed_method, no_raise=False
66+
"""
67+
68+
constructed_method.print(analysis=frame.entries)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .analysis import ShapeAnalysis as ShapeAnalysis

src/bloqade/squin/analysis/shape/analysis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from bloqade.squin.op.types import OpType
88
from bloqade.squin.op.traits import Sized, HasSize
99

10-
from .lattice import Shape, OpShape, AnyShape
10+
from .lattice import Shape, NoShape, OpShape
1111

1212

1313
class ShapeAnalysis(Forward[Shape]):
@@ -16,7 +16,7 @@ class ShapeAnalysis(Forward[Shape]):
1616
lattice = Shape
1717

1818
def initialize(self):
19-
super().initialize
19+
super().initialize()
2020
return self
2121

2222
# Take a page from const prop in Kirin,
@@ -25,15 +25,15 @@ def initialize(self):
2525
def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
2626
if stmt.has_trait(Sized):
2727
size = stmt.get_trait(Sized)
28-
return (OpShape(size=size),)
28+
return (OpShape(size=size.data),)
2929
# Handle op.Identity
3030
elif stmt.has_trait(HasSize):
3131
# Caution! This can return None
3232
has_size_inst = stmt.get_trait(HasSize)
3333
size = has_size_inst.get_size(stmt)
3434
return (OpShape(size=size),)
3535
else:
36-
return (AnyShape(),)
36+
return (NoShape(),)
3737

3838
def eval_stmt_fallback(
3939
self, frame: ForwardFrame[Shape], stmt: ir.Statement

0 commit comments

Comments
 (0)