Skip to content

Commit fc54913

Browse files
committed
[CIR] Upstream SelectOp and ShiftOp
Since SelectOp will only generated by a future pass that transforms a TernaryOp this only includes the lowering bits. This patchs also improves the testing of the existing binary operators.
1 parent 25f4f0a commit fc54913

File tree

7 files changed

+579
-18
lines changed

7 files changed

+579
-18
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,79 @@ def BinOp : CIR_Op<"binop", [Pure,
889889
let hasVerifier = 1;
890890
}
891891

892+
//===----------------------------------------------------------------------===//
893+
// ShiftOp
894+
//===----------------------------------------------------------------------===//
895+
896+
def ShiftOp : CIR_Op<"shift", [Pure]> {
897+
let summary = "Shift";
898+
let description = [{
899+
Shift `left` or `right`, according to the first operand. Second operand is
900+
the shift target and the third the amount. Second and the thrid operand are
901+
integers.
902+
903+
```mlir
904+
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
905+
```
906+
}];
907+
908+
// TODO(cir): Support vectors. CIR_IntType -> CIR_AnyIntOrVecOfInt. Also
909+
// update the description above.
910+
let results = (outs CIR_IntType:$result);
911+
let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
912+
UnitAttr:$isShiftleft);
913+
914+
let assemblyFormat = [{
915+
`(`
916+
(`left` $isShiftleft^) : (```right`)?
917+
`,` $value `:` type($value)
918+
`,` $amount `:` type($amount)
919+
`)` `->` type($result) attr-dict
920+
}];
921+
922+
let hasVerifier = 1;
923+
}
924+
925+
//===----------------------------------------------------------------------===//
926+
// SelectOp
927+
//===----------------------------------------------------------------------===//
928+
929+
def SelectOp : CIR_Op<"select", [Pure,
930+
AllTypesMatch<["true_value", "false_value", "result"]>]> {
931+
let summary = "Yield one of two values based on a boolean value";
932+
let description = [{
933+
The `cir.select` operation takes three operands. The first operand
934+
`condition` is a boolean value of type `!cir.bool`. The second and the third
935+
operand can be of any CIR types, but their types must be the same. If the
936+
first operand is `true`, the operation yields its second operand. Otherwise,
937+
the operation yields its third operand.
938+
939+
Example:
940+
941+
```mlir
942+
%0 = cir.const #cir.bool<true> : !cir.bool
943+
%1 = cir.const #cir.int<42> : !s32i
944+
%2 = cir.const #cir.int<72> : !s32i
945+
%3 = cir.select if %0 then %1 else %2 : (!cir.bool, !s32i, !s32i) -> !s32i
946+
```
947+
}];
948+
949+
let arguments = (ins CIR_BoolType:$condition, CIR_AnyType:$true_value,
950+
CIR_AnyType:$false_value);
951+
let results = (outs CIR_AnyType:$result);
952+
953+
let assemblyFormat = [{
954+
`if` $condition `then` $true_value `else` $false_value
955+
`:` `(`
956+
qualified(type($condition)) `,`
957+
qualified(type($true_value)) `,`
958+
qualified(type($false_value))
959+
`)` `->` qualified(type($result)) attr-dict
960+
}];
961+
962+
let hasFolder = 1;
963+
}
964+
892965
//===----------------------------------------------------------------------===//
893966
// GlobalOp
894967
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,8 +1138,9 @@ mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &ops) {
11381138
mlir::isa<cir::IntType>(ops.lhs.getType()))
11391139
cgf.cgm.errorNYI("sanitizers");
11401140

1141-
cgf.cgm.errorNYI("shift ops");
1142-
return {};
1141+
return builder.create<cir::ShiftOp>(cgf.getLoc(ops.loc),
1142+
cgf.convertType(ops.fullType), ops.lhs,
1143+
ops.rhs, cgf.getBuilder().getUnitAttr());
11431144
}
11441145

