Skip to content

Commit e558aa7

Browse files
authored
Add adjoint to assembly format (#1695)
**Context:** MLIR can represent properties in an attribute dictionary when they are not yet specified in the assembly format. We can definitely keep using that. But we could also move them to the assembly format. **Description of the Change:** Just a small proposal of changing the syntax here for having adjoint. **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** Fixes #1692
1 parent 22e3629 commit e558aa7

File tree

6 files changed

+51
-44
lines changed

6 files changed

+51
-44
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@
8080
* The utility function `EnsureFunctionDeclaration` is refactored into the `Utils` of the `Catalyst` dialect, instead of being duplicated in each individual dialect.
8181
[(#1683)](https://github.com/PennyLaneAI/catalyst/pull/1683)
8282

83+
* The assembly format for some MLIR operations now includes adjoint.
84+
[(#1695)](https://github.com/PennyLaneAI/catalyst/pull/1695)
85+
8386
* Improved the definition of `YieldOp` in the quantum dialect by removing `AnyTypeOf`
8487
[(#1696)](https://github.com/PennyLaneAI/catalyst/pull/1696)
8588

mlir/include/Quantum/IR/QuantumOps.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,14 @@ class UnitaryGate_Op<string mnemonic, list<Trait> traits = []> :
250250
}
251251

252252
bool getAdjointFlag() {
253-
return getAdjoint().has_value() ? getAdjoint().value() : false;
253+
return getAdjoint().value_or(false);
254254
}
255255
void setAdjointFlag(bool adjoint) {
256-
setAdjoint(adjoint);
256+
if (adjoint) {
257+
(*this)->setAttr("adjoint", mlir::UnitAttr::get(this->getContext()));
258+
} else {
259+
(*this)->removeAttr("adjoint");
260+
}
257261
};
258262

259263
mlir::ValueRange getCtrlValueOperands() {
@@ -446,7 +450,7 @@ def CustomOp : UnitaryGate_Op<"custom", [DifferentiableGate, NoMemoryEffect,
446450
];
447451

448452
let assemblyFormat = [{
449-
$gate_name `(` $params `)` $in_qubits attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
453+
$gate_name `(` $params `)` $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
450454
}];
451455

452456
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -541,7 +545,7 @@ def MultiRZOp : UnitaryGate_Op<"multirz", [DifferentiableGate, NoMemoryEffect,
541545
);
542546

543547
let assemblyFormat = [{
544-
`(` $theta `)` $in_qubits attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
548+
`(` $theta `)` $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
545549
}];
546550

547551
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -578,7 +582,7 @@ def QubitUnitaryOp : UnitaryGate_Op<"unitary", [ParametrizedGate, NoMemoryEffect
578582
);
579583

580584
let assemblyFormat = [{
581-
`(` $matrix `:` type($matrix) `)` $in_qubits attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
585+
`(` $matrix `:` type($matrix) `)` $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
582586
}];
583587

584588
let extraClassDeclaration = extraBaseClassDeclaration # [{

mlir/test/Mitigation/ZneFoldingAllFullTest.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,37 @@
2323
// CHECK: [[q0:%.+]] = quantum.extract [[qReg]][ 0] : !quantum.reg -> !quantum.bit
2424
// CHECK: [[q0_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q0_in:%.+]] = [[q0]]) -> (!quantum.bit) {
2525
// CHECK: [[q0_loop:%.+]] = quantum.custom "Hadamard"() [[q0_in]] : !quantum.bit
26-
// CHECK: [[q0_loop2:%.+]] = quantum.custom "Hadamard"() [[q0_loop]] {adjoint} : !quantum.bit
26+
// CHECK: [[q0_loop2:%.+]] = quantum.custom "Hadamard"() [[q0_loop]] adj : !quantum.bit
2727
// CHECK: scf.yield [[q0_loop2]] : !quantum.bit
2828
// CHECK: [[q0_out2:%.+]] = quantum.custom "Hadamard"() [[q0_out]] : !quantum.bit
2929
// CHECK: [[q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit
3030
// CHECK: [[q01_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q01_in1:%.+]] = [[q0_out2]], [[q01_in2:%.+]] = [[q1]]) -> (!quantum.bit, !quantum.bit) {
3131
// CHECK: [[q01_loop:%.+]]:2 = quantum.custom "CNOT"() [[q01_in1]], [[q01_in2]] : !quantum.bit, !quantum.bit
32-
// CHECK: [[q01_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q01_loop]]#0, [[q01_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit
32+
// CHECK: [[q01_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q01_loop]]#0, [[q01_loop]]#1 adj : !quantum.bit, !quantum.bit
3333
// CHECK: scf.yield [[q01_loop2]]#0, [[q01_loop2]]#1 : !quantum.bit, !quantum.bit
3434
// CHECK: [[q01_out2:%.+]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 : !quantum.bit, !quantum.bit
3535
// CHECK: [[q2:%.+]] = quantum.extract [[qReg]][ 2] : !quantum.reg -> !quantum.bit
3636
// CHECK: [[q12_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q12_in1:%.+]] = [[q01_out2]]#1, [[q12_in2:%.+]] = [[q2]]) -> (!quantum.bit, !quantum.bit) {
3737
// CHECK: [[q12_loop:%.+]]:2 = quantum.custom "CNOT"() [[q12_in1]], [[q12_in2]] : !quantum.bit, !quantum.bit
38-
// CHECK: [[q12_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q12_loop]]#0, [[q12_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit
38+
// CHECK: [[q12_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q12_loop]]#0, [[q12_loop]]#1 adj : !quantum.bit, !quantum.bit
3939
// CHECK: scf.yield [[q12_loop2]]#0, [[q12_loop2]]#1 : !quantum.bit, !quantum.bit
4040
// CHECK: [[q12_out2:%.+]]:2 = quantum.custom "CNOT"() [[q12_out]]#0, [[q12_out]]#1 : !quantum.bit, !quantum.bit
4141
// CHECK: [[q1_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q1_in:%.+]] = [[q12_out2]]#0) -> (!quantum.bit) {
4242
// CHECK: [[q1_loop:%.+]] = quantum.custom "T"() [[q1_in]] : !quantum.bit
43-
// CHECK: [[q1_loop2:%.+]] = quantum.custom "T"() [[q1_loop]] {adjoint} : !quantum.bit
43+
// CHECK: [[q1_loop2:%.+]] = quantum.custom "T"() [[q1_loop]] adj : !quantum.bit
4444
// CHECK: scf.yield [[q1_loop2]] : !quantum.bit
4545
// CHECK: [[q1_out2:%.+]] = quantum.custom "T"() [[q1_out]] : !quantum.bit
4646
// CHECK: [[q3:%.+]] = quantum.extract [[qReg]][ 3] : !quantum.reg -> !quantum.bit
4747
// CHECK: [[q23_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q23_in1:%.+]] = [[q12_out2]]#1, [[q23_in2:%.+]] = [[q3]]) -> (!quantum.bit, !quantum.bit) {
4848
// CHECK: [[q23_loop:%.+]]:2 = quantum.custom "CNOT"() [[q23_in1]], [[q23_in2]] : !quantum.bit, !quantum.bit
49-
// CHECK: [[q23_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q23_loop]]#0, [[q23_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit
49+
// CHECK: [[q23_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q23_loop]]#0, [[q23_loop]]#1 adj : !quantum.bit, !quantum.bit
5050
// CHECK: scf.yield [[q23_loop2]]#0, [[q23_loop2]]#1 : !quantum.bit, !quantum.bit
5151
// CHECK: [[q23_out2:%.+]]:2 = quantum.custom "CNOT"() [[q23_out]]#0, [[q23_out]]#1 : !quantum.bit, !quantum.bit
5252
// CHECK: [[q3_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q3_in:%.+]] = [[q23_out2]]#1) -> (!quantum.bit) {
53-
// CHECK: [[q3_loop:%.+]] = quantum.custom "T"() [[q3_in]] {adjoint} : !quantum.bit
53+
// CHECK: [[q3_loop:%.+]] = quantum.custom "T"() [[q3_in]] adj : !quantum.bit
5454
// CHECK: [[q3_loop2:%.+]] = quantum.custom "T"() [[q3_loop]] : !quantum.bit
5555
// CHECK: scf.yield [[q3_loop2]] : !quantum.bit
56-
// CHECK: [[q3_out2:%.+]] = quantum.custom "T"() [[q3_out]] {adjoint} : !quantum.bit
56+
// CHECK: [[q3_out2:%.+]] = quantum.custom "T"() [[q3_out]] adj : !quantum.bit
5757

5858

5959
//CHECK-LABEL: func.func @circuit() -> tensor<f64> attributes {qnode} {
@@ -70,7 +70,7 @@ func.func @circuit() -> tensor<f64> attributes {qnode} {
7070
%out_qubits_2 = quantum.custom "T"() %out_qubits_1#0 : !quantum.bit
7171
%4 = quantum.extract %0[ 3] : !quantum.reg -> !quantum.bit
7272
%out_qubits_3:2 = quantum.custom "CNOT"() %out_qubits_1#1, %4 : !quantum.bit, !quantum.bit
73-
%out_qubits_4 = quantum.custom "T"() %out_qubits_3#1 {adjoint} : !quantum.bit
73+
%out_qubits_4 = quantum.custom "T"() %out_qubits_3#1 adj : !quantum.bit
7474
%5 = quantum.namedobs %out_qubits_0#0[ PauliY] : !quantum.obs
7575
%6 = quantum.expval %5 {shots = 5 : i64} : f64
7676
%from_elements = tensor.from_elements %6 : tensor<f64>

mlir/test/Mitigation/ZneFoldingAllMinimalTest.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
// CHECK: [[q0:%.+]] = quantum.extract [[qReg]][ 0] : !quantum.reg -> !quantum.bit
2424
// CHECK: [[q0_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q0_in:%.+]] = [[q0]]) -> (!quantum.bit) {
2525
// CHECK: [[q0_loop:%.+]] = quantum.custom "Hadamard"() [[q0_in]] : !quantum.bit
26-
// CHECK: [[q0_loop2:%.+]] = quantum.custom "Hadamard"() [[q0_loop]] {adjoint} : !quantum.bit
26+
// CHECK: [[q0_loop2:%.+]] = quantum.custom "Hadamard"() [[q0_loop]] adj : !quantum.bit
2727
// CHECK: scf.yield [[q0_loop2]] : !quantum.bit
2828
// CHECK: [[q0_out2:%.+]] = quantum.custom "Hadamard"() [[q0_out]] : !quantum.bit
2929
// CHECK: [[q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit
3030
// CHECK: [[q01_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q01_in1:%.+]] = [[q0_out2]], [[q01_in2:%.+]] = [[q1]]) -> (!quantum.bit, !quantum.bit) {
3131
// CHECK: [[q01_loop:%.+]]:2 = quantum.custom "CNOT"() [[q01_in1]], [[q01_in2]] : !quantum.bit, !quantum.bit
32-
// CHECK: [[q01_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q01_loop]]#0, [[q01_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit
32+
// CHECK: [[q01_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q01_loop]]#0, [[q01_loop]]#1 adj : !quantum.bit, !quantum.bit
3333
// CHECK: scf.yield [[q01_loop2]]#0, [[q01_loop2]]#1 : !quantum.bit, !quantum.bit
3434
// CHECK: [[q01_out2:%.+]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 : !quantum.bit, !quantum.bit
3535
// CHECK: [[q2:%.+]] = quantum.namedobs [[q01_out2]]#0[ PauliY] : !quantum.obs

mlir/test/Quantum/AdjointTest.mlir

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ func.func private @workflow_plain() -> tensor<4xcomplex<f64>> attributes {} {
2727
%3 = quantum.insert %0[%c0_i64], %2 : !quantum.reg, !quantum.bit
2828
%4 = quantum.adjoint(%3) : !quantum.reg {
2929
// CHECK: PauliZ
30-
// CHECK-SAME: adjoint
30+
// CHECK-SAME: adj
3131
// CHECK: PauliY
32-
// CHECK-SAME: adjoint
32+
// CHECK-SAME: adj
3333
// CHECK: PauliX
34-
// CHECK-SAME: adjoint
34+
// CHECK-SAME: adj
3535
^bb0(%arg0: !quantum.reg):
3636
%10 = quantum.extract %arg0[%c0_i64] : !quantum.reg -> !quantum.bit
3737
%11 = quantum.custom "PauliX"() %10 : !quantum.bit
@@ -57,13 +57,13 @@ func.func private @workflow_plain() -> tensor<4xcomplex<f64>> attributes {} {
5757
// CHECK: OpC
5858
// CHECK: OpD
5959
// CHECK: OpF
60-
// CHECK-SAME: adjoint
60+
// CHECK-SAME: adj
6161
// CHECK: OpE
62-
// CHECK-SAME: adjoint
62+
// CHECK-SAME: adj
6363
// CHECK: OpB
64-
// CHECK-SAME: adjoint
64+
// CHECK-SAME: adj
6565
// CHECK: OpA
66-
// CHECK-SAME: adjoint
66+
// CHECK-SAME: adj
6767
func.func private @workflow_nested() -> tensor<4xcomplex<f64>> attributes {} {
6868
%c1_i64 = arith.constant 1 : i64
6969
%c0_i64 = arith.constant 0 : i64
@@ -182,17 +182,17 @@ func.func private @circuit(%arg0: f64, %arg1: !quantum.reg) -> !quantum.reg {
182182
}
183183

184184
// CHECK: func.func private @circuit.adjoint(%arg0: f64, %arg1: !quantum.reg) -> !quantum.reg {
185-
// CHECK: quantum.custom "PauliZ"() {{%.+}} {adjoint} : !quantum.bit
186-
// CHECK: quantum.custom "RX"({{%.+}}) {{%.+}} {adjoint} : !quantum.bit
187-
// CHECK: quantum.custom "PauliX"() {{%.+}} {adjoint} : !quantum.bit
185+
// CHECK: quantum.custom "PauliZ"() {{%.+}} adj : !quantum.bit
186+
// CHECK: quantum.custom "RX"({{%.+}}) {{%.+}} adj : !quantum.bit
187+
// CHECK: quantum.custom "PauliX"() {{%.+}} adj : !quantum.bit
188188

189189
// CHECK: func.func private @workflow_adjoint(%arg0: f64) -> tensor<4xcomplex<f64>> {
190190
// CHECK: quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
191-
// CHECK: quantum.custom "RY"({{%.+}}) {{%.+}} {adjoint} : !quantum.bit
191+
// CHECK: quantum.custom "RY"({{%.+}}) {{%.+}} adj : !quantum.bit
192192
// CHECK: call @circuit.adjoint(%arg0, {{%.+}}) : (f64, !quantum.reg) -> !quantum.reg
193-
// CHECK: quantum.custom "PauliZ"() {{%.+}} {adjoint} : !quantum.bit
194-
// CHECK: quantum.custom "RX"({{%.+}}) {{%.+}} {adjoint} : !quantum.bit
195-
// CHECK: quantum.custom "PauliX"() {{%.+}} {adjoint} : !quantum.bit
193+
// CHECK: quantum.custom "PauliZ"() {{%.+}} adj : !quantum.bit
194+
// CHECK: quantum.custom "RX"({{%.+}}) {{%.+}} adj : !quantum.bit
195+
// CHECK: quantum.custom "PauliX"() {{%.+}} adj : !quantum.bit
196196
// CHECK: quantum.custom "RY"({{%.+}}) {{%.+}} : !quantum.bit
197197

198198
func.func private @workflow_adjoint(%arg0: f64) -> tensor<4xcomplex<f64>> attributes {} {

0 commit comments

Comments
 (0)