Skip to content

Commit bb80136

Browse files
authored
Add a graph_state_prep operation to the xDSL MBQC dialect (#8059)
**Context:** We recently added the `mbqc.graph_state_prep` operation to Catalyst in PennyLaneAI/catalyst#1965. We now need the equivalent operation defined in the MBQC dialect of the unified Python compiler. **Description of the Change:** Adds the `mbqc.graph_state_prep` operation to the xDSL MBQC dialect. It's implementation is equivalent to the one already defined in Catalyst. [sc-97308]
1 parent 3b457d1 commit bb80136

File tree

3 files changed

+88
-7
lines changed

3 files changed

+88
-7
lines changed

doc/releases/changelog-dev.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@
192192
* The `mbqc` xDSL dialect has been added to the Python compiler, which is used to represent
193193
measurement-based quantum-computing instructions in the xDSL framework.
194194
[(#7815)](https://github.com/PennyLaneAI/pennylane/pull/7815)
195+
[(#8059)](https://github.com/PennyLaneAI/pennylane/pull/8059)
195196

196197
* The `AllocQubitOp` and `DeallocQubitOp` operations have been added to the `Quantum` dialect in the
197198
Python compiler.

pennylane/compiler/python_compiler/dialects/mbqc.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,24 @@
1919
2020
It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) starting from the
2121
catalyst/mlir/include/MBQC/IR/MBQCDialect.td file in the catalyst repository.
22+
23+
For detailed documentation on the operations contained in this dialect, please refer to the MBQC
24+
dialect documentation in Catalyst.
2225
"""
2326

2427
from typing import TypeAlias
2528

26-
from xdsl.dialects.builtin import I32, Float64Type, IntegerAttr, IntegerType
29+
from xdsl.dialects.builtin import (
30+
I32,
31+
AnyAttr,
32+
Float64Type,
33+
IntegerAttr,
34+
IntegerType,
35+
MemRefType,
36+
StringAttr,
37+
TensorType,
38+
i1,
39+
)
2740
from xdsl.ir import Dialect, EnumAttribute, Operation, SpacedOpaqueSyntaxAttribute, SSAValue
2841
from xdsl.irdl import (
2942
IRDLOperation,
@@ -37,7 +50,8 @@
3750
from xdsl.utils.exceptions import VerifyException
3851
from xdsl.utils.str_enum import StrEnum # StrEnum is standard in Python>=3.11
3952

40-
from .quantum import QubitType
53+
from ..xdsl_extras import MemRefRankConstraint, TensorRankConstraint
54+
from .quantum import QubitType, QuregType
4155

4256
QubitSSAValue: TypeAlias = SSAValue[QubitType]
4357

@@ -69,6 +83,10 @@ class MeasureInBasisOp(IRDLOperation):
6983

7084
name = "mbqc.measure_in_basis"
7185

86+
assembly_format = """
87+
`[` $plane `,` $angle `]` $in_qubit (`postselect` $postselect^)? attr-dict `:` type(results)
88+
"""
89+
7290
in_qubit = operand_def(QubitType)
7391

7492
plane = prop_def(MeasurementPlaneAttr)
@@ -81,10 +99,6 @@ class MeasureInBasisOp(IRDLOperation):
8199

82100
out_qubit = result_def(QubitType)
83101

84-
assembly_format = """
85-
`[` $plane `,` $angle `]` $in_qubit (`postselect` $postselect^)? attr-dict `:` type(results)
86-
"""
87-
88102
def __init__(
89103
self,
90104
in_qubit: QubitSSAValue | Operation,
@@ -115,10 +129,54 @@ def verify_(self):
115129
raise VerifyException("'postselect' must be 0 or 1.")
116130

117131

132+
@irdl_op_definition
133+
class GraphStatePrepOp(IRDLOperation):
134+
"""Allocate resources for a new graph state."""
135+
136+
# pylint: disable=too-few-public-methods
137+
138+
name = "mbqc.graph_state_prep"
139+
140+
assembly_format = """
141+
`(` $adj_matrix `:` type($adj_matrix) `)` `[` `init` $init_op `,` `entangle` $entangle_op `]` attr-dict `:` type(results)
142+
"""
143+
144+
adj_matrix = operand_def(
145+
(TensorType.constr(i1) & TensorRankConstraint(1))
146+
| (MemRefType.constr(i1) & MemRefRankConstraint(1))
147+
)
148+
149+
init_op = prop_def(StringAttr)
150+
151+
entangle_op = prop_def(StringAttr)
152+
153+
qreg = result_def(QuregType)
154+
155+
def __init__(
156+
self, adj_matrix: AnyAttr, init_op: str | StringAttr, entangle_op: str | StringAttr
157+
):
158+
if isinstance(init_op, str):
159+
init_op = StringAttr(data=init_op)
160+
161+
if isinstance(entangle_op, str):
162+
entangle_op = StringAttr(data=entangle_op)
163+
164+
properties = {"init_op": init_op, "entangle_op": entangle_op}
165+
166+
qreg = QuregType()
167+
168+
super().__init__(
169+
operands=(adj_matrix,),
170+
result_types=(qreg,),
171+
properties=properties,
172+
)
173+
174+
118175
MBQC = Dialect(
119176
"mbqc",
120177
[
121178
MeasureInBasisOp,
179+
GraphStatePrepOp,
122180
],
123181
[
124182
MeasurementPlaneAttr,

tests/python_compiler/dialects/test_mbqc_dialect.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
pytestmark = pytest.mark.external
2525

26-
from xdsl.dialects import builtin, test
26+
from xdsl.dialects import arith, builtin, test
2727
from xdsl.utils.exceptions import VerifyException
2828

2929
from pennylane.compiler.python_compiler.dialects import Quantum, mbqc
@@ -33,6 +33,7 @@
3333

3434
expected_ops_names = {
3535
"MeasureInBasisOp": "mbqc.measure_in_basis",
36+
"GraphStatePrepOp": "mbqc.graph_state_prep",
3637
}
3738

3839
expected_attrs_names = {
@@ -90,6 +91,11 @@ def test_assembly_format(run_filecheck):
9091
// COM: Check generic format
9192
// CHECK: {{%.+}}, {{%.+}} = mbqc.measure_in_basis[XY, [[angle]]] [[qubit]] postselect 0 : i1, !quantum.bit
9293
%res:2 = "mbqc.measure_in_basis"(%qubit, %angle) <{plane = #mbqc<measurement_plane XY>, postselect = 0 : i32}> : (!quantum.bit, f64) -> (i1, !quantum.bit)
94+
95+
// CHECK: [[adj_matrix:%.+]] = arith.constant {{.*}} : tensor<6xi1>
96+
// CHECK: [[graph_reg:%.+]] = mbqc.graph_state_prep{{\s*}}([[adj_matrix]] : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg
97+
%adj_matrix = arith.constant dense<[1, 0, 1, 0, 0, 1]> : tensor<6xi1>
98+
%graph_reg = mbqc.graph_state_prep (%adj_matrix : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg
9399
"""
94100

95101
run_filecheck(program)
@@ -154,3 +160,19 @@ def test_invalid_postselect_raises_on_verify(self, postselect):
154160

155161
with pytest.raises(VerifyException, match="'postselect' must be 0 or 1"):
156162
measure_in_basis_op.verify_()
163+
164+
@pytest.mark.parametrize("init_op", ["Hadamard", builtin.StringAttr(data="Hadamard")])
165+
@pytest.mark.parametrize("entangle_op", ["CZ", builtin.StringAttr(data="CZ")])
166+
def test_graph_state_prep_instantiation(self, init_op, entangle_op):
167+
"""Test the instantiation of a mbqc.graph_state_prep op."""
168+
adj_matrix = [1, 0, 1, 0, 0, 1]
169+
adj_matrix_op = arith.ConstantOp(
170+
builtin.DenseIntOrFPElementsAttr.from_list(
171+
type=builtin.TensorType(builtin.IntegerType(1), shape=(6,)), data=adj_matrix
172+
)
173+
)
174+
graph_state_prep_op = mbqc.GraphStatePrepOp(adj_matrix_op.result, init_op, entangle_op)
175+
176+
assert graph_state_prep_op.adj_matrix == adj_matrix_op.result
177+
assert graph_state_prep_op.init_op == builtin.StringAttr(data="Hadamard")
178+
assert graph_state_prep_op.entangle_op == builtin.StringAttr(data="CZ")

0 commit comments

Comments
 (0)