Skip to content

Commit 26fe16e

Browse files
authored
MergeCircuit multiple barriers (#166)
Improves the MergeCircuitPass by enabling the pass to search past multiple barriers for a call_circuit to merge with in the barrier qubits does not overlap with the first call_circuit.
1 parent e7cefb0 commit 26fe16e

File tree

2 files changed

+60
-18
lines changed

2 files changed

+60
-18
lines changed

lib/Dialect/QUIR/Transforms/MergeCircuits.cpp

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,19 @@
2020
//===----------------------------------------------------------------------===//
2121

2222
#include "Dialect/QUIR/Transforms/MergeCircuits.h"
23-
24-
#include "Dialect/OQ3/IR/OQ3Ops.h"
2523
#include "Dialect/QUIR/IR/QUIROps.h"
2624
#include "Dialect/QUIR/Utils/Utils.h"
2725

2826
#include "mlir/IR/BlockAndValueMapping.h"
29-
#include "mlir/IR/BuiltinOps.h"
3027
#include "mlir/IR/Operation.h"
3128
#include "mlir/Support/LogicalResult.h"
32-
#include "mlir/Transforms/DialectConversion.h"
3329
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3430

31+
#include "llvm/ADT/None.h"
3532
#include "llvm/ADT/SmallVector.h"
3633
#include "llvm/ADT/StringRef.h"
3734

38-
#include <llvm/ADT/None.h>
39-
#include <mlir/Dialect/Complex/IR/Complex.h>
35+
#include <algorithm>
4036
#include <vector>
4137

4238
using namespace mlir;
@@ -61,14 +57,33 @@ struct CircuitAndCircuitPattern : public OpRewritePattern<CallCircuitOp> {
6157
LogicalResult matchAndRewrite(CallCircuitOp callCircuitOp,
6258
PatternRewriter &rewriter) const override {
6359

64-
// get next quantum op and check if its a CallCircuitOp
65-
llvm::Optional<Operation *> secondOp = nextQuantumOpOrNull(callCircuitOp);
66-
if (!secondOp)
67-
return failure();
60+
// find next CallCircuitOp or fail
61+
Operation *searchOp = callCircuitOp.getOperation();
62+
llvm::Optional<Operation *> secondOp;
63+
CallCircuitOp nextCallCircuitOp;
64+
while (true) {
65+
secondOp = nextQuantumOpOrNull(searchOp);
66+
if (!secondOp)
67+
return failure();
68+
69+
nextCallCircuitOp = dyn_cast<CallCircuitOp>(*secondOp);
70+
if (nextCallCircuitOp)
71+
break;
72+
73+
// check for overlapping BarrierOp and fail if found
74+
auto barrierOp = dyn_cast<BarrierOp>(*secondOp);
75+
if (barrierOp) {
76+
std::set<uint> firstQubits =
77+
QubitOpInterface::getOperatedQubits(callCircuitOp);
78+
std::set<uint> secondQubits =
79+
QubitOpInterface::getOperatedQubits(barrierOp);
80+
81+
if (QubitOpInterface::qubitSetsOverlap(firstQubits, secondQubits))
82+
return failure();
83+
}
6884

69-
auto nextCallCircuitOp = dyn_cast<CallCircuitOp>(*secondOp);
70-
if (!nextCallCircuitOp)
71-
return failure();
85+
searchOp = *secondOp;
86+
}
7287

7388
Operation *insertOp = *secondOp;
7489

@@ -126,7 +141,6 @@ struct CircuitAndCircuitPattern : public OpRewritePattern<CallCircuitOp> {
126141

127142
return MergeCircuitsPass::mergeCallCircuits(rewriter, callCircuitOp,
128143
nextCallCircuitOp);
129-
130144
} // matchAndRewrite
131145
}; // struct CircuitAndCircuitPattern
132146

test/Dialect/QUIR/Transforms/merge-circuits.mlir

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ module {
4747
%1 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
4848
quir.return %0, %1: i1, i1
4949
}
50+
quir.circuit @circuit_7(%arg0: !quir.qubit<1>) -> (i1, i1) {
51+
%0 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
52+
%1 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
53+
quir.return %0, %1: i1, i1
54+
}
55+
quir.circuit @circuit_8(%arg0: !quir.qubit<1>) -> (i1, i1) {
56+
%0 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
57+
%1 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
58+
quir.return %0, %1: i1, i1
59+
}
5060
// CHECK: @circuit_0_q0_circuit_1_q1(%arg0: !quir.qubit<1>
5161
// CHECK: %0 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
5262
// CHECK: %1 = quir.measure(%arg1) : (!quir.qubit<1>) -> i1
@@ -55,14 +65,16 @@ module {
5565
func @main() -> i32 {
5666
%0 = quir.declare_qubit {id = 0 : i32} : !quir.qubit<1>
5767
%1 = quir.declare_qubit {id = 1 : i32} : !quir.qubit<1>
58-
%200 = quir.declare_qubit {id = 1 : i32} : !quir.qubit<1>
68+
%200 = quir.declare_qubit {id = 2 : i32} : !quir.qubit<1>
69+
%201 = quir.declare_qubit {id = 3 : i32} : !quir.qubit<1>
70+
%202 = quir.declare_qubit {id = 4 : i32} : !quir.qubit<1>
5971
%2 = quir.call_circuit @circuit_0(%0) : (!quir.qubit<1>) -> i1
6072
%3 = quir.call_circuit @circuit_1(%1) : (!quir.qubit<1>) -> i1
6173
// CHECK-NOT: {{.*}} = quir.call_circuit @circuit_0(%0) : (!quir.qubit<1>) -> i1
6274
// CHECK-NOT: {{.*}} = quir.call_circuit @circuit_1(%1) : (!quir.qubit<1>) -> i1
6375
// CHECK: {{.*}}:2 = quir.call_circuit @circuit_0_q0_circuit_1_q1(%0, %1) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1)
6476

65-
quir.barrier %0, %1, %200 : (!quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>) -> ()
77+
quir.barrier %0, %1, %200, %201, %202 : (!quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>) -> ()
6678
%4 = quir.call_circuit @circuit_2(%0) : (!quir.qubit<1>) -> i1
6779
oq3.cbit_assign_bit @qc0_meas<2> [0] : i1 = %4
6880
%5 = quir.call_circuit @circuit_3(%1) : (!quir.qubit<1>) -> i1
@@ -75,12 +87,14 @@ module {
7587
// CHECK-NOT: oq3.cbit_assign_bit @qc0_meas<2> [0] : i1 = %[[MEAS2]]:0
7688
// CHECK-NOT: oq3.cbit_assign_bit @qc0_meas<2> [1] : i1 = %[[MEAS2]]:1
7789

78-
quir.barrier %0, %1, %200 : (!quir.qubit<1>,!quir.qubit<1>, !quir.qubit<1>) -> ()
90+
quir.barrier %0, %1, %200, %201, %202 : (!quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>) -> ()
7991
%6 = quir.call_circuit @circuit_4(%0) : (!quir.qubit<1>) -> i1
8092
%7 = oq3.variable_load @p1 : f64
8193
%8 = "oq3.cast"(%7) : (f64) -> !quir.angle<64>
94+
quir.barrier %200 : (!quir.qubit<1>) -> ()
8295
%9 = oq3.variable_load @p2 : f64
8396
%10 = "oq3.cast"(%9) : (f64) -> !quir.angle<64>
97+
quir.barrier %201 : (!quir.qubit<1>) -> ()
8498
%11 = quir.call_circuit @circuit_5(%1, %8) : (!quir.qubit<1>, !quir.angle<64>) -> i1
8599
// CHECK-NOT: %6 = quir.call_circuit @circuit_4(%0) : (!quir.qubit<1>) -> i1
86100
// CHECK-NOT: %7 = oq3.variable_load @p1 : f64
@@ -92,12 +106,26 @@ module {
92106
// CHECK: {{.*}} = oq3.variable_load @p2 : f64
93107
// CHECK: %{{.*}}:2 = quir.call_circuit @circuit_4_q0_circuit_5_q1(%0, %1, %[[CAST]]) : (!quir.qubit<1>, !quir.qubit<1>, !quir.angle<64>) -> (i1, i1)
94108

95-
quir.barrier %0, %1, %200 : (!quir.qubit<1>,!quir.qubit<1>, !quir.qubit<1>) -> ()
109+
quir.barrier %0, %1, %200, %201, %202 : (!quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>) -> ()
96110
%12:2 = quir.call_circuit @circuit_6(%0) : (!quir.qubit<1>) -> (i1, i1)
97111
quir.barrier %200 : (!quir.qubit<1>) -> ()
98112
%13:2 = quir.call_circuit @circuit_6(%0) : (!quir.qubit<1>) -> (i1, i1)
99113
// CHECK: %{{.*}}:4 = quir.call_circuit @circuit_6_q0_circuit_6_q0(%0, %0) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1, i1, i1)
100114

115+
116+
quir.barrier %0, %1, %200, %201, %202 : (!quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>) -> ()
117+
%14:2 = quir.call_circuit @circuit_7(%0) : (!quir.qubit<1>) -> (i1, i1)
118+
quir.barrier %200, %201 : (!quir.qubit<1>, !quir.qubit<1>) -> ()
119+
quir.barrier %200, %202 : (!quir.qubit<1>, !quir.qubit<1>) -> ()
120+
%15:2 = quir.call_circuit @circuit_7(%0) : (!quir.qubit<1>) -> (i1, i1)
121+
// CHECK: %{{.*}}:4 = quir.call_circuit @circuit_7_q0_circuit_7_q0(%0, %0) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1, i1, i1)
122+
123+
quir.barrier %0, %1, %200, %201, %202 : (!quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>) -> ()
124+
%16:2 = quir.call_circuit @circuit_8(%0) : (!quir.qubit<1>) -> (i1, i1)
125+
quir.barrier %0 : (!quir.qubit<1>) -> ()
126+
%17:2 = quir.call_circuit @circuit_8(%0) : (!quir.qubit<1>) -> (i1, i1)
127+
// CHECK-NOT: %{{.*}}:4 = quir.call_circuit @circuit_8_q0_circuit_8_q0(%0, %0) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1, i1, i1)
128+
101129
%c0_i32 = arith.constant 0 : i32
102130
return %c0_i32 : i32
103131
}

0 commit comments

Comments
 (0)