Skip to content

Commit f464c6c

Browse files
committed
Merge branch 'main' into david/185-pyqrack-squin
2 parents b559f84 + 65acb76 commit f464c6c

File tree

6 files changed

+158
-8
lines changed

6 files changed

+158
-8
lines changed

src/bloqade/pyqrack/base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,28 @@ class PyQrackOptions(typing.TypedDict):
2626
isOpenCL: bool
2727

2828

29+
def _validate_pyqrack_options(options: PyQrackOptions) -> None:
30+
if options["isBinaryDecisionTree"] and options["isStabilizerHybrid"]:
31+
raise ValueError(
32+
"Cannot use both isBinaryDecisionTree and isStabilizerHybrid at the same time."
33+
)
34+
elif options["isTensorNetwork"] and options["isBinaryDecisionTree"]:
35+
raise ValueError(
36+
"Cannot use both isTensorNetwork and isBinaryDecisionTree at the same time."
37+
)
38+
elif options["isTensorNetwork"] and options["isStabilizerHybrid"]:
39+
raise ValueError(
40+
"Cannot use both isTensorNetwork and isStabilizerHybrid at the same time."
41+
)
42+
43+
2944
def _default_pyqrack_args() -> PyQrackOptions:
3045
return PyQrackOptions(
3146
qubitCount=-1,
3247
isTensorNetwork=False,
3348
isSchmidtDecomposeMulti=True,
3449
isSchmidtDecompose=True,
35-
isStabilizerHybrid=True,
50+
isStabilizerHybrid=False,
3651
isBinaryDecisionTree=True,
3752
isPaged=True,
3853
isCpuGpuHybrid=True,
@@ -45,6 +60,9 @@ class MemoryABC(abc.ABC):
4560
pyqrack_options: PyQrackOptions = field(default_factory=_default_pyqrack_args)
4661
sim_reg: "QrackSimulator" = field(init=False)
4762

63+
def __post_init__(self):
64+
_validate_pyqrack_options(self.pyqrack_options)
65+
4866
@abc.abstractmethod
4967
def allocate(self, n_qubits: int) -> tuple[int, ...]:
5068
"""Allocate `n_qubits` qubits and return their ids."""

src/bloqade/squin/op/stmts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,14 @@ class PauliOp(ConstantUnitary):
148148

149149

150150
@statement(dialect=dialect)
151-
class CliffordString(ConstantUnitary):
151+
class PauliString(ConstantUnitary):
152152
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
153153
string: str = info.attribute()
154154

155155
def verify(self) -> None:
156-
if not set("XYZHS").issuperset(self.string):
156+
if not set("XYZ").issuperset(self.string):
157157
raise ValueError(
158-
f"Invalid Clifford string: {self.string}. Must be a combination of 'X', 'Y', 'Z', 'H', and 'S'."
158+
f"Invalid Pauli string: {self.string}. Must be a combination of 'X', 'Y', and 'Z'."
159159
)
160160

161161

src/bloqade/squin/wire.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,25 @@ def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue):
6969
) # custom lowering required for wrapper to work here
7070

7171

72+
# Carry over from Qubit dialect
73+
@statement(dialect=dialect)
74+
class Broadcast(ir.Statement):
75+
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
76+
operator: ir.SSAValue = info.argument(OpType)
77+
inputs: tuple[ir.SSAValue, ...] = info.argument(WireType)
78+
79+
def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue):
80+
result_types = tuple(WireType for _ in args)
81+
super().__init__(
82+
args=(operator,) + args,
83+
result_types=result_types,
84+
args_slice={
85+
"operator": 0,
86+
"inputs": slice(1, None),
87+
}, # pretty printing + syntax sugar
88+
) # custom lowering required for wrapper to work here
89+
90+
7291
# NOTE: measurement cannot be pure because they will collapse the state
7392
# of the qubit. The state is a hidden state that is not visible to
7493
# the user in the wire dialect.
@@ -99,6 +118,7 @@ class Reset(ir.Statement):
99118
class ConstPropWire(interp.MethodTable):
100119

101120
@interp.impl(Apply)
121+
@interp.impl(Broadcast)
102122
def apply(self, interp, frame, stmt: Apply):
103123

104124
return frame.get_values(stmt.inputs)

test/pyqrack/test_target.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def multiple_registers():
103103

104104
return q1
105105

106-
target = PyQrack(6)
106+
target = PyQrack(
107+
6, pyqrack_options={"isBinaryDecisionTree": False, "isStabilizerHybrid": True}
108+
)
107109
q1 = target.run(multiple_registers)
108110

109111
assert isinstance(q1, ilist.IList)

