Skip to content

Commit ed78572

Browse files
authored
Improve MergeCircuitsPass (#151)
Improves the MergeCircuitsPass so that two `CallCircuitOps` will be merged if there are non-quantum operations between them. The non-quantum operations are moved above or below the `CallCircuitOps` as required by the uses and the `CircuitOp` / `CallCircuitOp` are merged.
1 parent 12b9fdf commit ed78572

File tree

2 files changed

+79
-30
lines changed

2 files changed

+79
-30
lines changed

lib/Dialect/QUIR/Transforms/MergeCircuits.cpp

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -42,55 +42,53 @@ using namespace mlir::quir;
4242

4343
namespace {
4444

45-
// This pattern matches on two CallCircuitOp back to back
45+
// This pattern matches on two CallCircuitOps separated by non-quantum ops
4646
struct CircuitAndCircuitPattern : public OpRewritePattern<CallCircuitOp> {
4747
explicit CircuitAndCircuitPattern(MLIRContext *ctx)
4848
: OpRewritePattern<CallCircuitOp>(ctx) {}
4949

5050
LogicalResult matchAndRewrite(CallCircuitOp callCircuitOp,
5151
PatternRewriter &rewriter) const override {
5252

53-
Operation *nextOp = callCircuitOp->getNextNode();
54-
if (!nextOp)
53+
// get next quantum op and check if its a CallCircuitOp
54+
llvm::Optional<Operation *> secondOp = nextQuantumOpOrNull(callCircuitOp);
55+
if (!secondOp)
5556
return failure();
5657

57-
auto nextCallCircuitOp = dyn_cast<CallCircuitOp>(nextOp);
58+
auto nextCallCircuitOp = dyn_cast<CallCircuitOp>(*secondOp);
5859
if (!nextCallCircuitOp)
5960
return failure();
6061

61-
return MergeCircuitsPass::mergeCallCircuits(rewriter, callCircuitOp,
62-
nextCallCircuitOp);
63-
64-
} // matchAndRewrite
65-
}; // struct CircuitAndCircuitPattern
66-
67-
// This pattern matches on two CallCircuitOp with a CBitAssignBitOp in between
68-
struct CircuitAssignAndCircuitPattern : public OpRewritePattern<CallCircuitOp> {
69-
explicit CircuitAssignAndCircuitPattern(MLIRContext *ctx)
70-
: OpRewritePattern<CallCircuitOp>(ctx) {}
71-
72-
LogicalResult matchAndRewrite(CallCircuitOp callCircuitOp,
73-
PatternRewriter &rewriter) const override {
74-
75-
Operation *nextOp = callCircuitOp->getNextNode();
76-
if (!nextOp)
77-
return failure();
78-
79-
auto secondOp = dyn_cast<oq3::CBitAssignBitOp>(nextOp);
80-
if (!secondOp)
81-
return failure();
62+
// Move first CallCircuitOp after nodes until a user of the
63+
// CallCircuitOp or the second CallCircuitOp is reached
64+
Operation *curOp = callCircuitOp->getNextNode();
65+
while (curOp != *secondOp) {
66+
if (std::find(callCircuitOp->user_begin(), callCircuitOp->user_end(),
67+
curOp) != callCircuitOp->user_end())
68+
break;
69+
callCircuitOp->moveAfter(curOp);
70+
curOp = callCircuitOp->getNextNode();
71+
}
8272

83-
Operation *thirdOp = secondOp->getNextNode();
73+
// Move second CallCircuitOp before nodes until a definition the
74+
// second CallCircuitOp uses or the first CallCircuitOp is reached
75+
curOp = nextCallCircuitOp->getPrevNode();
76+
while (curOp != callCircuitOp) {
77+
if (std::find(curOp->user_begin(), curOp->user_end(),
78+
nextCallCircuitOp) != curOp->user_end())
79+
break;
80+
nextCallCircuitOp->moveBefore(curOp);
81+
curOp = nextCallCircuitOp->getPrevNode();
82+
}
8483

85-
auto nextCallCircuitOp = dyn_cast<CallCircuitOp>(thirdOp);
86-
if (!nextCallCircuitOp)
84+
if (callCircuitOp->getNextNode() != nextCallCircuitOp)
8785
return failure();
8886

8987
return MergeCircuitsPass::mergeCallCircuits(rewriter, callCircuitOp,
9088
nextCallCircuitOp);
9189

9290
} // matchAndRewrite
93-
}; // struct CircuitAssignAndCircuitPattern
91+
}; // struct CircuitAndCircuitPattern
9492

9593
} // end anonymous namespace
9694

