Skip to content

Commit 517722f

Browse files
authored
Kirin upgrade: fix naming and argument conventions in native.gate dialect. (#590)
As we're introducing a large breaking change it is probably a good time to update the statements in `native` to be more consistent with `squin`'s naming/call conventions. blocked by CI at the moment.
1 parent db67b8b commit 517722f

File tree

10 files changed

+40
-41
lines changed

10 files changed

+40
-41
lines changed

src/bloqade/native/_prelude.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
from bloqade import qubit
99

10-
from .dialects import gates
10+
from .dialects import gate
1111

1212

13-
@ir.dialect_group(structural_no_opt.union([gates, qubit]))
13+
@ir.dialect_group(structural_no_opt.union([gate, qubit]))
1414
def kernel(self):
1515
"""Compile a function to a native kernel."""
1616

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import stmts as stmts
2+
from ._dialect import dialect as dialect
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("native.gate")

src/bloqade/native/dialects/gates/_interface.py renamed to src/bloqade/native/dialects/gate/_interface.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212

1313
@lowering.wraps(CZ)
1414
def cz(
15-
ctrls: ilist.IList[qubit.Qubit, Len],
16-
qargs: ilist.IList[qubit.Qubit, Len],
15+
controls: ilist.IList[qubit.Qubit, Len],
16+
targets: ilist.IList[qubit.Qubit, Len],
1717
): ...
1818

1919

2020
@lowering.wraps(R)
2121
def r(
22-
inputs: ilist.IList[qubit.Qubit, typing.Any],
2322
axis_angle: float,
2423
rotation_angle: float,
24+
qubits: ilist.IList[qubit.Qubit, typing.Any],
2525
): ...
2626

2727

2828
@lowering.wraps(Rz)
2929
def rz(
30-
inputs: ilist.IList[qubit.Qubit, typing.Any],
3130
rotation_angle: float,
31+
qubits: ilist.IList[qubit.Qubit, typing.Any],
3232
): ...

src/bloqade/native/dialects/gates/stmts.py renamed to src/bloqade/native/dialects/gate/stmts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,20 @@
1212
@statement(dialect=dialect)
1313
class CZ(ir.Statement):
1414
traits = frozenset({lowering.FromPythonCall()})
15-
ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
16-
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
15+
controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
16+
targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
1717

1818

1919
@statement(dialect=dialect)
2020
class R(ir.Statement):
2121
traits = frozenset({lowering.FromPythonCall()})
22-
inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
2322
axis_angle: ir.SSAValue = info.argument(types.Float)
2423
rotation_angle: ir.SSAValue = info.argument(types.Float)
24+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
2525

2626

2727
@statement(dialect=dialect)
2828
class Rz(ir.Statement):
2929
traits = frozenset({lowering.FromPythonCall()})
30-
inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
3130
rotation_angle: ir.SSAValue = info.argument(types.Float)
31+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])

src/bloqade/native/dialects/gates/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/bloqade/native/dialects/gates/_dialect.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/bloqade/native/stdlib/broadcast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from bloqade import qubit
77
from bloqade.native._prelude import kernel
8-
from bloqade.native.dialects.gates import _interface as native
8+
from bloqade.native.dialects.gate import _interface as native
99

1010