11451146
mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) {
@@ -1163,8 +1164,8 @@ mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) {
11631164

11641165
// Note that we don't need to distinguish unsigned treatment at this
11651166
// point since it will be handled later by LLVM lowering.
1166-
cgf.cgm.errorNYI("shift ops");
1167-
return {};
1167+
return builder.create<cir::ShiftOp>(
1168+
cgf.getLoc(ops.loc), cgf.convertType(ops.fullType), ops.lhs, ops.rhs);
11681169
}
11691170

11701171
mlir::Value ScalarExprEmitter::emitAnd(const BinOpInfo &ops) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,9 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
728728
// been implemented yet.
729729
mlir::LogicalResult cir::FuncOp::verify() { return success(); }
730730

731+
//===----------------------------------------------------------------------===//
732+
// BinOp
733+
//===----------------------------------------------------------------------===//
731734
LogicalResult cir::BinOp::verify() {
732735
bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
733736
bool saturated = getSaturated();
@@ -759,6 +762,46 @@ LogicalResult cir::BinOp::verify() {
759762
return mlir::success();
760763
}
761764

765+
//===----------------------------------------------------------------------===//
766+
// ShiftOp
767+
//===----------------------------------------------------------------------===//
768+
LogicalResult cir::ShiftOp::verify() {
769+
mlir::Operation *op = getOperation();
770+
mlir::Type resType = getResult().getType();
771+
assert(!cir::MissingFeatures::vectorType());
772+
bool isOp0Vec = false;
773+
bool isOp1Vec = false;
774+
if (isOp0Vec != isOp1Vec)
775+
return emitOpError() << "input types cannot be one vector and one scalar";
776+
if (isOp1Vec && op->getOperand(1).getType() != resType) {
777+
return emitOpError() << "shift amount must have the type of the result "
778+
<< "if it is vector shift";
779+
}
780+
return mlir::success();
781+
}
782+
783+
//===----------------------------------------------------------------------===//
784+
// SelectOp
785+
//===----------------------------------------------------------------------===//
786+
787+
OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
788+
mlir::Attribute condition = adaptor.getCondition();
789+
if (condition) {
790+
bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
791+
return conditionValue ? getTrueValue() : getFalseValue();
792+
}
793+
794+
// cir.select if %0 then x else x -> x
795+
mlir::Attribute trueValue = adaptor.getTrueValue();
796+
mlir::Attribute falseValue = adaptor.getFalseValue();
797+
if (trueValue && trueValue == falseValue)
798+
return trueValue;
799+
if (getTrueValue() == getFalseValue())
800+
return getTrueValue();
801+
802+
return {};
803+
}
804+
762805
//===----------------------------------------------------------------------===//
763806
// UnaryOp
764807
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,91 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
11171117
return mlir::LogicalResult::success();
11181118
}
11191119

