Skip to content

Commit 5447a77

Browse files
authored
Merge branch 'main' into roger/upgrade-kirin-0-17
2 parents dafd4e0 + 1cd01ec commit 5447a77

File tree

26 files changed

+390
-93
lines changed

26 files changed

+390
-93
lines changed

src/bloqade/analysis/address/impls.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,10 @@ def unwrap(
192192

193193
origin_qubit = frame.get(stmt.qubit)
194194

195-
return (AddressWire(origin_qubit=origin_qubit),)
195+
if isinstance(origin_qubit, AddressQubit):
196+
return (AddressWire(origin_qubit=origin_qubit),)
197+
else:
198+
return (Address.top(),)
196199

197200
@interp.impl(squin.wire.Apply)
198201
def apply(
@@ -201,14 +204,7 @@ def apply(
201204
frame: ForwardFrame[Address],
202205
stmt: squin.wire.Apply,
203206
):
204-
205-
origin_qubits = tuple(
206-
[frame.get(input_elem).origin_qubit for input_elem in stmt.inputs]
207-
)
208-
new_address_wires = tuple(
209-
[AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits]
210-
)
211-
return new_address_wires
207+
return frame.get_values(stmt.inputs)
212208

213209

214210
@squin.qubit.dialect.register(key="qubit.address")

src/bloqade/analysis/address/lattice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,5 @@ class AddressWire(Address):
8181

8282
def is_subseteq(self, other: Address) -> bool:
8383
if isinstance(other, AddressWire):
84-
return self.origin_qubit == self.origin_qubit
84+
return self.origin_qubit == other.origin_qubit
8585
return False

src/bloqade/noise/native/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,9 @@ class MoveNoiseModelABC(abc.ABC):
102102
params: MoveNoiseParams = field(default_factory=MoveNoiseParams)
103103
"""Parameters for calculating move noise."""
104104

105-
@classmethod
106105
@abc.abstractmethod
107106
def parallel_cz_errors(
108-
cls, ctrls: List[int], qargs: List[int], rest: List[int]
107+
self, ctrls: List[int], qargs: List[int], rest: List[int]
109108
) -> Dict[Tuple[float, float, float, float], List[int]]:
110109
"""Takes a set of ctrls and qargs and returns a noise model for all qubits."""
111110
pass

src/bloqade/pyqrack/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414

1515
# NOTE: The following import is for registering the method tables
1616
from .noise import native as native
17-
from .qasm2 import uop as uop, core as core, parallel as parallel
17+
from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
1818
from .target import PyQrack as PyQrack

src/bloqade/pyqrack/noise/native.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ def atom_loss_channel(
9393

9494
for qarg in active_qubits:
9595
if interp.rng_state.uniform() <= stmt.prob:
96-
sim_reg = qarg.ref.sim_reg
97-
sim_reg.force_m(qarg.addr, 0)
96+
qarg.ref.sim_reg.m(qarg.addr)
9897
qarg.drop()
9998

10099
return ()

src/bloqade/pyqrack/qasm2/core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
@core.dialect.register(key="pyqrack")
1616
class PyQrackMethods(interp.MethodTable):
17-
1817
@interp.impl(core.QRegNew)
1918
def qreg_new(
2019
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegNew

src/bloqade/pyqrack/qasm2/glob.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Any
2+
3+
from kirin import interp
4+
from kirin.dialects import ilist
5+
6+
from bloqade.pyqrack.reg import PyQrackReg
7+
from bloqade.pyqrack.base import PyQrackInterpreter
8+
from bloqade.qasm2.dialects import glob
9+
10+
11+
@glob.dialect.register(key="pyqrack")
12+
class PyQrackMethods(interp.MethodTable):
13+
@interp.impl(glob.UGate)
14+
def ugate(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: glob.UGate):
15+
registers: ilist.IList[PyQrackReg, Any] = frame.get(stmt.registers)
16+
theta, phi, lam = (
17+
frame.get(stmt.theta),
18+
frame.get(stmt.phi),
19+
frame.get(stmt.lam),
20+
)
21+
22+
for qreg in registers:
23+
for qarg in qreg:
24+
if qarg.is_active():
25+
interp.memory.sim_reg.u(qarg.addr, theta, phi, lam)
26+
return ()

src/bloqade/pyqrack/reg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class CBitRef:
3333
pos: int
3434
"""The position of this bit in the classical register."""
3535

36-
def set_value(self, value: bool):
36+
def set_value(self, value: Measurement):
3737
self.ref[self.pos] = value
3838

3939
def get_value(self):
@@ -46,7 +46,7 @@ class QubitState(enum.Enum):
4646

4747

4848
@dataclass(frozen=True)
49-
class PyQrackReg(QReg):
49+
class PyQrackReg(QReg): # TODO: clean up implementation with list base class
5050
"""Simulation runtime value of a quantum register."""
5151

5252
size: int
@@ -72,6 +72,8 @@ def drop(self, pos: int):
7272
self.qubit_state[pos] = QubitState.Lost
7373

7474
def __getitem__(self, pos: int):
75+
if not 0 <= pos < self.size:
76+
raise IndexError("Qubit index out of bounds of register.")
7577
return PyQrackQubit(self, pos)
7678

7779

src/bloqade/qasm2/glob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
@wraps(glob.UGate)
1313
def u(
14-
theta: float, phi: float, lam: float, registers: ilist.IList[QReg, Any] | list
14+
registers: ilist.IList[QReg, Any] | list, theta: float, phi: float, lam: float
1515
) -> None:
1616
"""Apply a U gate to all qubits in the input registers.
1717

src/bloqade/qasm2/parse/lowering.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
import pathlib
13
from typing import Any
24
from dataclasses import field, dataclass
35

@@ -17,6 +19,118 @@ class QASM2(lowering.LoweringABC[ast.Node]):
1719
hint_show_lineno: bool = field(default=True, kw_only=True)
1820
stacktrace: bool = field(default=True, kw_only=True)
1921

22+
def loads(
23+
self,
24+
source: str,
25+
kernel_name: str,
26+
*,
27+
returns: str | None = None,
28+
globals: dict[str, Any] | None = None,
29+
file: str | None = None,
30+
lineno_offset: int = 0,
31+
col_offset: int = 0,
32+
compactify: bool = True,
33+
) -> ir.Method:
34+
from ..parse import loads
35+
36+
# TODO: add source info
37+
stmt = loads(source)
38+
39+
state = lowering.State(
40+
self,
41+
file=file,
42+
lineno_offset=lineno_offset,
43+
col_offset=col_offset,
44+
)
45+
with state.frame(
46+
[stmt],
47+
globals=globals,
48+
) as frame:
49+
try:
50+
self.visit(state, stmt)
51+
# append return statement with the return values
52+
if returns is not None:
53+
return_value = frame.get(returns)
54+
if return_value is None:
55+
raise lowering.BuildError(f"Cannot find return value {returns}")
56+
else:
57+
return_value = func.ConstantNone()
58+
59+
return_node = frame.push(func.Return(value_or_stmt=return_value))
60+
61+
except lowering.BuildError as e:
62+
hint = state.error_hint(
63+
e,
64+
max_lines=self.max_lines,
65+
indent=self.hint_indent,
66+
show_lineno=self.hint_show_lineno,
67+
)
68+
if self.stacktrace:
69+
raise Exception(
70+
f"{e.args[0]}\n\n{hint}",
71+
*e.args[1:],
72+
) from e
73+
else:
74+
e.args = (hint,)
75+
raise e
76+
77+
region = frame.curr_region
78+
79+
if compactify:
80+
from kirin.rewrite import Walk, CFGCompactify
81+
82+
Walk(CFGCompactify()).rewrite(region)
83+
84+
code = func.Function(
85+
sym_name=kernel_name,
86+
signature=func.Signature((), return_node.value.type),
87+
body=region,
88+
)
89+
90+
return ir.Method(
91+
mod=None,
92+
py_func=None,
93+
sym_name=kernel_name,
94+
arg_names=[],
95+
dialects=self.dialects,
96+
code=code,
97+
)
98+
99+
def loadfile(
100+
self,
101+
file: str | pathlib.Path,
102+
*,
103+
kernel_name: str | None = None,
104+
returns: str | None = None,
105+
globals: dict[str, Any] | None = None,
106+
lineno_offset: int = 0,
107+
col_offset: int = 0,
108+
compactify: bool = True,
109+
) -> ir.Method:
110+
if isinstance(file, str):
111+
file = pathlib.Path(*os.path.split(file))
112+
113+
if not file.is_file() or not file.name.endswith(".qasm"):
114+
raise ValueError("File must be a .qasm file")
115+
116+
kernel_name = (
117+
file.name.replace(".qasm", "") if kernel_name is None else kernel_name
118+
)
119+
120+
with file.open("r") as f:
121+
source = f.read()
122+
123+
return self.loads(
124+
source,
125+
kernel_name,
126+
returns=returns,
127+
globals=globals,
128+
file=str(file),
129+
lineno_offset=lineno_offset,
130+
col_offset=col_offset,
131+
compactify=compactify,
132+
)
133+
20134
def run(
21135
self,
22136
stmt: ast.Node,
@@ -85,6 +199,10 @@ def lower_literal(self, state: lowering.State[ast.Node], value) -> ir.SSAValue:
85199
stmt = expr.ConstInt(value=value)
86200
elif isinstance(value, float):
87201
stmt = expr.ConstFloat(value=value)
202+
else:
203+
raise lowering.BuildError(
204+
f"Expected value of type float or int, got {type(value)}."
205+
)
88206
state.current_frame.push(stmt)
89207
return stmt.result
90208

@@ -99,6 +217,8 @@ def visit_MainProgram(self, state: lowering.State[ast.Node], node: ast.MainProgr
99217
dialects = ["qasm2.core", "qasm2.uop", "qasm2.expr"]
100218
elif isinstance(node.header, ast.Kirin):
101219
dialects = node.header.dialects
220+
else:
221+
raise lowering.BuildError(f"Unexpected node header {node.header}")
102222

103223
for dialect in dialects:
104224
if dialect not in allowed:
@@ -295,6 +415,8 @@ def visit_Bit(self, state: lowering.State[ast.Node], node: ast.Bit):
295415
stmt = core.QRegGet(reg, addr.result)
296416
elif reg.type.is_subseteq(CRegType):
297417
stmt = core.CRegGet(reg, addr.result)
418+
else:
419+
raise lowering.BuildError(f"Unexpected register type {reg.type}")
298420
return state.current_frame.push(stmt).result
299421

300422
def visit_Call(self, state: lowering.State[ast.Node], node: ast.Call):

0 commit comments

Comments
 (0)