Skip to content

Commit d9209b7

Browse files
authored
✨ Add QCO If operation (#1506)
1 parent 756d3c1 commit d9209b7

File tree

12 files changed

+732
-7
lines changed

12 files changed

+732
-7
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel
1212
### Added
1313

1414
- ✨ Add initial infrastructure for new QC and QCO MLIR dialects
15-
([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1510], [#1513], [#1521])
15+
([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521])
1616
([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**])
1717

1818
### Changed
@@ -330,6 +330,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool
330330
[#1521]: https://github.com/munich-quantum-toolkit/core/pull/1521
331331
[#1513]: https://github.com/munich-quantum-toolkit/core/pull/1513
332332
[#1510]: https://github.com/munich-quantum-toolkit/core/pull/1510
333+
[#1506]: https://github.com/munich-quantum-toolkit/core/pull/1506
333334
[#1481]: https://github.com/munich-quantum-toolkit/core/pull/1481
334335
[#1475]: https://github.com/munich-quantum-toolkit/core/pull/1475
335336
[#1474]: https://github.com/munich-quantum-toolkit/core/pull/1474

mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#pragma once
1212

1313
#include <cstdint>
14-
#include <functional>
1514
#include <llvm/ADT/DenseSet.h>
1615
#include <llvm/ADT/STLFunctionalExtras.h>
1716
#include <llvm/ADT/SmallVector.h>
@@ -1071,6 +1070,51 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder {
10711070
*/
10721071
QCOProgramBuilder& dealloc(Value qubit);
10731072

1073+
//===--------------------------------------------------------------------===//
1074+
// SCF operations
1075+
//===--------------------------------------------------------------------===//
1076+
1077+
/**
1078+
* @brief Construct an if operation for qubits with linear typing
1079+
*
1080+
* @details
1081+
* Constructs an if operation that takes a bool Value and a range of qubit
1082+
* values that are used in the then/else region of this operation. The qubit
1083+
* values are passed down as block arguments to each region.
1084+
*
1085+
* @param condition Bool condition
1086+
* @param qubits Input qubits
1087+
* @param thenBody Function that builds the then body of the if
1088+
* operation
1089+
* @param elseBody Function that builds the else body of the if
1090+
* operation
1091+
* @return ValueRange of the results (must be the same types as the input
1092+
* qubits)
1093+
*
1094+
* @par Example:
1095+
* ```c++
1096+
* auto result =
1097+
* builder.qcoIf(condition, q0,
1098+
* [&](ValueRange args) -> llvm::SmallVector<Value> {
1099+
* auto q1 = builder.h(args[0]);
1100+
* return {q1};
1101+
* });
1102+
* ```
1103+
* ```mlir
1104+
* %q2 = qco.if %condition qubits(%arg0 = %q0) {
1105+
* %q1 = qco.h %arg0 : !qco.qubit -> !qco.qubit
1106+
* qco.yield %q1
1107+
* } else qubits(%arg0 = %q0) {
1108+
* qco.yield %arg0
1109+
* } : {i1, !qco.qubit} -> {!qco.qubit}
1110+
* ```
1111+
*/
1112+
ValueRange
1113+
qcoIf(const std::variant<bool, Value>& condition, ValueRange qubits,
1114+
llvm::function_ref<llvm::SmallVector<Value>(ValueRange)> thenBody,
1115+
llvm::function_ref<llvm::SmallVector<Value>(ValueRange)> elseBody =
1116+
nullptr);
1117+
10741118
//===--------------------------------------------------------------------===//
10751119
// Finalization
10761120
//===--------------------------------------------------------------------===//

mlir/include/mlir/Dialect/QCO/IR/QCOOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Dialect/QCO/IR/QCOInterfaces.h"
2626

2727
#include <mlir/Bytecode/BytecodeOpInterface.h>
28+
#include <mlir/Interfaces/ControlFlowInterfaces.h>
2829
#include <mlir/Interfaces/SideEffectInterfaces.h>
2930
#include <variant>
3031

mlir/include/mlir/Dialect/QCO/IR/QCOOps.td

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
1414
include "mlir/IR/DialectBase.td"
1515
include "mlir/IR/EnumAttr.td"
1616
include "mlir/IR/OpBase.td"
17+
include "mlir/Interfaces/ControlFlowInterfaces.td"
1718
include "mlir/Interfaces/InferTypeOpInterface.td"
1819
include "mlir/Interfaces/SideEffectInterfaces.td"
1920

@@ -1048,7 +1049,7 @@ def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> {
10481049
// Modifiers
10491050
//===----------------------------------------------------------------------===//
10501051

1051-
def YieldOp : QCOOp<"yield", traits = [Terminator]> {
1052+
def YieldOp : QCOOp<"yield", traits = [Terminator, ReturnLike]> {
10521053
let summary = "Yield from a modifier region";
10531054
let description = [{
10541055
Terminates a modifier region, yielding the transformed target qubits back to the enclosing modifier operation.
@@ -1205,4 +1206,83 @@ def InvOp : QCOOp<"inv",
12051206
let hasVerifier = 1;
12061207
}
12071208

1209+
//===----------------------------------------------------------------------===//
1210+
// SCF operations
1211+
//===----------------------------------------------------------------------===//
1212+
1213+
def IfOp : QCOOp<"if", traits =
1214+
[
1215+
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
1216+
"getNumRegionInvocations", "getRegionInvocationBounds",
1217+
"getEntrySuccessorRegions"]>,
1218+
SingleBlock,
1219+
SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">,
1220+
RecursiveMemoryEffects
1221+
]> {
1222+
1223+
let summary = "If-then-else operation for linear (qubit) types";
1224+
let description = [{
1225+
The `qco.if` operation is an if-then-else construct similar to the standard scf.if operation.
1226+
In addition to the condition, the operation takes a variadic number of qubits as inputs that are
1227+
required in the bodies of both branches. These qubits are passed down to the individual regions
1228+
as block arguments. The number of results and the type of the results must be equivalent to the
1229+
number and types of the input qubits.
1230+
1231+
Example:
1232+
```mlir
1233+
%result = qco.if %condition qubits(%arg0 = %q0) {
1234+
%q1 = qco.h %arg0 : !qco.qubit -> !qco.qubit
1235+
qco.yield %q1
1236+
} else qubits(%arg0 = %q0) {
1237+
qco.yield %arg0
1238+
} : {i1, !qco.qubit} -> {!qco.qubit}
1239+
```
1240+
}];
1241+
1242+
let arguments = (ins I1:$condition, Variadic<QubitType>:$qubits);
1243+
let results = (outs Variadic<QubitType>:$results);
1244+
let regions = (region SizedRegion<1>:$thenRegion,
1245+
SizedRegion<1>:$elseRegion);
1246+
1247+
let assemblyFormat = [{
1248+
$condition
1249+
custom<IfOpAliasing>($thenRegion, $elseRegion, $qubits)
1250+
attr-dict `:`
1251+
`{` type($condition) `,` type($qubits) `}`
1252+
`->`
1253+
`{` type($results) `}`
1254+
}];
1255+
1256+
let builders = [
1257+
OpBuilder<(ins "Value":$condition, "ValueRange":$qubits), [{
1258+
build($_builder, $_state, qubits.getTypes(), condition, qubits);
1259+
}]>,
1260+
OpBuilder<(ins "Value":$condition, "ValueRange":$qubits, "llvm::function_ref<llvm::SmallVector<Value>(ValueRange)>":$thenBuilder,
1261+
CArg<"llvm::function_ref<llvm::SmallVector<Value>(ValueRange)>","nullptr">:$elseBuilder)>
1262+
];
1263+
1264+
let extraClassDeclaration = [{
1265+
Block *thenBlock() {
1266+
return &getThenRegion().back();
1267+
}
1268+
YieldOp thenYield();
1269+
Block* elseBlock() {
1270+
return &getElseRegion().back();
1271+
}
1272+
YieldOp elseYield();
1273+
}];
1274+
1275+
let extraClassDefinition = [{
1276+
YieldOp $cppClass::thenYield() {
1277+
return cast<YieldOp>(&thenBlock()->back());
1278+
}
1279+
YieldOp $cppClass::elseYield() {
1280+
return cast<YieldOp>(&elseBlock()->back());
1281+
}
1282+
}];
1283+
1284+
let hasCanonicalizer = 1;
1285+
let hasVerifier = 1;
1286+
}
1287+
12081288
#endif // QCOOPS

mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,71 @@ QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) {
678678
return *this;
679679
}
680680

681+
//===----------------------------------------------------------------------===//
682+
// SCF Operations
683+
//===----------------------------------------------------------------------===//
684+
685+
ValueRange QCOProgramBuilder::qcoIf(
686+
const std::variant<bool, Value>& condition, ValueRange qubits,
687+
llvm::function_ref<llvm::SmallVector<Value>(ValueRange)> thenBody,
688+
llvm::function_ref<llvm::SmallVector<Value>(ValueRange)> elseBody) {
689+
checkFinalized();
690+
691+
auto conditionValue = utils::variantToValue(*this, getLoc(), condition);
692+
693+
auto ifOp = IfOp::create(*this, conditionValue, qubits);
694+
// Create the then and else block
695+
auto& thenBlock = ifOp->getRegion(0).emplaceBlock();
696+
auto& elseBlock = ifOp->getRegion(1).emplaceBlock();
697+
698+
// Create the block arguments and add them as valid qubits
699+
for (auto qubitType : qubits.getTypes()) {
700+
const auto thenArg = thenBlock.addArgument(qubitType, getLoc());
701+
const auto elseArg = elseBlock.addArgument(qubitType, getLoc());
702+
validQubits.insert(thenArg);
703+
validQubits.insert(elseArg);
704+
}
705+
706+
// Construct the bodies of the regions
707+
const InsertionGuard guard(*this);
708+
setInsertionPointToStart(&thenBlock);
709+
const auto thenResult = thenBody(thenBlock.getArguments());
710+
YieldOp::create(*this, thenResult);
711+
setInsertionPointToStart(&elseBlock);
712+
llvm::SmallVector<Value> elseResult;
713+
if (elseBody) {
714+
elseResult = elseBody(elseBlock.getArguments());
715+
YieldOp::create(*this, elseResult);
716+
} else {
717+
elseResult.assign(elseBlock.getArguments().begin(),
718+
elseBlock.getArguments().end());
719+
YieldOp::create(*this, elseBlock.getArguments());
720+
}
721+
722+
if (thenResult.size() != qubits.size() ||
723+
thenResult.size() != elseResult.size()) {
724+
llvm::reportFatalUsageError(
725+
"Then and else body must return the same amount of qubits as the "
726+
"number of input qubits!");
727+
}
728+
729+
// Update qubit tracking
730+
const auto& ifResults = ifOp->getResults();
731+
for (auto [input, output] : llvm::zip_equal(qubits, ifResults)) {
732+
updateQubitTracking(input, output);
733+
}
734+
735+
// Remove the inner qubits as valid qubits
736+
for (auto thenOut : thenResult) {
737+
validQubits.erase(thenOut);
738+
}
739+
for (auto elseOut : elseResult) {
740+
validQubits.erase(elseOut);
741+
}
742+
743+
return ifResults;
744+
}
745+
681746
//===----------------------------------------------------------------------===//
682747
// Finalization
683748
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/QCO/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
file(GLOB_RECURSE MODIFIERS "${CMAKE_CURRENT_SOURCE_DIR}/Modifiers/*.cpp")
1010
file(GLOB_RECURSE OPERATIONS "${CMAKE_CURRENT_SOURCE_DIR}/Operations/*.cpp")
1111
file(GLOB_RECURSE QUBIT_MANAGEMENT "${CMAKE_CURRENT_SOURCE_DIR}/QubitManagement/*.cpp")
12+
file(GLOB_RECURSE SCF "${CMAKE_CURRENT_SOURCE_DIR}/SCF/*.cpp")
1213

1314
add_mlir_dialect_library(
1415
MLIRQCODialect
1516
QCOOps.cpp
1617
${MODIFIERS}
1718
${OPERATIONS}
1819
${QUBIT_MANAGEMENT}
20+
${SCF}
1921
ADDITIONAL_HEADER_DIRS
2022
${PROJECT_SOURCE_DIR}/mlir/include/mlir/Dialect/QCO
2123
DEPENDS

mlir/lib/Dialect/QCO/IR/QCOOps.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#include "mlir/Dialect/QCO/IR/QCODialect.h" // IWYU pragma: associated
1414

15+
#include <llvm/ADT/DenseSet.h>
16+
#include <llvm/ADT/STLExtras.h>
1517
#include <mlir/IR/Block.h>
1618
#include <mlir/IR/OpImplementation.h>
1719
#include <mlir/IR/Operation.h>
@@ -113,6 +115,68 @@ static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/,
113115
printer.printRegion(region, false);
114116
}
115117

118+
static ParseResult
119+
parseIfOpAliasing(OpAsmParser& parser, Region& thenRegion, Region& elseRegion,
120+
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
121+
// Parse the qubits keyword
122+
if (parser.parseKeyword("qubits")) {
123+
return failure();
124+
}
125+
126+
// Parse the then region
127+
if (parseTargetAliasing(parser, thenRegion, operands)) {
128+
return failure();
129+
}
130+
const auto thenCount = operands.size();
131+
132+
// Parse the else keyword
133+
if (parser.parseKeyword("else")) {
134+
return failure();
135+
}
136+
137+
// Parse the qubits keyword
138+
if (parser.parseKeyword("qubits")) {
139+
return failure();
140+
}
141+
142+
// Parse the else region
143+
if (parseTargetAliasing(parser, elseRegion, operands)) {
144+
return failure();
145+
}
146+
147+
const auto elseCount = operands.size() - thenCount;
148+
149+
if (thenCount != elseCount) {
150+
return parser.emitError(
151+
parser.getCurrentLocation(),
152+
"then/else qubit aliasing lists must be the same length");
153+
}
154+
for (unsigned i = 0; i < thenCount; ++i) {
155+
if (operands[i].name != operands[thenCount + i].name) {
156+
return parser.emitError(
157+
parser.getCurrentLocation(),
158+
"then/else qubit aliasing lists must match in order");
159+
}
160+
}
161+
162+
// Remove duplicate operands
163+
llvm::DenseSet<llvm::StringRef> seen;
164+
llvm::erase_if(operands,
165+
[&](const auto& op) { return !seen.insert(op.name).second; });
166+
167+
return success();
168+
}
169+
170+
static void printIfOpAliasing(OpAsmPrinter& printer, Operation* op,
171+
Region& thenRegion, Region& elseRegion,
172+
OperandRange qubits) {
173+
printer << "qubits";
174+
printTargetAliasing(printer, op, thenRegion, qubits);
175+
printer << " else ";
176+
printer << "qubits";
177+
printTargetAliasing(printer, op, elseRegion, qubits);
178+
}
179+
116180
//===----------------------------------------------------------------------===//
117181
// Dialect
118182
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)