1120+
mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
1121+
cir::ShiftOp op, OpAdaptor adaptor,
1122+
mlir::ConversionPatternRewriter &rewriter) const {
1123+
auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
1124+
auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType());
1125+
1126+
// Operands could also be vector type
1127+
assert(!cir::MissingFeatures::vectorType());
1128+
mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
1129+
mlir::Value amt = adaptor.getAmount();
1130+
mlir::Value val = adaptor.getValue();
1131+
1132+
// TODO(cir): Assert for vector types
1133+
assert((cirValTy && cirAmtTy) &&
1134+
"shift input type must be integer or vector type, otherwise NYI");
1135+
1136+
assert((cirValTy == op.getType()) && "inconsistent operands' types NYI");
1137+
1138+
// Ensure shift amount is the same type as the value. Some undefined
1139+
// behavior might occur in the casts below as per [C99 6.5.7.3].
1140+
// Vector type shift amount needs no cast as type consistency is expected to
1141+
// be already be enforced at CIRGen.
1142+
if (cirAmtTy)
1143+
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
1144+
!cirAmtTy.isSigned(), cirAmtTy.getWidth(),
1145+
cirValTy.getWidth());
1146+
1147+
// Lower to the proper LLVM shift operation.
1148+
if (op.getIsShiftleft()) {
1149+
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
1150+
} else {
1151+
assert(!cir::MissingFeatures::vectorType());
1152+
bool isUnsigned = !cirValTy.isSigned();
1153+
if (isUnsigned)
1154+
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
1155+
else
1156+
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
1157+
}
1158+
1159+
return mlir::success();
1160+
}
1161+
1162+
mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite(
1163+
cir::SelectOp op, OpAdaptor adaptor,
1164+
mlir::ConversionPatternRewriter &rewriter) const {
1165+
auto getConstantBool = [](mlir::Value value) -> std::optional<bool> {
1166+
auto definingOp =
1167+
mlir::dyn_cast_if_present<cir::ConstantOp>(value.getDefiningOp());
1168+
if (!definingOp)
1169+
return std::nullopt;
1170+
1171+
auto constValue = mlir::dyn_cast<cir::BoolAttr>(definingOp.getValue());
1172+
if (!constValue)
1173+
return std::nullopt;
1174+
1175+
return constValue.getValue();
1176+
};
1177+
1178+
// Two special cases in the LLVMIR codegen of select op:
1179+
// - select %0, %1, false => and %0, %1
1180+
// - select %0, true, %1 => or %0, %1
1181+
if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) {
1182+
std::optional<bool> trueValue = getConstantBool(op.getTrueValue());
1183+
std::optional<bool> falseValue = getConstantBool(op.getFalseValue());
1184+
if (falseValue.has_value() && !*falseValue) {
1185+
// select %0, %1, false => and %0, %1
1186+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(),
1187+
adaptor.getTrueValue());
1188+
return mlir::success();
1189+
}
1190+
if (trueValue.has_value() && *trueValue) {
1191+
// select %0, true, %1 => or %0, %1
1192+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(),
1193+
adaptor.getFalseValue());
1194+
return mlir::success();
1195+
}
1196+
}
1197+
1198+
mlir::Value llvmCondition = adaptor.getCondition();
1199+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
1200+
op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());
1201+
1202+
return mlir::success();
1203+
}
1204+
11201205
static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
11211206
mlir::DataLayout &dataLayout) {
11221207
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1259,6 +1344,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
12591344
CIRToLLVMBrCondOpLowering,
12601345
CIRToLLVMBrOpLowering,
12611346
CIRToLLVMFuncOpLowering,
1347+
CIRToLLVMSelectOpLowering,
1348+
CIRToLLVMShiftOpLowering,
12621349
CIRToLLVMTrapOpLowering,
12631350
CIRToLLVMUnaryOpLowering
12641351
// clang-format on

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,26 @@ class CIRToLLVMBinOpLowering : public mlir::OpConversionPattern<cir::BinOp> {
189189
mlir::ConversionPatternRewriter &) const override;
190190
};
191191

192+
class CIRToLLVMShiftOpLowering
193+
: public mlir::OpConversionPattern<cir::ShiftOp> {
194+
public:
195+
using mlir::OpConversionPattern<cir::ShiftOp>::OpConversionPattern;
196+
197+
mlir::LogicalResult
198+
matchAndRewrite(cir::ShiftOp op, OpAdaptor,
199+
mlir::ConversionPatternRewriter &) const override;
200+
};
201+
202+
class CIRToLLVMSelectOpLowering
203+
: public mlir::OpConversionPattern<cir::SelectOp> {
204+
public:
205+
using mlir::OpConversionPattern<cir::SelectOp>::OpConversionPattern;
206+
207+
mlir::LogicalResult
208+
matchAndRewrite(cir::SelectOp op, OpAdaptor,
209+
mlir::ConversionPatternRewriter &) const override;
210+
};
211+
192212
class CIRToLLVMBrOpLowering : public mlir::OpConversionPattern<cir::BrOp> {
193213
public:
194214
using mlir::OpConversionPattern<cir::BrOp>::OpConversionPattern;

0 commit comments

Comments
 (0)