Skip to content

Commit 04480a3

Browse files
authored
Merge branch 'main' into david/229-pyqrack-qubit-reset
2 parents 995e903 + 676ed3a commit 04480a3

File tree

15 files changed

+352
-40
lines changed

15 files changed

+352
-40
lines changed

src/bloqade/noise/native/_wrappers.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,28 @@
88

99

1010
@wraps(native.AtomLossChannel)
11-
def atom_loss_channel(
12-
qargs: ilist.IList[Qubit, Any] | list, *, prob: float
13-
) -> None: ...
11+
def atom_loss_channel(qargs: ilist.IList[Qubit, Any] | list, *, prob: float) -> None:
12+
"""Apply an atom loss channel to a list of qubits.
13+
14+
Args:
15+
qargs (ilist.IList[Qubit, Any] | list): List of qubits to apply the noise to.
16+
prob (float): The loss probability.
17+
"""
18+
...
1419

1520

1621
@wraps(native.PauliChannel)
1722
def pauli_channel(
1823
qargs: ilist.IList[Qubit, Any] | list, *, px: float, py: float, pz: float
19-
) -> None: ...
24+
) -> None:
25+
"""Apply a Pauli channel to a list of qubits.
26+
27+
Args:
28+
qargs (ilist.IList[Qubit, Any] | list): List of qubits to apply the noise to.
29+
px (float): Probability of X error.
30+
py (float): Probability of Y error.
31+
pz (float): Probability of Z error.
32+
"""
2033

2134

2235
@wraps(native.CZPauliChannel)
@@ -31,4 +44,20 @@ def cz_pauli_channel(
3144
py_qarg: float,
3245
pz_qarg: float,
3346
paired: bool,
34-
) -> None: ...
47+
) -> None:
48+
"""Insert noise for a CZ gate with a Pauli channel on qubits.
49+
50+
Args:
51+
ctrls: List of control qubits.
52+
qarg2: List of target qubits.
53+
px_ctrl: Probability of X error on control qubits.
54+
py_ctrl: Probability of Y error on control qubits.
55+
pz_ctrl: Probability of Z error on control qubits.
56+
px_qarg: Probability of X error on target qubits.
57+
py_qarg: Probability of Y error on target qubits.
58+
pz_qarg: Probability of Z error on target qubits.
59+
paired: If True, the noise is applied to both control and target qubits
60+
are not lost otherwise skip this error. If False Apply the noise on
61+
the whatever qubit is not lost.
62+
"""
63+
...

src/bloqade/qasm2/_qasm_loading.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pathlib
44
from typing import Any
55

6-
from kirin import ir, types
6+
from kirin import ir, lowering
77
from kirin.dialects import func
88