@@ -254,7 +252,6 @@ void MergeCircuitsPass::runOnOperation() {
254252

255253
RewritePatternSet patterns(&getContext());
256254
patterns.add<CircuitAndCircuitPattern>(&getContext());
257-
patterns.add<CircuitAssignAndCircuitPattern>(&getContext());
258255

259256
if (failed(
260257
applyPatternsAndFoldGreedily(moduleOperation, std::move(patterns))))

test/Dialect/QUIR/Transforms/merge-circuits.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
// that they have been altered from the originals.
1515

1616
module {
17+
oq3.declare_variable @qc0_meas : !quir.cbit<2>
18+
oq3.declare_variable {input} @p1 : f64
19+
oq3.declare_variable {input} @p2 : f64
1720
quir.circuit @circuit_0(%arg0: !quir.qubit<1>) -> i1 {
1821
%0 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
1922
quir.return %0: i1
@@ -22,6 +25,23 @@ module {
2225
%0 = quir.measure(%arg1) : (!quir.qubit<1>) -> i1
2326
quir.return %0: i1
2427
}
28+
quir.circuit @circuit_2(%arg0: !quir.qubit<1>) -> i1 {
29+
%0 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
30+
quir.return %0: i1
31+
}
32+
quir.circuit @circuit_3(%arg1: !quir.qubit<1>) -> i1 {
33+
%0 = quir.measure(%arg1) : (!quir.qubit<1>) -> i1
34+
quir.return %0: i1
35+
}
36+
quir.circuit @circuit_4(%arg0: !quir.qubit<1>) -> i1 {
37+
%0 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
38+
quir.return %0: i1
39+
}
40+
quir.circuit @circuit_5(%arg0: !quir.qubit<1>, %arg2: !quir.angle<64>) -> i1 {
41+
quir.call_gate @rz(%arg0, %arg2) : (!quir.qubit<1>, !quir.angle<64>) -> ()
42+
%0 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
43+
quir.return %0 : i1
44+
}
2545
// CHECK: @circuit_0_q0_circuit_1_q1(%arg0: !quir.qubit<1>
2646
// CHECK: %0 = quir.measure(%arg0) : (!quir.qubit<1>) -> i1
2747
// CHECK: %1 = quir.measure(%arg1) : (!quir.qubit<1>) -> i1
@@ -35,6 +55,38 @@ module {
3555
// CHECK-NOT: {{.*}} = quir.call_circuit @circuit_0(%0) : (!quir.qubit<1>) -> i1
3656
// CHECK-NOT: {{.*}} = quir.call_circuit @circuit_1(%1) : (!quir.qubit<1>) -> i1
3757
// CHECK: {{.*}}:2 = quir.call_circuit @circuit_0_q0_circuit_1_q1(%0, %1) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1)
58+
59+
quir.barrier %0 : (!quir.qubit<1>) -> ()
60+
%4 = quir.call_circuit @circuit_2(%0) : (!quir.qubit<1>) -> i1
61+
oq3.cbit_assign_bit @qc0_meas<2> [0] : i1 = %4
62+
%5 = quir.call_circuit @circuit_3(%1) : (!quir.qubit<1>) -> i1
63+
oq3.cbit_assign_bit @qc0_meas<2> [1] : i1 = %5
64+
// CHECK-NOT: %4 = quir.call_circuit @circuit_2(%0) : (!quir.qubit<1>) -> i1
65+
// CHECK-NOT: oq3.cbit_assign_bit @qc0_meas<2> [0] : i1 = %4
66+
// CHECK-NOT: %5 = quir.call_circuit @circuit_3(%1) : (!quir.qubit<1>) -> i1
67+
// CHECK-NOT: oq3.cbit_assign_bit @qc0_meas<2> [1] : i1 = %5
68+
// CHECK: %[[MEAS2:.*]]:2 = quir.call_circuit @circuit_2_q0_circuit_3_q1(%0, %1) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1)
69+
// CHECK-NOT: oq3.cbit_assign_bit @qc0_meas<2> [0] : i1 = %[[MEAS2]]:0
70+
// CHECK-NOT: oq3.cbit_assign_bit @qc0_meas<2> [1] : i1 = %[[MEAS2]]:1
71+
72+
quir.barrier %0 : (!quir.qubit<1>) -> ()
73+
%6 = quir.call_circuit @circuit_4(%0) : (!quir.qubit<1>) -> i1
74+
%7 = oq3.variable_load @p1 : f64
75+
%8 = "oq3.cast"(%7) : (f64) -> !quir.angle<64>
76+
%9 = oq3.variable_load @p2 : f64
77+
%10 = "oq3.cast"(%9) : (f64) -> !quir.angle<64>
78+
%11 = quir.call_circuit @circuit_5(%1, %8) : (!quir.qubit<1>, !quir.angle<64>) -> i1
79+
// CHECK-NOT: %6 = quir.call_circuit @circuit_4(%0) : (!quir.qubit<1>) -> i1
80+
// CHECK-NOT: %7 = oq3.variable_load @p1 : f64
81+
// CHECK-NOT: %8 = "oq3.cast"(%7) : (f64) -> !quir.angle<64>
82+
// CHECK-NOT: %9 = oq3.variable_load @p2 : f64
83+
// CHECK-NOT: %10 = "oq3.cast"(%9) : (f64) -> !quir.angle<64>
84+
// CHECK: %[[LOAD:.*]] = oq3.variable_load @p1 : f64
85+
// CHECK: %[[CAST:.*]] = "oq3.cast"(%[[LOAD]]) : (f64) -> !quir.angle<64>
86+
// CHECK: {{.*}} = oq3.variable_load @p2 : f64
87+
// CHECK: %{{.*}}:2 = quir.call_circuit @circuit_4_q0_circuit_5_q1(%0, %1, %[[CAST]]) : (!quir.qubit<1>, !quir.qubit<1>, !quir.angle<64>) -> (i1, i1)
88+
89+
3890
%c0_i32 = arith.constant 0 : i32
3991
return %c0_i32 : i32
4092
}

0 commit comments

Comments
 (0)