Skip to content

Commit 9f45e44

Browse files
committed
merging main
2 parents 51f27b5 + 4dbe384 commit 9f45e44

File tree

8 files changed

+94
-50
lines changed

8 files changed

+94
-50
lines changed

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ dependencies = [
1717
"rich>=13.9.4",
1818
"pydantic>=1.3.0,<2.11.0",
1919
"pandas>=2.2.3",
20+
"pyqrack>=1.38.2 ; sys_platform == 'darwin'",
21+
"pyqrack-cpu>=1.38.2 ; sys_platform != 'darwin'",
2022
]
2123

2224
[project.optional-dependencies]
@@ -36,6 +38,12 @@ cirq = [
3638
"cirq-core>=1.4.1",
3739
"cirq-core[contrib]>=1.4.1",
3840
]
41+
pyqrack-opencl = [
42+
"pyqrack>=1.38.2 ; sys_platform != 'darwin'",
43+
]
44+
pyqrack-cuda = [
45+
"pyqrack-cuda>=1.38.2",
46+
]
3947

4048
[build-system]
4149
requires = ["hatchling"]

src/bloqade/noise/native/stmts.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
from kirin import ir, types, lowering
24
from kirin.decl import info, statement
35
from kirin.dialects import ilist
@@ -7,25 +9,44 @@
79
from ._dialect import dialect
810

911

10-
@statement(dialect=dialect)
11-
class PauliChannel(ir.Statement):
12-
12+
@statement
13+
class NativeNoiseStmt(ir.Statement):
1314
traits = frozenset({lowering.FromPythonCall()})
1415

16+
@property
17+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
18+
raise NotImplementedError(f"Override the method in {type(self).__name__}")
19+
20+
def check(self):
21+
for probs in self.probabilities:
22+
self.check_probability(sum(probs))
23+
for p in probs:
24+
self.check_probability(p)
25+
26+
def check_probability(self, p: float):
27+
if not 0 <= p <= 1:
28+
raise ValueError(
29+
f"Invalid noise probability encountered in {type(self).__name__}: {p}"
30+
)
31+
32+
33+
@statement(dialect=dialect)
34+
class PauliChannel(NativeNoiseStmt):
1535
px: float = info.attribute(types.Float)
1636
py: float = info.attribute(types.Float)
1737
pz: float = info.attribute(types.Float)
1838
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType])
1939

40+
@property
41+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
42+
return ((self.px, self.py, self.pz),)
43+
2044

2145
NumQubits = types.TypeVar("NumQubits")
2246

2347

2448
@statement(dialect=dialect)
25-
class CZPauliChannel(ir.Statement):
26-
27-
traits = frozenset({lowering.FromPythonCall()})
28-
49+
class CZPauliChannel(NativeNoiseStmt):
2950
paired: bool = info.attribute(types.Bool)
3051
px_ctrl: float = info.attribute(types.Float)
3152
py_ctrl: float = info.attribute(types.Float)
@@ -36,11 +57,19 @@ class CZPauliChannel(ir.Statement):
3657
ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, NumQubits])
3758
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, NumQubits])
3859

60+
@property
61+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
62+
return (
63+
(self.px_ctrl, self.py_ctrl, self.pz_ctrl),
64+
(self.px_qarg, self.py_qarg, self.pz_qarg),
65+
)
3966

40-
@statement(dialect=dialect)
41-
class AtomLossChannel(ir.Statement):
42-
43-
traits = frozenset({lowering.FromPythonCall()})
4467

68+
@statement(dialect=dialect)
69+
class AtomLossChannel(NativeNoiseStmt):
4570
prob: float = info.attribute(types.Float)
4671
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType])
72+
73+
@property
74+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
75+
return ((self.prob,),)

src/bloqade/qasm2/dialects/expr/_emit.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Literal
22

33
from kirin import interp
4-
from kirin.emit.exceptions import EmitError
54

65
from bloqade.qasm2.parse import ast
76
from bloqade.qasm2.types import QubitType
@@ -19,17 +18,18 @@ def emit_func(
1918
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.GateFunction
2019
):
2120