1111
@kernel
@@ -29,7 +29,7 @@ def rx(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
2929
angle (float): Rotation angle in radians.
3030
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
3131
"""
32-
native.r(qubits, 0.0, _radian_to_turn(angle))
32+
native.r(0.0, _radian_to_turn(angle), qubits)
3333

3434

3535
@kernel
@@ -70,7 +70,7 @@ def ry(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
7070
angle (float): Rotation angle in radians.
7171
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
7272
"""
73-
native.r(qubits, 0.25, _radian_to_turn(angle))
73+
native.r(0.25, _radian_to_turn(angle), qubits)
7474

7575

7676
@kernel
@@ -111,7 +111,7 @@ def rz(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
111111
angle (float): Rotation angle in radians.
112112
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
113113
"""
114-
native.rz(qubits, _radian_to_turn(angle))
114+
native.rz(_radian_to_turn(angle), qubits)
115115

116116

117117
@kernel

src/bloqade/pyqrack/native.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,42 @@
77
from pyqrack import Pauli
88
from bloqade.pyqrack import PyQrackQubit
99
from bloqade.pyqrack.base import PyQrackInterpreter
10-
from bloqade.native.dialects import gates
10+
from bloqade.native.dialects.gate import stmts
1111

1212

13-
@gates.dialect.register(key="pyqrack")
13+
@stmts.dialect.register(key="pyqrack")
1414
class NativeMethods(interp.MethodTable):
1515

16-
@interp.impl(gates.CZ)
17-
def cz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.CZ):
18-
ctrls = frame.get_casted(stmt.ctrls, ilist.IList[PyQrackQubit, Any])
19-
qargs = frame.get_casted(stmt.qargs, ilist.IList[PyQrackQubit, Any])
16+
@interp.impl(stmts.CZ)
17+
def cz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.CZ):
18+
controls = frame.get_casted(stmt.controls, ilist.IList[PyQrackQubit, Any])
19+
targets = frame.get_casted(stmt.targets, ilist.IList[PyQrackQubit, Any])
2020

21-
for ctrl, qarg in zip(ctrls, qargs):
22-
if ctrl.is_active() and qarg.is_active():
23-
ctrl.sim_reg.mcz([ctrl.addr], qarg.addr)
21+
for ctrl, trgt in zip(controls, targets):
22+
if ctrl.is_active() and trgt.is_active():
23+
ctrl.sim_reg.mcz([ctrl.addr], trgt.addr)
2424

2525
return ()
2626

27-
@interp.impl(gates.R)
28-
def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.R):
29-
inputs = frame.get_casted(stmt.inputs, ilist.IList[PyQrackQubit, Any])
27+
@interp.impl(stmts.R)
28+
def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.R):
29+
qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any])
3030
rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float)
3131
axis_angle = 2 * math.pi * frame.get_casted(stmt.axis_angle, float)
32-
for qubit in inputs:
32+
for qubit in qubits:
3333
if qubit.is_active():
3434
qubit.sim_reg.r(Pauli.PauliZ, axis_angle, qubit.addr)
3535
qubit.sim_reg.r(Pauli.PauliX, rotation_angle, qubit.addr)
3636
qubit.sim_reg.r(Pauli.PauliZ, -axis_angle, qubit.addr)
3737

3838
return ()
3939

40-
@interp.impl(gates.Rz)
41-
def rz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.Rz):
42-
inputs = frame.get_casted(stmt.inputs, ilist.IList[PyQrackQubit, Any])
40+
@interp.impl(stmts.Rz)
41+
def rz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.Rz):
42+
qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any])
4343
rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float)
4444

45-
for qubit in inputs:
45+
for qubit in qubits:
4646
if qubit.is_active():
4747
qubit.sim_reg.r(Pauli.PauliZ, rotation_angle, qubit.addr)
4848

test/native/upstream/test_squin2native.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from bloqade import squin
66
from bloqade.squin import gate
77
from bloqade.pyqrack import StackMemorySimulator
8-
from bloqade.native.dialects import gates
8+
from bloqade.native.dialects import gate as native_gate
99
from bloqade.native.upstream import GateRule, SquinToNative
1010

1111

@@ -33,14 +33,14 @@ def main():
3333
new_main = SquinToNative().emit(main, no_raise=True)
3434

3535
new_callgraph = callgraph.CallGraph(new_main)
36-
# make sure all kernels have been converted to native gates
36+
# make sure all kernels have been converted to native gate
3737
all_kernels = (ker for kers in new_callgraph.defs.values() for ker in kers)
3838
for ker in all_kernels:
3939
assert gate.dialect not in ker.dialects
40-
assert gates.dialect in ker.dialects
40+
assert native_gate.dialect in ker.dialects
4141

4242
# test to make sure the statevectors are the same
43-
# before and after conversion to native gates
43+
# before and after conversion to native gate
4444
old_sv = np.asarray(StackMemorySimulator(min_qubits=n).state_vector(main))
4545
old_sv /= old_sv[imax := np.abs(old_sv).argmax()] / np.abs(old_sv[imax])
4646

0 commit comments

Comments
 (0)