Skip to content

Commit 693a94c

Browse files
committed
Fixing tests
1 parent 966712a commit 693a94c

File tree

14 files changed

+193
-103
lines changed

14 files changed

+193
-103
lines changed
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from . import impls as impls
22
from .lattice import (
3+
Bottom as Bottom,
34
Address as Address,
5+
Unknown as Unknown,
6+
AddressReg as AddressReg,
7+
UnknownReg as UnknownReg,
48
ConstResult as ConstResult,
59
AddressQubit as AddressQubit,
10+
PartialIList as PartialIList,
11+
PartialTuple as PartialTuple,
612
UnknownQubit as UnknownQubit,
713
PartialLambda as PartialLambda,
8-
StaticContainer as StaticContainer,
914
)
1015
from .analysis import AddressAnalysis as AddressAnalysis

src/bloqade/analysis/fidelity/analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from kirin.interp.value import Successor
88
from kirin.analysis.forward import ForwardFrame
99

10-
from ..address import AddressAnalysis
10+
from ..address import Address, AddressAnalysis
1111

1212

1313
class FidelityAnalysis(Forward):
@@ -57,7 +57,7 @@ def main():
5757

5858
_current_atom_survival_probability: list[float] = field(init=False)
5959

60-
addr_frame: ForwardFrame = field(init=False)
60+
addr_frame: ForwardFrame[Address] = field(init=False)
6161

6262
def initialize(self):
6363
super().initialize()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from kirin import interp
2+
from kirin.analysis import const
3+
4+
from bloqade.analysis.address import (
5+
Address,
6+
AddressReg,
7+
ConstResult,
8+
AddressQubit,
9+
UnknownQubit,
10+
AddressAnalysis,
11+
)
12+
from bloqade.analysis.address.lattice import Bottom, UnknownReg
13+
14+
from .stmts import QRegGet, QRegNew
15+
from ._dialect import dialect
16+
17+
18+
@dialect.register(key="qubit.address")
19+
class AddressMethodTable(interp.MethodTable):
20+
21+
@interp.impl(QRegNew)
22+
def new(
23+
self,
24+
interp: AddressAnalysis,
25+
frame: interp.Frame[Address],
26+
stmt: QRegNew,
27+
):
28+
n_qubits = frame.get(stmt.n_qubits)
29+
match n_qubits:
30+
case ConstResult(const.Value(int() as n)):
31+
addr = AddressReg(range(interp.next_address, interp.next_address + n))
32+
interp.next_address += n
33+
return (addr,)
34+
case _:
35+
return (UnknownReg(),)
36+
37+
@interp.impl(QRegGet)
38+
def get(self, interp: AddressAnalysis, frame: interp.Frame[Address], stmt: QRegGet):
39+
addr = frame.get(stmt.reg)
40+
idx = frame.get(stmt.idx)
41+
42+
match (addr, idx):
43+
case (AddressReg(data), ConstResult(const.Value(int() as i))) if (
44+
0 <= i < len(data)
45+
):
46+
return (AddressQubit(data[i]),)
47+
case (UnknownReg(), ConstResult()):
48+
return (UnknownQubit(),)
49+
case _:
50+
return (Bottom(),)

src/bloqade/qasm2/dialects/noise/fidelity.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from kirin import interp
22
from kirin.lattice import EmptyLattice
33

4-
from bloqade.analysis.address import AddressQubit, AddressTuple
54
from bloqade.analysis.fidelity import FidelityAnalysis
65

76
from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
@@ -42,10 +41,8 @@ def atom_loss(
4241
stmt: AtomLossChannel,
4342
):
4443
# NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple
45-
addresses: AddressTuple = interp.addr_frame.get(stmt.qargs)
46-
44+
addresses = interp.addr_frame.get(stmt.qargs)
45+
print(addresses)
4746
# NOTE: get the corresponding index and reduce survival probability accordingly
48-
for qbit_address in addresses.data:
49-
assert isinstance(qbit_address, AddressQubit)
50-
index = qbit_address.data
47+
for index in addresses.data:
5148
interp._current_atom_survival_probability[index] *= 1 - stmt.prob

src/bloqade/qasm2/dialects/noise/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
from typing import Sequence
23
from dataclasses import field, dataclass
34

45