22-
args, cparams, qparams = [], [], []
23-
for arg in stmt.body.blocks[0].args[1:]:
24-
name = frame.get_typed(arg, ast.Name)
25-
args.append(name)
26-
if not isinstance(name, ast.Name):
27-
raise EmitError("expected ast.Name")
21+
args: list[ast.Node] = []
22+
cparams, qparams = [], []
23+
for arg in stmt.body.blocks[0].args:
24+
assert arg.name is not None
25+
26+
args.append(ast.Name(id=arg.name))
2827
if arg.type.is_subseteq(QubitType):
29-
qparams.append(name.id)
28+
qparams.append(arg.name)
3029
else:
31-
cparams.append(name.id)
32-
emit.run_ssacfg_region(frame, stmt.body, args)
30+
cparams.append(arg.name)
31+
32+
emit.run_ssacfg_region(frame, stmt.body, tuple(args))
3333
emit.output = ast.Gate(
3434
name=stmt.sym_name,
3535
cparams=cparams,

src/bloqade/qasm2/emit/gate.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,3 @@ def emit_err(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
8686
@interp.impl(func.ConstantNone)
8787
def ignore(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
8888
return ()
89-
90-
@interp.impl(func.Function)
91-
def emit_func(
92-
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: func.Function
93-
):
94-
args_ssa = stmt.args
95-
print(stmt.args)
96-
emit.run_ssacfg_region(frame, stmt.body, frame.get_values(args_ssa))
97-
98-
cparams, qparams = [], []
99-
for arg in args_ssa:
100-
if arg.type.is_subseteq(QubitType):
101-
qparams.append(frame.get(arg))
102-
else:
103-
cparams.append(frame.get(arg))
104-
105-
emit.output = ast.Gate(stmt.sym_name, cparams, qparams, frame.body)
106-
return ()

src/bloqade/qasm2/emit/target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
101101

102102
Py2QASM(entry.dialects)(entry)
103103
target_main = EmitQASM2Main(self.main_target)
104-
target_main.run(entry, tuple(ast.Name(name) for name in entry.arg_names[1:]))
104+
target_main.run(entry, ())
105105

106106
main_program = target_main.output
107107
assert main_program is not None, f"failed to emit {entry.sym_name}"

src/bloqade/squin/qubit.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def measure(qubit: ilist.IList[Qubit, Any] | list[Qubit]) -> list[bool]: ...
117117

118118
@wraps(MeasureAny)
119119
def measure(qubit: Any) -> Any:
120-
"""Measure a qubit or qubits in the list."
120+
"""Measure a qubit or qubits in the list.
121121
122122
Args:
123123
qubits: The list of qubits to measure.
@@ -128,6 +128,22 @@ def measure(qubit: Any) -> Any:
128128
...
129129

130130

131+
@wraps(Broadcast)
132+
def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
133+
"""Broadcast and apply an operator to a list of qubits. For example, an operator
134+
that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0.
135+
136+
Args:
137+
operator: The operator to broadcast and apply.
138+
qubits: The list of qubits to broadcast and apply the operator to. The size of the list
139+
must be inferable and match the number of qubits expected by the operator.
140+
141+
Returns:
142+
None
143+
"""
144+
...
145+
146+
131147
@wraps(MeasureAndReset)
132148
def measure_and_reset(qubits: ilist.IList[Qubit, Any]) -> int:
133149
"""Measure the qubits in the list and reset them."

test/pyqrack/runtime/noise/native/test_pauli.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest.mock import Mock, call
22

3+
import pytest
34
from kirin import ir
45

56
from bloqade import qasm2
@@ -41,6 +42,23 @@ def test_atom_loss():
4142
sim_reg.assert_has_calls([call.y(0)])
4243

4344

45+
@pytest.mark.xfail
46+
def test_pauli_probs_check():
47+
@simulation
48+
def test_atom_loss():
49+
q = qasm2.qreg(2)
50+
native.pauli_channel(
51+
[q[0]],
52+
px=0.1,
53+
py=0.4,
54+
pz=1.3,
55+
)
56+
return q
57+
58+
with pytest.raises(ir.ValidationError):
59+
test_atom_loss.verify()
60+
61+
4462
def test_cz_pauli_channel_false():
4563
@simulation
4664
def test_atom_loss():
@@ -122,9 +140,3 @@ def test_atom_loss():
122140
sim_reg = run_mock(test_atom_loss, rng_state)
123141

124142
sim_reg.assert_has_calls([call.y(0), call.x(1), call.mcz([0], 1)])
125-
126-
127-
if __name__ == "__main__":
128-
test_pauli_channel()
129-
test_cz_pauli_channel_false()
130-
test_cz_pauli_channel_true()

test/qasm2/emit/test_qasm2.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
import pytest
2-
31
from bloqade import qasm2
42

53

6-
@pytest.mark.skip(reason="broken gate emit!")
74
def test_qasm2_custom_gate():
85
@qasm2.gate
96
def custom_gate(a: qasm2.Qubit, b: qasm2.Qubit):

0 commit comments

Comments
 (0)