Skip to content

Commit a065746

Browse files
committed
Merge branch 'main' into john/fix-incorrect-squin-address-lattice
2 parents dacecbb + 4ebc574 commit a065746

File tree

12 files changed

+289
-15
lines changed

12 files changed

+289
-15
lines changed

src/bloqade/pyqrack/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414

1515
# NOTE: The following import is for registering the method tables
1616
from .noise import native as native
17-
from .qasm2 import uop as uop, core as core, parallel as parallel
17+
from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
1818
from .target import PyQrack as PyQrack

src/bloqade/pyqrack/qasm2/core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
@core.dialect.register(key="pyqrack")
1616
class PyQrackMethods(interp.MethodTable):
17-
1817
@interp.impl(core.QRegNew)
1918
def qreg_new(
2019
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegNew

src/bloqade/pyqrack/qasm2/glob.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Any
2+
3+
from kirin import interp
4+
from kirin.dialects import ilist
5+
6+
from bloqade.pyqrack.reg import PyQrackReg
7+
from bloqade.pyqrack.base import PyQrackInterpreter
8+
from bloqade.qasm2.dialects import glob
9+
10+
11+
@glob.dialect.register(key="pyqrack")
12+
class PyQrackMethods(interp.MethodTable):
13+
@interp.impl(glob.UGate)
14+
def ugate(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: glob.UGate):
15+
registers: ilist.IList[PyQrackReg, Any] = frame.get(stmt.registers)
16+
theta, phi, lam = (
17+
frame.get(stmt.theta),
18+
frame.get(stmt.phi),
19+
frame.get(stmt.lam),
20+
)
21+
22+
for qreg in registers:
23+
for qarg in qreg:
24+
if qarg.is_active():
25+
interp.memory.sim_reg.u(qarg.addr, theta, phi, lam)
26+
return ()

src/bloqade/pyqrack/reg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class CBitRef:
3333
pos: int
3434
"""The position of this bit in the classical register."""
3535

36-
def set_value(self, value: bool):
36+
def set_value(self, value: Measurement):
3737
self.ref[self.pos] = value
3838

3939
def get_value(self):
@@ -46,7 +46,7 @@ class QubitState(enum.Enum):
4646

4747

4848
@dataclass(frozen=True)
49-
class PyQrackReg(QReg):
49+
class PyQrackReg(QReg): # TODO: clean up implementation with list base class
5050
"""Simulation runtime value of a quantum register."""
5151

5252
size: int
@@ -72,6 +72,8 @@ def drop(self, pos: int):
7272
self.qubit_state[pos] = QubitState.Lost
7373

7474
def __getitem__(self, pos: int):
75+
if not 0 <= pos < self.size:
76+
raise IndexError("Qubit index out of bounds of register.")
7577
return PyQrackQubit(self, pos)
7678

7779

src/bloqade/qasm2/glob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
@wraps(glob.UGate)
1313
def u(
14-
theta: float, phi: float, lam: float, registers: ilist.IList[QReg, Any] | list
14+
registers: ilist.IList[QReg, Any] | list, theta: float, phi: float, lam: float
1515
) -> None:
1616
"""Apply a U gate to all qubits in the input registers.
1717

src/bloqade/qasm2/parse/lowering.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
import pathlib
13
from typing import Any
24
from dataclasses import field, dataclass
35

@@ -17,6 +19,118 @@ class QASM2(lowering.LoweringABC[ast.Node]):
1719
hint_show_lineno: bool = field(default=True, kw_only=True)
1820
stacktrace: bool = field(default=True, kw_only=True)
1921

22+
def loads(
23+
self,
24+
source: str,
25+
kernel_name: str,
26+
*,
27+
returns: str | None = None,
28+
globals: dict[str, Any] | None = None,
29+
file: str | None = None,
30+
lineno_offset: int = 0,
31+
col_offset: int = 0,
32+
compactify: bool = True,
33+
) -> ir.Method:
34+
from ..parse import loads
35+
36+
# TODO: add source info
37+
stmt = loads(source)
38+
39+
state = lowering.State(
40+
self,
41+
file=file,
42+
lineno_offset=lineno_offset,
43+
col_offset=col_offset,
44+
)
45+
with state.frame(
46+
[stmt],
47+
globals=globals,
48+
) as frame:
49+
try:
50+
self.visit(state, stmt)
51+
# append return statement with the return values
52+
if returns is not None:
53+
return_value = frame.get(returns)
54+
if return_value is None:
55+
raise lowering.BuildError(f"Cannot find return value {returns}")
56+
else:
57+
return_value = func.ConstantNone()
58+
59+
return_node = frame.push(func.Return(value_or_stmt=return_value))
60+
61+
except lowering.BuildError as e:
62+
hint = state.error_hint(
63+
e,
64+
max_lines=self.max_lines,
65+
indent=self.hint_indent,
66+
show_lineno=self.hint_show_lineno,
67+
)
68+
if self.stacktrace:
69+
raise Exception(
70+
f"{e.args[0]}\n\n{hint}",
71+
*e.args[1:],
72+
) from e
73+
else:
74+
e.args = (hint,)
75+
raise e
76+
77+
region = frame.curr_region
78+
79+
if compactify:
80+
from kirin.rewrite import Walk, CFGCompactify
81+
82+
Walk(CFGCompactify()).rewrite(region)
83+
84+
code = func.Function(
85+
sym_name=kernel_name,
86+
signature=func.Signature((), return_node.value.type),
87+
body=region,
88+
)
89+
90+
return ir.Method(
91+
mod=None,
92+
py_func=None,
93+
sym_name=kernel_name,
94+
arg_names=[],
95+
dialects=self.dialects,
96+
code=code,
97+
)
98+
99+
def loadfile(
100+
self,
101+
file: str | pathlib.Path,
102+
*,
103+
kernel_name: str | None = None,
104+
returns: str | None = None,
105+
globals: dict[str, Any] | None = None,
106+
lineno_offset: int = 0,
107+
col_offset: int = 0,
108+
compactify: bool = True,
109+
) -> ir.Method:
110+
if isinstance(file, str):
111+
file = pathlib.Path(*os.path.split(file))
112+
113+
if not file.is_file() or not file.name.endswith(".qasm"):
114+
raise ValueError("File must be a .qasm file")
115+
116+
kernel_name = (
117+
file.name.replace(".qasm", "") if kernel_name is None else kernel_name
118+
)
119+
120+
with file.open("r") as f:
121+
source = f.read()
122+
123+
return self.loads(
124+
source,
125+
kernel_name,
126+
returns=returns,
127+
globals=globals,
128+
file=str(file),
129+
lineno_offset=lineno_offset,
130+
col_offset=col_offset,
131+
compactify=compactify,
132+
)
133+
20134
def run(
21135
self,
22136
stmt: ast.Node,
@@ -85,6 +199,9 @@ def lower_literal(self, state: lowering.State[ast.Node], value) -> ir.SSAValue:
85199
stmt = expr.ConstInt(value=value)
86200
elif isinstance(value, float):
87201
stmt = expr.ConstFloat(value=value)
202+
else:
203+
raise lowering.BuildError(f"Unsupported literal type {type(value)}")
204+
88205
state.current_frame.push(stmt)
89206
return stmt.result
90207

src/bloqade/qasm2/rewrite/native_gates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def rewrite_cu3(self, node: uop.CU3) -> result.RewriteResult:
279279
lam = self._get_const_value(node.lam)
280280
phi = self._get_const_value(node.phi)
281281

282-
if not all((theta, phi, lam)):
282+
if theta is None or lam is None or phi is None:
283283
return result.RewriteResult()
284284

285285
# cirq.ControlledGate(u3(theta, lambda phi))

src/bloqade/squin/analysis/nsites/analysis.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ class NSitesAnalysis(Forward[Sites]):
1515
keys = ["op.nsites"]
1616
lattice = Sites
1717

18-
# Take a page from const prop in Kirin,
19-
# I can get the data I want from the SizedTrait
20-
# and go from there
18+
# Take a page from how constprop works in Kirin
2119

2220
## This gets called before the registry look up
2321
def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
@@ -37,7 +35,7 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
3735
# For when no implementation is found for the statement
3836
def eval_stmt_fallback(
3937
self, frame: ForwardFrame[Sites], stmt: ir.Statement
40-
) -> tuple[Sites, ...]: # some form of Shape will go back into the frame
38+
) -> tuple[Sites, ...]: # some form of Sites will go back into the frame
4139
return tuple(
4240
(
4341
self.lattice.top()

test/pyqrack/test_target.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,77 @@ def ghz():
3838
assert math.isclose(out[0].real, val, abs_tol=abs_tol)
3939
assert math.isclose(out[-1].real, val, abs_tol=abs_tol)
4040
assert all(math.isclose(ele.real, 0.0, abs_tol=abs_tol) for ele in out[1:-1])
41+
42+
43+
def test_target_glob():
44+
@qasm2.extended
45+
def global_h():
46+
q = qasm2.qreg(3)
47+
48+
# rotate around Y by pi/2, i.e. perform a hadamard
49+
qasm2.glob.u([q], math.pi / 2.0, 0, 0)
50+
51+
return q
52+
53+
target = PyQrack(3)
54+
q = target.run(global_h)
55+
56+
assert isinstance(q, reg.PyQrackReg)
57+
58+
out = q.sim_reg.out_ket()
59+
60+
# remove global phase introduced by pyqrack
61+
phase = out[0] / abs(out[0])
62+
out = [ele / phase for ele in out]
63+
64+
for element in out:
65+
assert math.isclose(element.real, 1 / math.sqrt(8), abs_tol=2.2e-7)
66+
assert math.isclose(element.imag, 0, abs_tol=2.2e-7)
67+
68+
@qasm2.extended
69+
def multiple_registers():
70+
q1 = qasm2.qreg(2)
71+
q2 = qasm2.qreg(2)
72+
q3 = qasm2.qreg(2)
73+
74+
# hadamard on first register
75+
qasm2.glob.u(
76+
[q1],
77+
math.pi / 2.0,
78+
0,
79+
0,
80+
)
81+
82+
# apply hadamard to the other two
83+
qasm2.glob.u(
84+
[q2, q3],
85+
math.pi / 2.0,
86+
0,
87+
0,
88+
)
89+
90+
# rotate all of them back down
91+
qasm2.glob.u(
92+
[q1, q2, q3],
93+
-math.pi / 2.0,
94+
0,
95+
0,
96+
)
97+
98+
return q1
99+
100+
target = PyQrack(6)
101+
q1 = target.run(multiple_registers)
102+
103+
assert isinstance(q1, reg.PyQrackReg)
104+
105+
out = q1.sim_reg.out_ket()
106+
107+
assert out[0] == 1
108+
for i in range(1, len(out)):
109+
assert out[i] == 0
110+
111+
assert True
112+
113+
114+
test_target_glob()

test/qasm2/passes/test_heuristic_noise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_global_noise():
265265
def test_method():
266266
q0 = qasm2.qreg(1)
267267
q1 = qasm2.qreg(1)
268-
glob.UGate([q0, q1], 0.1, 0.2, 0.3)
268+
qasm2.glob.u([q0, q1], 0.1, 0.2, 0.3)
269269

270270
px = 0.01
271271
py = 0.01

0 commit comments

Comments
 (0)