@@ -161,7 +162,7 @@ def sitter_errors(
161162

162163
@abc.abstractmethod
163164
def parallel_cz_errors(
164-
self, ctrls: list[int], qargs: list[int], rest: list[int]
165+
self, ctrls: Sequence[int], qargs: Sequence[int], rest: Sequence[int]
165166
) -> dict[tuple[float, float, float, float], list[int]]:
166167
"""Takes a set of ctrls and qargs and returns a noise model for all qubits."""
167168
pass

src/bloqade/qasm2/rewrite/noise/heuristic_noise.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Tuple, cast
1+
from typing import Dict, List, Tuple
22
from dataclasses import field, dataclass
33

44
from kirin import ir
@@ -55,7 +55,7 @@ def rewrite_single_qubit_gate(self, node: uop.SingleQubitGate):
5555

5656
def rewrite_global_single_qubit_gate(self, node: glob.UGate):
5757
addrs = self.address_analysis[node.registers]
58-
if not isinstance(addrs, address.AddressTuple):
58+
if not isinstance(addrs, address.PartialIList):
5959
return rewrite_abc.RewriteResult()
6060

6161
qargs = []
@@ -74,10 +74,7 @@ def rewrite_global_single_qubit_gate(self, node: glob.UGate):
7474

7575
def rewrite_parallel_single_qubit_gate(self, node: parallel.RZ | parallel.UGate):
7676
addrs = self.address_analysis[node.qargs]
77-
if not isinstance(addrs, address.AddressTuple):
78-
return rewrite_abc.RewriteResult()
79-
80-
if not all(isinstance(addr, address.AddressQubit) for addr in addrs.data):
77+
if not isinstance(addrs, address.AddressReg):
8178
return rewrite_abc.RewriteResult()
8279

8380
assert isinstance(node.qargs, ir.ResultValue)
@@ -178,18 +175,11 @@ def rewrite_parallel_cz_gate(self, node: parallel.CZ):
178175
qargs = self.address_analysis[node.qargs]
179176

180177
has_done_something = False
181-
if (
182-
isinstance(ctrls, address.AddressTuple)
183-
and all(isinstance(addr, address.AddressQubit) for addr in ctrls.data)
184-
and isinstance(qargs, address.AddressTuple)
185-
and all(isinstance(addr, address.AddressQubit) for addr in qargs.data)
178+
if isinstance(ctrls, address.AddressReg) and isinstance(
179+
qargs, address.AddressReg
186180
):
187-
ctrl_qubits = list(
188-
map(lambda addr: cast(address.AddressQubit, addr).data, ctrls.data)
189-
)
190-
qarg_qubits = list(
191-
map(lambda addr: cast(address.AddressQubit, addr).data, qargs.data)
192-
)
181+
ctrl_qubits = tuple(ctrls.data)
182+
qarg_qubits = tuple(qargs.data)
193183
rest = sorted(
194184
set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits)
195185
)

src/bloqade/qasm2/rewrite/parallel_to_glob.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from kirin import ir
55
from kirin.rewrite import abc
6-
from kirin.analysis import const
76
from kirin.dialects import ilist
87

98
from bloqade.analysis import address
@@ -20,28 +19,24 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
2019
return abc.RewriteResult()
2120

2221
qargs = node.qargs
23-
qarg_addresses = self.address_analysis.get(qargs, None)
22+
qargs_address = self.address_analysis.get(qargs, address.Unknown())
2423

25-
if isinstance(qarg_addresses, address.AddressReg):
26-
# NOTE: we only have an AddressReg if it's an entire register, definitely rewrite that
27-
return self._rewrite_parallel_to_glob(node)
28-
29-
if not isinstance(qarg_addresses, address.AddressTuple):
24+
if not isinstance(qargs_address, address.AddressReg):
3025
return abc.RewriteResult()
3126

32-
idxs, qreg = self._find_qreg(qargs.owner, set())
27+
qregs = self._get_all_qreg(qargs.owner)
3328

34-
if qreg is None:
35-
# NOTE: no unique register found
29+
if len(qregs) != 1:
3630
return abc.RewriteResult()
3731

38-
if not isinstance(hint := qreg.n_qubits.hints.get("const"), const.Value):
39-
# NOTE: non-constant number of qubits
32+
qreg = next(iter(qregs))
33+
34+
qreg_address = self.address_analysis.get(qreg, address.Unknown())
35+
36+
if not isinstance(qreg_address, address.AddressReg):
4037
return abc.RewriteResult()
4138

42-
n = hint.data
43-
if len(idxs) != n:
44-
# NOTE: not all qubits of the register are there
39+
if set(qargs_address.data) != set(qreg_address.data):
4540
return abc.RewriteResult()
4641

4742
return self._rewrite_parallel_to_glob(node)
@@ -53,6 +48,24 @@ def _rewrite_parallel_to_glob(node: parallel.UGate) -> abc.RewriteResult:
5348
node.replace_by(global_u)
5449
return abc.RewriteResult(has_done_something=True)
5550

51+
@staticmethod
52+
def _get_all_qreg(owner: ir.Statement | ir.Block):
53+
stack = [owner]
54+
qregs: set[ir.SSAValue] = set()
55+
while stack:
56+
current = stack.pop()
57+
58+
if isinstance(current, core.stmts.QRegGet):
59+
stack.append(current.reg.owner)
60+
elif isinstance(current, ilist.New):
61+
for val in current.values:
62+
stack.append(val.owner)
63+
64+
elif isinstance(current, core.QRegNew):
65+
qregs.add(current.result)
66+
67+
return qregs
68+
5669
@staticmethod
5770
def _find_qreg(
5871
qargs_owner: ir.Statement | ir.Block, idxs: set

src/bloqade/qasm2/rewrite/parallel_to_uop.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,10 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
2121

2222
def get_qubit_ssa(self, ilist_ref: ir.SSAValue) -> Optional[List[ir.SSAValue]]:
2323
addr = self.address_analysis.get(ilist_ref)
24-
if not isinstance(addr, address.AddressTuple):
24+
if not isinstance(addr, address.AddressReg):
2525
return None
2626

27-
ids = []
28-
for ele in addr.data:
29-
if not isinstance(ele, address.AddressQubit):
30-
return None
31-
32-
ids.append(ele.data)
33-
27+
ids = addr.data
3428
return [self.id_map[ele] for ele in ids]
3529

3630
def rewrite_cz(self, node: ir.Statement):

src/bloqade/stim/rewrite/squin_noise.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from bloqade.squin import noise as squin_noise
1010
from bloqade.stim.dialects import noise as stim_noise
1111
from bloqade.stim.rewrite.util import insert_qubit_idx_from_address
12-
from bloqade.analysis.address.lattice import AddressTuple
12+
from bloqade.analysis.address.lattice import AddressReg, AddressQubit
1313
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
1414

1515

@@ -42,19 +42,24 @@ def rewrite_NoiseChannel(
4242
if not isinstance(qubit_address_attr, AddressAttribute):
4343
return RewriteResult()
4444

45-
address_tuple = qubit_address_attr.address
45+
address_reg = qubit_address_attr.address
4646

47-
if not isinstance(address_tuple, AddressTuple):
47+
if not isinstance(address_reg, AddressReg):
4848
return RewriteResult()
4949

5050
qubit_idx_ssas_list = [
51-
insert_qubit_idx_from_address(AddressAttribute(address=address), stmt)
52-
for address in address_tuple.data
51+
insert_qubit_idx_from_address(
52+
AddressAttribute(address=AddressQubit(address)), stmt
53+
)
54+
for address in address_reg.data
5355
]
5456
if None in qubit_idx_ssas_list:
5557
return RewriteResult()
5658

5759
for qubit_idx_ssas in qubit_idx_ssas_list:
60+
assert (
61+
qubit_idx_ssas is not None
62+
), "qubit_idx_ssas should not be None here"
5863
stim_stmt = rewrite_method(stmt, tuple(qubit_idx_ssas))
5964
stim_stmt.insert_before(stmt)
6065
stmt.delete()

src/bloqade/stim/rewrite/util.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.dialects import py
33

44
from bloqade.squin.rewrite import AddressAttribute
5-
from bloqade.analysis.address import AddressReg, AddressWire, AddressQubit, AddressTuple
5+
from bloqade.analysis.address import AddressReg, AddressQubit
66

77

88
def create_and_insert_qubit_idx_stmt(
@@ -22,14 +22,7 @@ def insert_qubit_idx_from_address(
2222
address_data = address.address
2323
qubit_idx_ssas = []
2424

25-
if isinstance(address_data, AddressTuple):
26-
for address_qubit in address_data.data:
27-
if not isinstance(address_qubit, AddressQubit):
28-
return
29-
create_and_insert_qubit_idx_stmt(
30-
address_qubit.data, stmt_to_insert_before, qubit_idx_ssas
31-
)
32-
elif isinstance(address_data, AddressReg):
25+
if isinstance(address_data, AddressReg):
3326
for qubit_idx in address_data.data:
3427
create_and_insert_qubit_idx_stmt(
3528
qubit_idx, stmt_to_insert_before, qubit_idx_ssas
@@ -38,11 +31,6 @@ def insert_qubit_idx_from_address(
3831
create_and_insert_qubit_idx_stmt(
3932
address_data.data, stmt_to_insert_before, qubit_idx_ssas
4033
)
41-
elif isinstance(address_data, AddressWire):
42-
address_qubit = address_data.origin_qubit
43-
create_and_insert_qubit_idx_stmt(
44-
address_qubit.data, stmt_to_insert_before, qubit_idx_ssas
45-
)
4634
else:
4735
return
4836

0 commit comments

Comments
 (0)