99
from . import parse
@@ -16,6 +16,7 @@ def loads(
1616
*,
1717
kernel_name: str = "main",
1818
dialects: ir.DialectGroup | None = None,
19+
returns: str | None = None,
1920
globals: dict[str, Any] | None = None,
2021
file: str | None = None,
2122
lineno_offset: int = 0,
@@ -54,7 +55,7 @@ def loads(
5455
# TODO: add source info
5556
stmt = parse.loads(qasm)
5657
qasm2_lowering = QASM2(dialects or main)
57-
body = qasm2_lowering.run(
58+
frame = qasm2_lowering.get_frame(
5859
stmt,
5960
source=qasm,
6061
file=file,
@@ -63,13 +64,21 @@ def loads(
6364
col_offset=col_offset,
6465
compactify=compactify,
6566
)
66-
return_value = func.ConstantNone()
67-
body.blocks[0].stmts.append(return_value)
68-
body.blocks[0].stmts.append(func.Return(value_or_stmt=return_value))
6967

68+
if returns is not None:
69+
return_value = frame.get(returns)
70+
if return_value is None:
71+
raise lowering.BuildError(f"Cannot find return value {returns}")
72+
else:
73+
return_value = func.ConstantNone()
74+
frame.push(return_value)
75+
76+
return_node = frame.push(func.Return(value_or_stmt=return_value))
77+
78+
body = frame.curr_region
7079
code = func.Function(
7180
sym_name=kernel_name,
72-
signature=func.Signature((), types.NoneType),
81+
signature=func.Signature((), return_node.value.type),
7382
body=body,
7483
)
7584

@@ -88,6 +97,7 @@ def loadfile(
8897
*,
8998
kernel_name: str = "main",
9099
dialects: ir.DialectGroup | None = None,
100+
returns: str | None = None,
91101
globals: dict[str, Any] | None = None,
92102
file: str | None = None,
93103
lineno_offset: int = 0,
@@ -132,6 +142,7 @@ def loadfile(
132142
source,
133143
kernel_name=kernel_name,
134144
dialects=dialects,
145+
returns=returns,
135146
globals=globals,
136147
file=file,
137148
lineno_offset=lineno_offset,

src/bloqade/qasm2/emit/target.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
allow_parallel: bool = False,
2828
allow_global: bool = False,
2929
custom_gate: bool = True,
30+
unroll_ifs: bool = True,
3031
) -> None:
3132
"""Initialize the QASM2 target.
3233
@@ -43,9 +44,14 @@ def __init__(
4344
qelib1 (bool):
4445
Include the `include "qelib1.inc"` line in the resulting QASM2 AST that's
4546
submitted to qBraid. Defaults to `True`.
47+
4648
custom_gate (bool):
4749
Include the custom gate definitions in the resulting QASM2 AST. Defaults to `True`. If `False`, all the qasm2.gate will be inlined.
4850
51+
unroll_ifs (bool):
52+
Unrolls if statements with multiple qasm2 statements in the body in order to produce valid qasm2 output, which only allows a single
53+
operation in an if body. Defaults to `True`.
54+
4955
5056
5157
"""
@@ -58,6 +64,7 @@ def __init__(
5864
self.custom_gate = custom_gate
5965
self.allow_parallel = allow_parallel
6066
self.allow_global = allow_global
67+
self.unroll_ifs = unroll_ifs
6168

6269
if allow_parallel:
6370
self.main_target = self.main_target.add(qasm2.dialects.parallel)
@@ -87,9 +94,11 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
8794

8895
# make a cloned instance of kernel
8996
entry = entry.similar()
90-
QASM2Fold(entry.dialects, inline_gate_subroutine=not self.custom_gate).fixpoint(
91-
entry
92-
)
97+
QASM2Fold(
98+
entry.dialects,
99+
inline_gate_subroutine=not self.custom_gate,
100+
unroll_ifs=self.unroll_ifs,
101+
).fixpoint(entry)
93102

94103
if not self.allow_global:
95104
# rewrite global to parallel

src/bloqade/qasm2/groups.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from kirin.prelude import structural_no_opt
33
from kirin.dialects import scf, func, ilist, lowering
44

5+
from bloqade.noise import native
56
from bloqade.qasm2.dialects import (
67
uop,
78
core,
@@ -90,6 +91,7 @@ def run_pass(
9091
noise,
9192
parallel,
9293
core,
94+
native,
9395
]
9496
)
9597
)

src/bloqade/qasm2/parse/lowering.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,28 @@ def run(
2828
col_offset: int = 0,
2929
compactify: bool = True,
3030
) -> ir.Region:
31+
32+
frame = self.get_frame(
33+
stmt,
34+
source=source,
35+
globals=globals,
36+
file=file,
37+
lineno_offset=lineno_offset,
38+
col_offset=col_offset,
39+
)
40+
41+
return frame.curr_region
42+
43+
def get_frame(
44+
self,
45+
stmt: ast.Node,
46+
source: str | None = None,
47+
globals: dict[str, Any] | None = None,
48+
file: str | None = None,
49+
lineno_offset: int = 0,
50+
col_offset: int = 0,
51+
compactify: bool = True,
52+
) -> lowering.Frame:
3153
# TODO: add source info
3254
state = lowering.State(
3355
self,
@@ -41,13 +63,13 @@ def run(
4163
finalize_next=False,
4264
) as frame:
4365
self.visit(state, stmt)
44-
region = frame.curr_region
4566

46-
if compactify:
47-
from kirin.rewrite import Walk, CFGCompactify
67+
if compactify:
68+
from kirin.rewrite import Walk, CFGCompactify
69+
70+
Walk(CFGCompactify()).rewrite(frame.curr_region)
4871

49-
Walk(CFGCompactify()).rewrite(region)
50-
return region
72+
return frame
5173

5274
def visit(self, state: lowering.State[ast.Node], node: ast.Node) -> lowering.Result:
5375
name = node.__class__.__name__

src/bloqade/qasm2/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .py2qasm import Py2QASM as Py2QASM
44
from .qasm2py import QASM2Py as QASM2Py
55
from .parallel import UOpToParallel as UOpToParallel
6+
from .unroll_if import UnrollIfs as UnrollIfs

src/bloqade/qasm2/passes/fold.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@
2323

2424
from bloqade.qasm2.dialects import expr
2525

26+
from .unroll_if import UnrollIfs
27+
2628

2729
@dataclass
2830
class QASM2Fold(Pass):
2931
"""Fold pass for qasm2.extended"""
3032

3133
constprop: const.Propagate = field(init=False)
3234
inline_gate_subroutine: bool = True
35+
unroll_ifs: bool = True
3336

3437
def __post_init__(self):
3538
self.constprop = const.Propagate(self.dialects)
@@ -61,6 +64,9 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
6164
.join(result)
6265
)
6366

67+
if self.unroll_ifs:
68+
UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
69+
6470
# run typeinfer again after unroll etc. because we now insert
6571
# a lot of new nodes, which might have more precise types
6672
self.typeinfer.unsafe_run(mt)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from kirin import ir
2+
from kirin.passes import Pass
3+
from kirin.rewrite import (
4+
Walk,
5+
Chain,
6+
Fixpoint,
7+
ConstantFold,
8+
CommonSubexpressionElimination,
9+
)
10+
11+
from ..rewrite.split_ifs import LiftThenBody, SplitIfStmts
12+
13+
14+
class UnrollIfs(Pass):
15+
"""This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
16+
17+
def unsafe_run(self, mt: ir.Method):
18+
result = Walk(LiftThenBody()).rewrite(mt.code)
19+
result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
20+
result = (
21+
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
22+
.rewrite(mt.code)
23+
.join(result)
24+
)
25+
return result
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from kirin import ir
2+
from kirin.dialects import scf, func
3+
from kirin.rewrite.abc import RewriteRule, RewriteResult
4+
5+
from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
6+
from ..dialects.core.stmts import Reset, Measure
7+
8+
# TODO: unify with PR #248
9+
AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
10+
11+
DontLiftType = AllowedThenType | scf.Yield | func.Return | func.Invoke
12+
13+
14+
class LiftThenBody(RewriteRule):
15+
"""Lifts anything that's not a UOP or a yield/return out of the then body"""
16+
17+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18+
if not isinstance(node, scf.IfElse):
19+
return RewriteResult()
20+
21+
then_stmts = node.then_body.stmts()
22+
23+
lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)]
24+
25+
if len(lift_stmts) == 0:
26+
return RewriteResult()
27+
28+
for stmt in lift_stmts:
29+
stmt.detach()
30+
stmt.insert_before(node)
31+
32+
return RewriteResult(has_done_something=True)
33+
34+
35+
class SplitIfStmts(RewriteRule):
36+
"""Splits the then body of an if-else statement into multiple if statements"""
37+
38+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
39+
if not isinstance(node, scf.IfElse):
40+
return RewriteResult()
41+
42+
*stmts, yield_or_return = node.then_body.stmts()
43+
44+
if len(stmts) == 1:
45+
return RewriteResult()
46+
47+
is_yield = isinstance(yield_or_return, scf.Yield)
48+
49+
for stmt in stmts:
50+
stmt.detach()
51+
52+
yield_or_return = scf.Yield() if is_yield else func.Return()
53+
54+
then_block = ir.Block((stmt, yield_or_return), argtypes=(node.cond.type,))
55+
then_body = ir.Region(then_block)
56+
else_body = node.else_body.clone()
57+
else_body.detach()
58+
new_if = scf.IfElse(
59+
cond=node.cond, then_body=then_body, else_body=else_body
60+
)
61+
62+
new_if.insert_before(node)
63+
64+
node.delete()
65+
66+
return RewriteResult(has_done_something=True)

0 commit comments

Comments
 (0)