Skip to content

Commit 3fefb4e

Browse files
committed
Add canonicalization pattern for merging nested CtrlOps
1 parent e007675 commit 3fefb4e

File tree

5 files changed

+93
-12
lines changed

5 files changed

+93
-12
lines changed

mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ class FluxProgramBuilder {
380380
Location loc;
381381

382382
private:
383-
bool inRegion = false;
383+
int inRegion = 0;
384384

385385
//===--------------------------------------------------------------------===//
386386
// Linear Type Tracking Helpers

mlir/include/mlir/Dialect/Flux/IR/FluxOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,13 @@ def CtrlOp : FluxOp<"ctrl", traits = [UnitaryOpInterface, AttrSizedOperandSegmen
433433
let builders = [
434434
OpBuilder<(ins "ValueRange":$controls_in, "ValueRange":$targets_in), [{
435435
auto qubit_type = QubitType::get($_builder.getContext());
436-
build($_builder, $_state, qubit_type, qubit_type, controls_in, targets_in);
436+
SmallVector<Type> control_types(controls_in.size(), qubit_type);
437+
SmallVector<Type> target_types(targets_in.size(), qubit_type);
438+
build($_builder, $_state, TypeRange(control_types), TypeRange(target_types), controls_in, targets_in);
437439
}]>,
438440
];
441+
442+
let hasCanonicalizer = 1;
439443
}
440444

441445
#endif // FluxOPS

mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ Value FluxProgramBuilder::x(Value qubit) {
178178
const auto& qubitOut = xOp.getQubitOut();
179179

180180
// Update tracking
181-
if (!inRegion) {
181+
if (inRegion == 0) {
182182
updateQubitTracking(qubit, qubitOut);
183183
}
184184

@@ -232,21 +232,23 @@ std::pair<SmallVector<Value>, SmallVector<Value>> FluxProgramBuilder::ctrl(
232232
const mlir::OpBuilder::InsertionGuard guard(builder);
233233
builder.setInsertionPointToStart(&ctrlOp.getBody().emplaceBlock());
234234

235-
inRegion = true;
235+
inRegion++;
236236
auto targetsYield = body(*this);
237-
inRegion = false;
237+
inRegion--;
238238

239239
builder.create<YieldOp>(loc, targetsYield);
240240

241-
// Update tracking
242241
const auto& controlsOut = ctrlOp.getControlsOut();
243-
for (const auto& [control, controlOut] : llvm::zip(controls, controlsOut)) {
244-
updateQubitTracking(control, controlOut);
245-
}
246-
247242
const auto& targetsOut = ctrlOp.getTargetsOut();
248-
for (const auto& [target, targetOut] : llvm::zip(targets, targetsOut)) {
249-
updateQubitTracking(target, targetOut);
243+
244+
// Update tracking
245+
if (inRegion == 0) {
246+
for (const auto& [control, controlOut] : llvm::zip(controls, controlsOut)) {
247+
updateQubitTracking(control, controlOut);
248+
}
249+
for (const auto& [target, targetOut] : llvm::zip(targets, targetsOut)) {
250+
updateQubitTracking(target, targetOut);
251+
}
250252
}
251253

252254
return {controlsOut, targetsOut};

mlir/lib/Dialect/Flux/IR/FluxOps.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,46 @@ struct MergeSubsequentRX final : OpRewritePattern<RXOp> {
556556
}
557557
};
558558

559+
struct MergeNestedCtrl final : OpRewritePattern<CtrlOp> {
560+
using OpRewritePattern::OpRewritePattern;
561+
562+
LogicalResult matchAndRewrite(CtrlOp ctrlOp,
563+
PatternRewriter& rewriter) const override {
564+
auto bodyUnitary = ctrlOp.getBodyUnitary();
565+
auto bodyCtrlOp = llvm::dyn_cast<CtrlOp>(bodyUnitary.getOperation());
566+
567+
// Check if the body unitary is a CtrlOp
568+
if (!bodyCtrlOp) {
569+
return failure();
570+
}
571+
572+
// Merge controls
573+
SmallVector<Value> newControls;
574+
newControls.append(ctrlOp.getControlsIn().begin(),
575+
ctrlOp.getControlsIn().end());
576+
for (auto control : bodyCtrlOp.getControlsIn()) {
577+
if (llvm::is_contained(newControls, control)) {
578+
continue;
579+
}
580+
newControls.push_back(control);
581+
}
582+
583+
// Create new CtrlOp
584+
auto newCtrlOp = rewriter.create<CtrlOp>(ctrlOp.getLoc(), newControls,
585+
bodyCtrlOp.getTargetsIn());
586+
587+
// Clone block
588+
rewriter.cloneRegionBefore(bodyCtrlOp.getBody(), newCtrlOp.getBody(),
589+
newCtrlOp.getBody().end());
590+
591+
// Replace CtrlOps
592+
rewriter.eraseOp(bodyCtrlOp);
593+
rewriter.replaceOp(ctrlOp, newCtrlOp.getResults());
594+
595+
return success();
596+
}
597+
};
598+
559599
} // namespace
560600

561601
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet& results,
@@ -577,3 +617,8 @@ void RXOp::getCanonicalizationPatterns(RewritePatternSet& results,
577617
MLIRContext* context) {
578618
results.add<MergeSubsequentRX>(context);
579619
}
620+
621+
void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results,
622+
MLIRContext* context) {
623+
results.add<MergeNestedCtrl>(context);
624+
}

mlir/unittests/pipeline/test_compiler_pipeline.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,4 +1269,34 @@ TEST_F(SimpleConversionTest, CX) {
12691269
EXPECT_TRUE(verify("Flux to Quartz", quartzIRConv, quartzExpected.get()));
12701270
}
12711271

1272+
TEST_F(SimpleConversionTest, CXMerge) {
1273+
const auto fluxInit = buildFluxIR([](flux::FluxProgramBuilder& b) {
1274+
auto reg = b.allocQubitRegister(3, "q");
1275+
auto q0a = reg[0];
1276+
auto q1a = reg[1];
1277+
auto q2a = reg[2];
1278+
b.ctrl({q0a}, {q1a, q2a}, [&](auto& b) {
1279+
auto q12b = b.ctrl({q1a}, {q2a}, [&](auto& b) {
1280+
auto q2b = b.x(q2a);
1281+
return SmallVector<Value>{q2b};
1282+
});
1283+
return SmallVector<Value>{q12b.first[0], q12b.second[0]};
1284+
});
1285+
});
1286+
const auto fluxInitIR = captureIR(fluxInit.get());
1287+
1288+
const auto fluxOpt = buildFluxIR([](flux::FluxProgramBuilder& b) {
1289+
auto reg = b.allocQubitRegister(3, "q");
1290+
auto q0a = reg[0];
1291+
auto q1a = reg[1];
1292+
auto q2a = reg[2];
1293+
b.ctrl({q0a, q1a}, {q2a}, [&](auto& b) {
1294+
auto q2b = b.x(q2a);
1295+
return SmallVector<Value>{q2b};
1296+
});
1297+
});
1298+
1299+
EXPECT_TRUE(verify("Flux Canonicalization", fluxInitIR, fluxOpt.get()));
1300+
}
1301+
12721302
} // namespace

0 commit comments

Comments
 (0)