test/qasm2/passes/test_heuristic_noise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,11 @@ def test_method():
287287
expected_block = ir.Block(
288288
[
289289
n_qubits := constant.Constant(1),
290-
reg0 := core.QRegNew(n_qubits.result),
291-
zero := constant.Constant(0),
292-
q0 := core.QRegGet(reg0.result, zero.result),
293290
reg1 := core.QRegNew(n_qubits.result),
291+
zero := constant.Constant(0),
294292
q1 := core.QRegGet(reg1.result, zero.result),
293+
reg0 := core.QRegNew(n_qubits.result),
294+
q0 := core.QRegGet(reg0.result, zero.result),
295295
reg_list := ilist.New(
296296
values=[reg0.result, reg1.result], elem_type=reg0.result.type
297297
),

test/squin/test_constprop.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# There's a method table in the wire dialect statements
2+
# that handles the multiple return values from Apply and Broadcast
3+
# that can cause problems for constant propoagation's default implementation.
4+
# These tests just make sure there are corresponding lattice types per each
5+
# SSA value (as opposed to a bunch of "missing" entries despite multiple
6+
# return values from Broadcast and Apply)
7+
8+
from kirin import ir, types
9+
from kirin.passes import Fold
10+
from kirin.dialects import py, func
11+
from kirin.analysis.const import Propagate
12+
13+
from bloqade import squin
14+
15+
16+
def as_int(value: int):
17+
return py.constant.Constant(value=value)
18+
19+
20+
def as_float(value: float):
21+
return py.constant.Constant(value=value)
22+
23+
24+
def gen_func_from_stmts(stmts):
25+
26+
squin_with_py = squin.groups.wired.add(py)
27+
28+
block = ir.Block(stmts)
29+
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
30+
func_wrapper = func.Function(
31+
sym_name="main",
32+
signature=func.Signature(inputs=(), output=squin.op.types.OpType),
33+
body=ir.Region(blocks=block),
34+
)
35+
36+
constructed_method = ir.Method(
37+
mod=None,
38+
py_func=None,
39+
sym_name="main",
40+
dialects=squin_with_py,
41+
code=func_wrapper,
42+
arg_names=[],
43+
)
44+
45+
fold_pass = Fold(squin_with_py)
46+
fold_pass(constructed_method)
47+
48+
return constructed_method
49+
50+
51+
def test_wire_apply_constprop():
52+
53+
stmts = [
54+
(n_qubits := as_int(2)),
55+
(qreg := squin.qubit.New(n_qubits=n_qubits.result)),
56+
(idx0 := as_int(0)),
57+
(q0 := py.GetItem(qreg.result, idx0.result)),
58+
(idx1 := as_int(1)),
59+
(q1 := py.GetItem(qreg.result, idx1.result)),
60+
# get wires
61+
(w0 := squin.wire.Unwrap(q0.result)),
62+
(w1 := squin.wire.Unwrap(q1.result)),
63+
# put wire through gates
64+
(x := squin.op.stmts.X()),
65+
(cx := squin.op.stmts.Control(op=x.result, n_controls=1)),
66+
(a := squin.wire.Apply(cx.result, w0.result, w1.result)),
67+
(func.Return(a.results[0])),
68+
]
69+
constructed_method = gen_func_from_stmts(stmts)
70+
71+
prop_analysis = Propagate(constructed_method.dialects)
72+
frame, _ = prop_analysis.run_analysis(constructed_method)
73+
74+
assert len(frame.entries.values()) == 13
75+
76+
77+
def test_wire_broadcast_constprop():
78+
79+
stmts = [
80+
(n_qubits := as_int(4)),
81+
(qreg := squin.qubit.New(n_qubits=n_qubits.result)),
82+
(idx0 := as_int(0)),
83+
(q0 := py.GetItem(qreg.result, idx0.result)),
84+
(idx1 := as_int(1)),
85+
(q1 := py.GetItem(qreg.result, idx1.result)),
86+
(idx2 := as_int(2)),
87+
(q2 := py.GetItem(qreg.result, idx2.result)),
88+
(idx3 := as_int(3)),
89+
(q3 := py.GetItem(qreg.result, idx3.result)),
90+
# get wires
91+
(w0 := squin.wire.Unwrap(q0.result)),
92+
(w1 := squin.wire.Unwrap(q1.result)),
93+
(w2 := squin.wire.Unwrap(q2.result)),
94+
(w3 := squin.wire.Unwrap(q3.result)),
95+
# put wire through gates
96+
(x := squin.op.stmts.X()),
97+
(cx := squin.op.stmts.Control(op=x.result, n_controls=1)),
98+
(
99+
a := squin.wire.Broadcast(
100+
cx.result, w0.result, w1.result, w2.result, w3.result
101+
)
102+
),
103+
(func.Return(a.results[0])),
104+
]
105+
constructed_method = gen_func_from_stmts(stmts)
106+
107+
prop_analysis = Propagate(constructed_method.dialects)
108+
frame, _ = prop_analysis.run_analysis(constructed_method)
109+
110+
assert len(frame.entries.values()) == 21

0 commit comments

Comments
 (0)