Skip to content

Commit 28636e9

Browse files
committed
✨ enhance compiler pipeline tests by adding canonicalization passes and updating measurement operations
Signed-off-by: burgholzer <[email protected]>
1 parent 6bac08f commit 28636e9

File tree

1 file changed

+39
-122
lines changed

1 file changed

+39
-122
lines changed

mlir/unittests/pipeline/test_compiler_pipeline.cpp

Lines changed: 39 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,10 @@ using namespace mlir;
9898
*/
9999
struct StageExpectations {
100100
ModuleOp quartzImport;
101-
ModuleOp initialCanon;
102101
ModuleOp fluxConversion;
103-
ModuleOp fluxCanon;
104102
ModuleOp optimization;
105-
ModuleOp optimizationCanon;
106103
ModuleOp quartzConversion;
107-
ModuleOp quartzCanon;
108104
ModuleOp qirConversion;
109-
ModuleOp qirCanon;
110105
};
111106

112107
//===----------------------------------------------------------------------===//
@@ -147,7 +142,7 @@ class CompilerPipelineTest : public testing::Test {
147142
config.convertToQIR = true;
148143
config.recordIntermediates = true;
149144
config.printIRAfterAllStages =
150-
false; /// TODO: Change back after everything is running
145+
true; /// TODO: Change back after everything is running
151146

152147
emptyQuartz = buildQuartzIR([](quartz::QuartzProgramBuilder&) {});
153148
emptyFlux = buildFluxIR([](flux::FluxProgramBuilder&) {});
@@ -222,37 +217,55 @@ class CompilerPipelineTest : public testing::Test {
222217
//===--------------------------------------------------------------------===//
223218

224219
/**
225-
* @brief Build expected Quartz IR programmatically
220+
* @brief Run canonicalization
221+
*/
222+
void runCanonicalizationPasses(ModuleOp module) const {
223+
PassManager pm(module.getContext());
224+
pm.addPass(createCanonicalizerPass());
225+
pm.addPass(createCSEPass());
226+
if (failed(pm.run(module))) {
227+
llvm::errs() << "Failed to run canonicalization passes\n";
228+
}
229+
}
230+
231+
/**
232+
* @brief Build expected Quartz IR programmatically and run canonicalization
226233
*/
227234
[[nodiscard]] OwningOpRef<ModuleOp> buildQuartzIR(
228235
const std::function<void(quartz::QuartzProgramBuilder&)>& buildFunc)
229236
const {
230237
quartz::QuartzProgramBuilder builder(context.get());
231238
builder.initialize();
232239
buildFunc(builder);
233-
return builder.finalize();
240+
auto module = builder.finalize();
241+
runCanonicalizationPasses(module.get());
242+
return module;
234243
}
235244

236245
/**
237-
* @brief Build expected Flux IR programmatically
246+
* @brief Build expected Flux IR programmatically and run canonicalization
238247
*/
239248
[[nodiscard]] OwningOpRef<ModuleOp> buildFluxIR(
240249
const std::function<void(flux::FluxProgramBuilder&)>& buildFunc) const {
241250
flux::FluxProgramBuilder builder(context.get());
242251
builder.initialize();
243252
buildFunc(builder);
244-
return builder.finalize();
253+
auto module = builder.finalize();
254+
runCanonicalizationPasses(module.get());
255+
return module;
245256
}
246257

247258
/**
248-
* @brief Build expected QIR programmatically
259+
* @brief Build expected QIR programmatically and run canonicalization
249260
*/
250261
[[nodiscard]] OwningOpRef<ModuleOp> buildQIR(
251262
const std::function<void(qir::QIRProgramBuilder&)>& buildFunc) const {
252263
qir::QIRProgramBuilder builder(context.get());
253264
builder.initialize();
254265
buildFunc(builder);
255-
return builder.finalize();
266+
auto module = builder.finalize();
267+
runCanonicalizationPasses(module.get());
268+
return module;
256269
}
257270

258271
//===--------------------------------------------------------------------===//
@@ -267,56 +280,30 @@ class CompilerPipelineTest : public testing::Test {
267280
* Stages without expectations are skipped.
268281
*/
269282
void verifyAllStages(const StageExpectations& expectations) const {
270-
if (expectations.quartzImport) {
271-
EXPECT_TRUE(verify("Quartz Import", record.afterQuartzImport,
283+
if (expectations.quartzImport != nullptr) {
284+
EXPECT_TRUE(verify("Quartz Import", record.afterInitialCanon,
272285
expectations.quartzImport));
273286
}
274287

275-
if (expectations.initialCanon) {
276-
EXPECT_TRUE(verify("Initial Canonicalization", record.afterInitialCanon,
277-
expectations.initialCanon));
278-
}
279-
280-
if (expectations.fluxConversion) {
281-
EXPECT_TRUE(verify("Flux Conversion", record.afterFluxConversion,
288+
if (expectations.fluxConversion != nullptr) {
289+
EXPECT_TRUE(verify("Flux Conversion", record.afterFluxCanon,
282290
expectations.fluxConversion));
283291
}
284292

285-
if (expectations.fluxCanon) {
286-
EXPECT_TRUE(verify("Flux Canonicalization", record.afterFluxCanon,
287-
expectations.fluxCanon));
288-
}
289-
290-
if (expectations.optimization) {
291-
EXPECT_TRUE(verify("Optimization", record.afterOptimization,
293+
if (expectations.optimization != nullptr) {
294+
EXPECT_TRUE(verify("Optimization", record.afterOptimizationCanon,
292295
expectations.optimization));
293296
}
294297

295-
if (expectations.optimizationCanon) {
296-
EXPECT_TRUE(verify("Optimization Canonicalization",
297-
record.afterOptimizationCanon,
298-
expectations.optimizationCanon));
299-
}
300-
301-
if (expectations.quartzConversion) {
302-
EXPECT_TRUE(verify("Quartz Conversion", record.afterQuartzConversion,
298+
if (expectations.quartzConversion != nullptr) {
299+
EXPECT_TRUE(verify("Quartz Conversion", record.afterQuartzCanon,
303300
expectations.quartzConversion));
304301
}
305302

306-
if (expectations.quartzCanon) {
307-
EXPECT_TRUE(verify("Quartz Canonicalization", record.afterQuartzCanon,
308-
expectations.quartzCanon));
309-
}
310-
311-
if (expectations.qirConversion) {
312-
EXPECT_TRUE(verify("QIR Conversion", record.afterQIRConversion,
303+
if (expectations.qirConversion != nullptr) {
304+
EXPECT_TRUE(verify("QIR Conversion", record.afterQIRCanon,
313305
expectations.qirConversion));
314306
}
315-
316-
if (expectations.qirCanon) {
317-
EXPECT_TRUE(verify("QIR Canonicalization", record.afterQIRCanon,
318-
expectations.qirCanon));
319-
}
320307
}
321308

322309
void TearDown() override {
@@ -376,15 +363,10 @@ TEST_F(CompilerPipelineTest, EmptyCircuit) {
376363
// Verify all stages
377364
verifyAllStages({
378365
.quartzImport = emptyQuartz.get(),
379-
.initialCanon = emptyQuartz.get(),
380366
.fluxConversion = emptyFlux.get(),
381-
.fluxCanon = emptyFlux.get(),
382367
.optimization = emptyFlux.get(),
383-
.optimizationCanon = emptyFlux.get(),
384368
.quartzConversion = emptyQuartz.get(),
385-
.quartzCanon = emptyQuartz.get(),
386369
.qirConversion = emptyQIR.get(),
387-
.qirCanon = emptyQIR.get(),
388370
});
389371
}
390372

@@ -414,15 +396,10 @@ TEST_F(CompilerPipelineTest, SingleQubitRegister) {
414396

415397
verifyAllStages({
416398
.quartzImport = quartzExpected.get(),
417-
.initialCanon = quartzExpected.get(),
418399
.fluxConversion = fluxExpected.get(),
419-
.fluxCanon = emptyFlux.get(),
420400
.optimization = emptyFlux.get(),
421-
.optimizationCanon = emptyFlux.get(),
422401
.quartzConversion = emptyQuartz.get(),
423-
.quartzCanon = emptyQuartz.get(),
424402
.qirConversion = emptyQIR.get(),
425-
.qirCanon = emptyQIR.get(),
426403
});
427404
}
428405

@@ -444,15 +421,10 @@ TEST_F(CompilerPipelineTest, MultiQubitRegister) {
444421

445422
verifyAllStages({
446423
.quartzImport = quartzExpected.get(),
447-
.initialCanon = quartzExpected.get(),
448424
.fluxConversion = fluxExpected.get(),
449-
.fluxCanon = emptyFlux.get(),
450425
.optimization = emptyFlux.get(),
451-
.optimizationCanon = emptyFlux.get(),
452426
.quartzConversion = emptyQuartz.get(),
453-
.quartzCanon = emptyQuartz.get(),
454427
.qirConversion = emptyQIR.get(),
455-
.qirCanon = emptyQIR.get(),
456428
});
457429
}
458430

@@ -480,15 +452,10 @@ TEST_F(CompilerPipelineTest, MultipleQuantumRegisters) {
480452

481453
verifyAllStages({
482454
.quartzImport = quartzExpected.get(),
483-
.initialCanon = quartzExpected.get(),
484455
.fluxConversion = fluxExpected.get(),
485-
.fluxCanon = emptyFlux.get(),
486456
.optimization = emptyFlux.get(),
487-
.optimizationCanon = emptyFlux.get(),
488457
.quartzConversion = emptyQuartz.get(),
489-
.quartzCanon = emptyQuartz.get(),
490458
.qirConversion = emptyQIR.get(),
491-
.qirCanon = emptyQIR.get(),
492459
});
493460
}
494461

@@ -528,15 +495,10 @@ TEST_F(CompilerPipelineTest, SingleClassicalBitRegister) {
528495

529496
verifyAllStages({
530497
.quartzImport = expected.get(),
531-
.initialCanon = emptyQuartz.get(),
532498
.fluxConversion = emptyFlux.get(),
533-
.fluxCanon = emptyFlux.get(),
534499
.optimization = emptyFlux.get(),
535-
.optimizationCanon = emptyFlux.get(),
536500
.quartzConversion = emptyQuartz.get(),
537-
.quartzCanon = emptyQuartz.get(),
538501
.qirConversion = emptyQIR.get(),
539-
.qirCanon = emptyQIR.get(),
540502
});
541503
}
542504

@@ -560,15 +522,10 @@ TEST_F(CompilerPipelineTest, MultiBitClassicalRegister) {
560522

561523
verifyAllStages({
562524
.quartzImport = expected.get(),
563-
.initialCanon = emptyQuartz.get(),
564525
.fluxConversion = emptyFlux.get(),
565-
.fluxCanon = emptyFlux.get(),
566526
.optimization = emptyFlux.get(),
567-
.optimizationCanon = emptyFlux.get(),
568527
.quartzConversion = emptyQuartz.get(),
569-
.quartzCanon = emptyQuartz.get(),
570528
.qirConversion = emptyQIR.get(),
571-
.qirCanon = emptyQIR.get(),
572529
});
573530
}
574531

@@ -595,15 +552,10 @@ TEST_F(CompilerPipelineTest, MultipleClassicalRegisters) {
595552

596553
verifyAllStages({
597554
.quartzImport = expected.get(),
598-
.initialCanon = emptyQuartz.get(),
599555
.fluxConversion = emptyFlux.get(),
600-
.fluxCanon = emptyFlux.get(),
601556
.optimization = emptyFlux.get(),
602-
.optimizationCanon = emptyFlux.get(),
603557
.quartzConversion = emptyQuartz.get(),
604-
.quartzCanon = emptyQuartz.get(),
605558
.qirConversion = emptyQIR.get(),
606-
.qirCanon = emptyQIR.get(),
607559
});
608560
}
609561

@@ -649,15 +601,10 @@ TEST_F(CompilerPipelineTest, SingleResetInSingleQubitCircuit) {
649601

650602
verifyAllStages({
651603
.quartzImport = expected.get(),
652-
.initialCanon = expected.get(),
653604
.fluxConversion = fluxExpected.get(),
654-
.fluxCanon = emptyFlux.get(),
655605
.optimization = emptyFlux.get(),
656-
.optimizationCanon = emptyFlux.get(),
657606
.quartzConversion = emptyQuartz.get(),
658-
.quartzCanon = emptyQuartz.get(),
659607
.qirConversion = emptyQIR.get(),
660-
.qirCanon = emptyQIR.get(),
661608
});
662609
}
663610

@@ -695,15 +642,10 @@ TEST_F(CompilerPipelineTest, ConsecutiveResetOperations) {
695642

696643
verifyAllStages({
697644
.quartzImport = expected.get(),
698-
.initialCanon = expected.get(),
699645
.fluxConversion = fluxExpected.get(),
700-
.fluxCanon = emptyFlux.get(),
701646
.optimization = emptyFlux.get(),
702-
.optimizationCanon = emptyFlux.get(),
703647
.quartzConversion = emptyQuartz.get(),
704-
.quartzCanon = emptyQuartz.get(),
705648
.qirConversion = emptyQIR.get(),
706-
.qirCanon = emptyQIR.get(),
707649
});
708650
}
709651

@@ -733,15 +675,10 @@ TEST_F(CompilerPipelineTest, SeparateResetsInTwoQubitSystem) {
733675

734676
verifyAllStages({
735677
.quartzImport = expected.get(),
736-
.initialCanon = expected.get(),
737678
.fluxConversion = fluxExpected.get(),
738-
.fluxCanon = emptyFlux.get(),
739679
.optimization = emptyFlux.get(),
740-
.optimizationCanon = emptyFlux.get(),
741680
.quartzConversion = emptyQuartz.get(),
742-
.quartzCanon = emptyQuartz.get(),
743681
.qirConversion = emptyQIR.get(),
744-
.qirCanon = emptyQIR.get(),
745682
});
746683
}
747684

@@ -779,15 +716,10 @@ TEST_F(CompilerPipelineTest, SingleMeasurementToSingleBit) {
779716

780717
verifyAllStages({
781718
.quartzImport = expected.get(),
782-
.initialCanon = expected.get(),
783719
.fluxConversion = fluxExpected.get(),
784-
.fluxCanon = fluxExpected.get(),
785720
.optimization = fluxExpected.get(),
786-
.optimizationCanon = fluxExpected.get(),
787721
.quartzConversion = expected.get(),
788-
.quartzCanon = expected.get(),
789722
.qirConversion = qirExpected.get(),
790-
.qirCanon = qirExpected.get(),
791723
});
792724
}
793725

@@ -825,15 +757,10 @@ TEST_F(CompilerPipelineTest, RepeatedMeasurementToSameBit) {
825757

826758
verifyAllStages({
827759
.quartzImport = expected.get(),
828-
.initialCanon = expected.get(),
829760
.fluxConversion = fluxExpected.get(),
830-
.fluxCanon = fluxExpected.get(),
831761
.optimization = fluxExpected.get(),
832-
.optimizationCanon = fluxExpected.get(),
833762
.quartzConversion = expected.get(),
834-
.quartzCanon = expected.get(),
835763
.qirConversion = qirExpected.get(),
836-
.qirCanon = qirExpected.get(),
837764
});
838765
}
839766

@@ -876,15 +803,10 @@ TEST_F(CompilerPipelineTest, RepeatedMeasurementOnSeparateBits) {
876803

877804
verifyAllStages({
878805
.quartzImport = expected.get(),
879-
.initialCanon = expected.get(),
880806
.fluxConversion = fluxExpected.get(),
881-
.fluxCanon = fluxExpected.get(),
882807
.optimization = fluxExpected.get(),
883-
.optimizationCanon = fluxExpected.get(),
884808
.quartzConversion = expected.get(),
885-
.quartzCanon = expected.get(),
886809
.qirConversion = qirExpected.get(),
887-
.qirCanon = qirExpected.get(),
888810
});
889811
}
890812

@@ -893,8 +815,8 @@ TEST_F(CompilerPipelineTest, RepeatedMeasurementOnSeparateBits) {
893815
*/
894816
TEST_F(CompilerPipelineTest, MultipleClassicalRegistersAndMeasurements) {
895817
qc::QuantumComputation qc(2);
896-
auto& c1 = qc.addClassicalRegister(1, "c1");
897-
auto& c2 = qc.addClassicalRegister(1, "c2");
818+
const auto& c1 = qc.addClassicalRegister(1, "c1");
819+
const auto& c2 = qc.addClassicalRegister(1, "c2");
898820
qc.measure(0, c1[0]);
899821
qc.measure(1, c2[0]);
900822

@@ -920,21 +842,16 @@ TEST_F(CompilerPipelineTest, MultipleClassicalRegistersAndMeasurements) {
920842

921843
const auto qirExpected = buildQIR([](qir::QIRProgramBuilder& b) {
922844
auto q = b.allocQubitRegister(2);
923-
b.measure(q[0], 0);
924-
b.measure(q[1], 1);
845+
b.measure(q[0], "c1", 0);
846+
b.measure(q[1], "c2", 0);
925847
});
926848

927849
verifyAllStages({
928850
.quartzImport = expected.get(),
929-
.initialCanon = expected.get(),
930851
.fluxConversion = fluxExpected.get(),
931-
.fluxCanon = fluxExpected.get(),
932852
.optimization = fluxExpected.get(),
933-
.optimizationCanon = fluxExpected.get(),
934853
.quartzConversion = expected.get(),
935-
.quartzCanon = expected.get(),
936854
.qirConversion = qirExpected.get(),
937-
.qirCanon = qirExpected.get(),
938855
});
939856
}
940857

0 commit comments

Comments
 (0)