Skip to content

Commit 803abd7

Browse files
Add support for FlattenCFG switch and introduce SwitchFlatOp
1 parent a4eb0db commit 803abd7

File tree

6 files changed

+734
-6
lines changed

6 files changed

+734
-6
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,52 @@ def SwitchOp : CIR_Op<"switch",
971971
}];
972972
}
973973

974+
//===----------------------------------------------------------------------===//
975+
// SwitchFlatOp
976+
//===----------------------------------------------------------------------===//
977+
978+
def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments,
979+
Terminator]> {
980+
981+
let description = [{
982+
The `cir.switch.flat` operation is a region-less and simplified
983+
version of the `cir.switch`.
984+
It's representation is closer to LLVM IR dialect
985+
than the C/C++ language feature.
986+
}];
987+
988+
let arguments = (ins
989+
CIR_IntType:$condition,
990+
Variadic<AnyType>:$defaultOperands,
991+
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
992+
ArrayAttr:$case_values,
993+
DenseI32ArrayAttr:$case_operand_segments
994+
);
995+
996+
let successors = (successor
997+
AnySuccessor:$defaultDestination,
998+
VariadicSuccessor<AnySuccessor>:$caseDestinations
999+
);
1000+
1001+
let assemblyFormat = [{
1002+
$condition `:` type($condition) `,`
1003+
$defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
1004+
custom<SwitchFlatOpCases>(ref(type($condition)), $case_values,
1005+
$caseDestinations, $caseOperands,
1006+
type($caseOperands))
1007+
attr-dict
1008+
}];
1009+
1010+
let builders = [
1011+
OpBuilder<(ins "mlir::Value":$condition,
1012+
"mlir::Block *":$defaultDestination,
1013+
"mlir::ValueRange":$defaultOperands,
1014+
CArg<"llvm::ArrayRef<llvm::APInt>", "{}">:$caseValues,
1015+
CArg<"mlir::BlockRange", "{}">:$caseDestinations,
1016+
CArg<"llvm::ArrayRef<mlir::ValueRange>", "{}">:$caseOperands)>
1017+
];
1018+
}
1019+
9741020
//===----------------------------------------------------------------------===//
9751021
// BrOp
9761022
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
2323
#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
2424
#include "clang/CIR/MissingFeatures.h"
25+
#include <numeric>
2526

2627
using namespace mlir;
2728
using namespace cir;
@@ -962,6 +963,102 @@ bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
962963
});
963964
}
964965

966+
//===----------------------------------------------------------------------===//
967+
// SwitchFlatOp
968+
//===----------------------------------------------------------------------===//
969+
970+
void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
971+
Value value, Block *defaultDestination,
972+
ValueRange defaultOperands,
973+
ArrayRef<APInt> caseValues,
974+
BlockRange caseDestinations,
975+
ArrayRef<ValueRange> caseOperands) {
976+
977+
std::vector<mlir::Attribute> caseValuesAttrs;
978+
for (auto &val : caseValues) {
979+
caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
980+
}
981+
mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
982+
983+
build(builder, result, value, defaultOperands, caseOperands, attrs,
984+
defaultDestination, caseDestinations);
985+
}
986+
987+
/// <cases> ::= `[` (case (`,` case )* )? `]`
988+
/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
989+
static ParseResult parseSwitchFlatOpCases(
990+
OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
991+
SmallVectorImpl<Block *> &caseDestinations,
992+
SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
993+
&caseOperands,
994+
SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
995+
if (failed(parser.parseLSquare()))
996+
return failure();
997+
if (succeeded(parser.parseOptionalRSquare()))
998+
return success();
999+
llvm::SmallVector<mlir::Attribute> values;
1000+
1001+
auto parseCase = [&]() {
1002+
int64_t value = 0;
1003+
if (failed(parser.parseInteger(value)))
1004+
return failure();
1005+
1006+
values.push_back(cir::IntAttr::get(flagType, value));
1007+
1008+
Block *destination;
1009+
llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
1010+
llvm::SmallVector<Type> operandTypes;
1011+
if (parser.parseColon() || parser.parseSuccessor(destination))
1012+
return failure();
1013+
if (!parser.parseOptionalLParen()) {
1014+
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1015+
/*allowResultNumber=*/false) ||
1016+
parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1017+
return failure();
1018+
}
1019+
caseDestinations.push_back(destination);
1020+
caseOperands.emplace_back(operands);
1021+
caseOperandTypes.emplace_back(operandTypes);
1022+
return success();
1023+
};
1024+
if (failed(parser.parseCommaSeparatedList(parseCase)))
1025+
return failure();
1026+
1027+
caseValues = ArrayAttr::get(flagType.getContext(), values);
1028+
1029+
return parser.parseRSquare();
1030+
}
1031+
1032+
static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1033+
Type flagType, mlir::ArrayAttr caseValues,
1034+
SuccessorRange caseDestinations,
1035+
OperandRangeRange caseOperands,
1036+
const TypeRangeRange &caseOperandTypes) {
1037+
p << '[';
1038+
p.printNewline();
1039+
if (!caseValues) {
1040+
p << ']';
1041+
return;
1042+
}
1043+
1044+
size_t index = 0;
1045+
llvm::interleave(
1046+
llvm::zip(caseValues, caseDestinations),
1047+
[&](auto i) {
1048+
p << " ";
1049+
mlir::Attribute a = std::get<0>(i);
1050+
p << mlir::cast<cir::IntAttr>(a).getValue();
1051+
p << ": ";
1052+
p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1053+
},
1054+
[&] {
1055+
p << ',';
1056+
p.printNewline();
1057+
});
1058+
p.printNewline();
1059+
p << ']';
1060+
}
1061+
9651062
//===----------------------------------------------------------------------===//
9661063
// GlobalOp
9671064
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
8484
}
8585
};
8686

87+
struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
88+
using OpRewritePattern<SwitchOp>::OpRewritePattern;
89+
90+
LogicalResult matchAndRewrite(SwitchOp op,
91+
PatternRewriter &rewriter) const final {
92+
if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front())))
93+
return failure();
94+
95+
rewriter.eraseOp(op);
96+
return success();
97+
}
98+
};
99+
87100
//===----------------------------------------------------------------------===//
88101
// CIRCanonicalizePass
89102
//===----------------------------------------------------------------------===//
@@ -127,8 +140,7 @@ void CIRCanonicalizePass::runOnOperation() {
127140
assert(!cir::MissingFeatures::callOp());
128141
// CastOp, UnaryOp and VecExtractOp are here to perform a manual `fold` in
129142
// applyOpPatternsGreedily.
130-
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp, VecExtractOp>(
131-
op))
143+
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp, VecExtractOp>(op))
132144
ops.push_back(op);
133145
});
134146

0 commit comments

Comments
 (0)