|
| 1 | +# There's a method table in the wire dialect statements |
| 2 | +# that handles the multiple return values from Apply and Broadcast |
| 3 | +# that can cause problems for constant propoagation's default implementation. |
| 4 | +# These tests just make sure there are corresponding lattice types per each |
| 5 | +# SSA value (as opposed to a bunch of "missing" entries despite multiple |
| 6 | +# return values from Broadcast and Apply) |
| 7 | + |
| 8 | +from kirin import ir, types |
| 9 | +from kirin.passes import Fold |
| 10 | +from kirin.dialects import py, func |
| 11 | +from kirin.analysis.const import Propagate |
| 12 | + |
| 13 | +from bloqade import squin |
| 14 | + |
| 15 | + |
| 16 | +def as_int(value: int): |
| 17 | + return py.constant.Constant(value=value) |
| 18 | + |
| 19 | + |
| 20 | +def as_float(value: float): |
| 21 | + return py.constant.Constant(value=value) |
| 22 | + |
| 23 | + |
| 24 | +def gen_func_from_stmts(stmts): |
| 25 | + |
| 26 | + squin_with_py = squin.groups.wired.add(py) |
| 27 | + |
| 28 | + block = ir.Block(stmts) |
| 29 | + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") |
| 30 | + func_wrapper = func.Function( |
| 31 | + sym_name="main", |
| 32 | + signature=func.Signature(inputs=(), output=squin.op.types.OpType), |
| 33 | + body=ir.Region(blocks=block), |
| 34 | + ) |
| 35 | + |
| 36 | + constructed_method = ir.Method( |
| 37 | + mod=None, |
| 38 | + py_func=None, |
| 39 | + sym_name="main", |
| 40 | + dialects=squin_with_py, |
| 41 | + code=func_wrapper, |
| 42 | + arg_names=[], |
| 43 | + ) |
| 44 | + |
| 45 | + fold_pass = Fold(squin_with_py) |
| 46 | + fold_pass(constructed_method) |
| 47 | + |
| 48 | + return constructed_method |
| 49 | + |
| 50 | + |
| 51 | +def test_wire_apply_constprop(): |
| 52 | + |
| 53 | + stmts = [ |
| 54 | + (n_qubits := as_int(2)), |
| 55 | + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), |
| 56 | + (idx0 := as_int(0)), |
| 57 | + (q0 := py.GetItem(qreg.result, idx0.result)), |
| 58 | + (idx1 := as_int(1)), |
| 59 | + (q1 := py.GetItem(qreg.result, idx1.result)), |
| 60 | + # get wires |
| 61 | + (w0 := squin.wire.Unwrap(q0.result)), |
| 62 | + (w1 := squin.wire.Unwrap(q1.result)), |
| 63 | + # put wire through gates |
| 64 | + (x := squin.op.stmts.X()), |
| 65 | + (cx := squin.op.stmts.Control(op=x.result, n_controls=1)), |
| 66 | + (a := squin.wire.Apply(cx.result, w0.result, w1.result)), |
| 67 | + (func.Return(a.results[0])), |
| 68 | + ] |
| 69 | + constructed_method = gen_func_from_stmts(stmts) |
| 70 | + |
| 71 | + prop_analysis = Propagate(constructed_method.dialects) |
| 72 | + frame, _ = prop_analysis.run_analysis(constructed_method) |
| 73 | + |
| 74 | + assert len(frame.entries.values()) == 13 |
| 75 | + |
| 76 | + |
| 77 | +def test_wire_broadcast_constprop(): |
| 78 | + |
| 79 | + stmts = [ |
| 80 | + (n_qubits := as_int(4)), |
| 81 | + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), |
| 82 | + (idx0 := as_int(0)), |
| 83 | + (q0 := py.GetItem(qreg.result, idx0.result)), |
| 84 | + (idx1 := as_int(1)), |
| 85 | + (q1 := py.GetItem(qreg.result, idx1.result)), |
| 86 | + (idx2 := as_int(2)), |
| 87 | + (q2 := py.GetItem(qreg.result, idx2.result)), |
| 88 | + (idx3 := as_int(3)), |
| 89 | + (q3 := py.GetItem(qreg.result, idx3.result)), |
| 90 | + # get wires |
| 91 | + (w0 := squin.wire.Unwrap(q0.result)), |
| 92 | + (w1 := squin.wire.Unwrap(q1.result)), |
| 93 | + (w2 := squin.wire.Unwrap(q2.result)), |
| 94 | + (w3 := squin.wire.Unwrap(q3.result)), |
| 95 | + # put wire through gates |
| 96 | + (x := squin.op.stmts.X()), |
| 97 | + (cx := squin.op.stmts.Control(op=x.result, n_controls=1)), |
| 98 | + ( |
| 99 | + a := squin.wire.Broadcast( |
| 100 | + cx.result, w0.result, w1.result, w2.result, w3.result |
| 101 | + ) |
| 102 | + ), |
| 103 | + (func.Return(a.results[0])), |
| 104 | + ] |
| 105 | + constructed_method = gen_func_from_stmts(stmts) |
| 106 | + |
| 107 | + prop_analysis = Propagate(constructed_method.dialects) |
| 108 | + frame, _ = prop_analysis.run_analysis(constructed_method) |
| 109 | + |
| 110 | + assert len(frame.entries.values()) == 21 |
0 commit comments