Skip to content

Commit 1383377

Browse files
committed
[CIR] Upstream SelectOp and ShiftOp
Since SelectOp will only generated by a future pass that transforms a TernaryOp this only includes the lowering bits. This patchs also improves the testing of the existing binary operators.
1 parent 10a1502 commit 1383377

File tree

7 files changed

+579
-18
lines changed

7 files changed

+579
-18
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,79 @@ def BinOp : CIR_Op<"binop", [Pure,
11761176
let hasVerifier = 1;
11771177
}
11781178

1179+
//===----------------------------------------------------------------------===//
1180+
// ShiftOp
1181+
//===----------------------------------------------------------------------===//
1182+
1183+
def ShiftOp : CIR_Op<"shift", [Pure]> {
1184+
let summary = "Shift";
1185+
let description = [{
1186+
Shift `left` or `right`, according to the first operand. Second operand is
1187+
the shift target and the third the amount. Second and the thrid operand are
1188+
integers.
1189+
1190+
```mlir
1191+
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
1192+
```
1193+
}];
1194+
1195+
// TODO(cir): Support vectors. CIR_IntType -> CIR_AnyIntOrVecOfInt. Also
1196+
// update the description above.
1197+
let results = (outs CIR_IntType:$result);
1198+
let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
1199+
UnitAttr:$isShiftleft);
1200+
1201+
let assemblyFormat = [{
1202+
`(`
1203+
(`left` $isShiftleft^) : (```right`)?
1204+
`,` $value `:` type($value)
1205+
`,` $amount `:` type($amount)
1206+
`)` `->` type($result) attr-dict
1207+
}];
1208+
1209+
let hasVerifier = 1;
1210+
}
1211+
1212+
//===----------------------------------------------------------------------===//
1213+
// SelectOp
1214+
//===----------------------------------------------------------------------===//
1215+
1216+
def SelectOp : CIR_Op<"select", [Pure,
1217+
AllTypesMatch<["true_value", "false_value", "result"]>]> {
1218+
let summary = "Yield one of two values based on a boolean value";
1219+
let description = [{
1220+
The `cir.select` operation takes three operands. The first operand
1221+
`condition` is a boolean value of type `!cir.bool`. The second and the third
1222+
operand can be of any CIR types, but their types must be the same. If the
1223+
first operand is `true`, the operation yields its second operand. Otherwise,
1224+
the operation yields its third operand.
1225+
1226+
Example:
1227+
1228+
```mlir
1229+
%0 = cir.const #cir.bool<true> : !cir.bool
1230+
%1 = cir.const #cir.int<42> : !s32i
1231+
%2 = cir.const #cir.int<72> : !s32i
1232+
%3 = cir.select if %0 then %1 else %2 : (!cir.bool, !s32i, !s32i) -> !s32i
1233+
```
1234+
}];
1235+
1236+
let arguments = (ins CIR_BoolType:$condition, CIR_AnyType:$true_value,
1237+
CIR_AnyType:$false_value);
1238+
let results = (outs CIR_AnyType:$result);
1239+
1240+
let assemblyFormat = [{
1241+
`if` $condition `then` $true_value `else` $false_value
1242+
`:` `(`
1243+
qualified(type($condition)) `,`
1244+
qualified(type($true_value)) `,`
1245+
qualified(type($false_value))
1246+
`)` `->` qualified(type($result)) attr-dict
1247+
}];
1248+
1249+
let hasFolder = 1;
1250+
}
1251+
11791252
//===----------------------------------------------------------------------===//
11801253
// GlobalOp
11811254
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,8 +1308,9 @@ mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &ops) {
13081308
mlir::isa<cir::IntType>(ops.lhs.getType()))
13091309
cgf.cgm.errorNYI("sanitizers");
13101310

1311-
cgf.cgm.errorNYI("shift ops");
1312-
return {};
1311+
return builder.create<cir::ShiftOp>(cgf.getLoc(ops.loc),
1312+
cgf.convertType(ops.fullType), ops.lhs,
1313+
ops.rhs, cgf.getBuilder().getUnitAttr());
13131314
}
13141315

13151316
mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) {
@@ -1333,8 +1334,8 @@ mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) {
13331334

13341335
// Note that we don't need to distinguish unsigned treatment at this
13351336
// point since it will be handled later by LLVM lowering.
1336-
cgf.cgm.errorNYI("shift ops");
1337-
return {};
1337+
return builder.create<cir::ShiftOp>(
1338+
cgf.getLoc(ops.loc), cgf.convertType(ops.fullType), ops.lhs, ops.rhs);
13381339
}
13391340

13401341
mlir::Value ScalarExprEmitter::emitAnd(const BinOpInfo &ops) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,9 @@ mlir::OpTrait::impl::verifySameFirstOperandAndResultType(Operation *op) {
997997
// been implemented yet.
998998
mlir::LogicalResult cir::FuncOp::verify() { return success(); }
999999

1000+
//===----------------------------------------------------------------------===//
1001+
// BinOp
1002+
//===----------------------------------------------------------------------===//
10001003
LogicalResult cir::BinOp::verify() {
10011004
bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
10021005
bool saturated = getSaturated();
@@ -1028,6 +1031,46 @@ LogicalResult cir::BinOp::verify() {
10281031
return mlir::success();
10291032
}
10301033

1034+
//===----------------------------------------------------------------------===//
1035+
// ShiftOp
1036+
//===----------------------------------------------------------------------===//
1037+
LogicalResult cir::ShiftOp::verify() {
1038+
mlir::Operation *op = getOperation();
1039+
mlir::Type resType = getResult().getType();
1040+
assert(!cir::MissingFeatures::vectorType());
1041+
bool isOp0Vec = false;
1042+
bool isOp1Vec = false;
1043+
if (isOp0Vec != isOp1Vec)
1044+
return emitOpError() << "input types cannot be one vector and one scalar";
1045+
if (isOp1Vec && op->getOperand(1).getType() != resType) {
1046+
return emitOpError() << "shift amount must have the type of the result "
1047+
<< "if it is vector shift";
1048+
}
1049+
return mlir::success();
1050+
}
1051+
1052+
//===----------------------------------------------------------------------===//
1053+
// SelectOp
1054+
//===----------------------------------------------------------------------===//
1055+
1056+
OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
1057+
mlir::Attribute condition = adaptor.getCondition();
1058+
if (condition) {
1059+
bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
1060+
return conditionValue ? getTrueValue() : getFalseValue();
1061+
}
1062+
1063+
// cir.select if %0 then x else x -> x
1064+
mlir::Attribute trueValue = adaptor.getTrueValue();
1065+
mlir::Attribute falseValue = adaptor.getFalseValue();
1066+
if (trueValue && trueValue == falseValue)
1067+
return trueValue;
1068+
if (getTrueValue() == getFalseValue())
1069+
return getTrueValue();
1070+
1071+
return {};
1072+
}
1073+
10311074
//===----------------------------------------------------------------------===//
10321075
// UnaryOp
10331076
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,91 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
12941294
return mlir::success();
12951295
}
12961296

1297+
mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
1298+
cir::ShiftOp op, OpAdaptor adaptor,
1299+
mlir::ConversionPatternRewriter &rewriter) const {
1300+
auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
1301+
auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType());
1302+
1303+
// Operands could also be vector type
1304+
assert(!cir::MissingFeatures::vectorType());
1305+
mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
1306+
mlir::Value amt = adaptor.getAmount();
1307+
mlir::Value val = adaptor.getValue();
1308+
1309+
// TODO(cir): Assert for vector types
1310+
assert((cirValTy && cirAmtTy) &&
1311+
"shift input type must be integer or vector type, otherwise NYI");
1312+
1313+
assert((cirValTy == op.getType()) && "inconsistent operands' types NYI");
1314+
1315+
// Ensure shift amount is the same type as the value. Some undefined
1316+
// behavior might occur in the casts below as per [C99 6.5.7.3].
1317+
// Vector type shift amount needs no cast as type consistency is expected to
1318+
// be already be enforced at CIRGen.
1319+
if (cirAmtTy)
1320+
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
1321+
!cirAmtTy.isSigned(), cirAmtTy.getWidth(),
1322+
cirValTy.getWidth());
1323+
1324+
// Lower to the proper LLVM shift operation.
1325+
if (op.getIsShiftleft()) {
1326+
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
1327+
} else {
1328+
assert(!cir::MissingFeatures::vectorType());
1329+
bool isUnsigned = !cirValTy.isSigned();
1330+
if (isUnsigned)
1331+
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
1332+
else
1333+
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
1334+
}
1335+
1336+
return mlir::success();
1337+
}
1338+
1339+
mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite(
1340+
cir::SelectOp op, OpAdaptor adaptor,
1341+
mlir::ConversionPatternRewriter &rewriter) const {
1342+
auto getConstantBool = [](mlir::Value value) -> std::optional<bool> {
1343+
auto definingOp =
1344+
mlir::dyn_cast_if_present<cir::ConstantOp>(value.getDefiningOp());
1345+
if (!definingOp)
1346+
return std::nullopt;
1347+
1348+
auto constValue = mlir::dyn_cast<cir::BoolAttr>(definingOp.getValue());
1349+
if (!constValue)
1350+
return std::nullopt;
1351+
1352+
return constValue.getValue();
1353+
};
1354+
1355+
// Two special cases in the LLVMIR codegen of select op:
1356+
// - select %0, %1, false => and %0, %1
1357+
// - select %0, true, %1 => or %0, %1
1358+
if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) {
1359+
std::optional<bool> trueValue = getConstantBool(op.getTrueValue());
1360+
std::optional<bool> falseValue = getConstantBool(op.getFalseValue());
1361+
if (falseValue.has_value() && !*falseValue) {
1362+
// select %0, %1, false => and %0, %1
1363+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(),
1364+
adaptor.getTrueValue());
1365+
return mlir::success();
1366+
}
1367+
if (trueValue.has_value() && *trueValue) {
1368+
// select %0, true, %1 => or %0, %1
1369+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(),
1370+
adaptor.getFalseValue());
1371+
return mlir::success();
1372+
}
1373+
}
1374+
1375+
mlir::Value llvmCondition = adaptor.getCondition();
1376+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
1377+
op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());
1378+
1379+
return mlir::success();
1380+
}
1381+
12971382
static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
12981383
mlir::DataLayout &dataLayout) {
12991384
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1439,6 +1524,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
14391524
CIRToLLVMConstantOpLowering,
14401525
CIRToLLVMFuncOpLowering,
14411526
CIRToLLVMGetGlobalOpLowering,
1527+
CIRToLLVMSelectOpLowering,
1528+
CIRToLLVMShiftOpLowering,
14421529
CIRToLLVMTrapOpLowering,
14431530
CIRToLLVMUnaryOpLowering
14441531
// clang-format on

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,26 @@ class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
209209
mlir::ConversionPatternRewriter &) const override;
210210
};
211211

212+
class CIRToLLVMShiftOpLowering
213+
: public mlir::OpConversionPattern<cir::ShiftOp> {
214+
public:
215+
using mlir::OpConversionPattern<cir::ShiftOp>::OpConversionPattern;
216+
217+
mlir::LogicalResult
218+
matchAndRewrite(cir::ShiftOp op, OpAdaptor,
219+
mlir::ConversionPatternRewriter &) const override;
220+
};
221+
222+
class CIRToLLVMSelectOpLowering
223+
: public mlir::OpConversionPattern<cir::SelectOp> {
224+
public:
225+
using mlir::OpConversionPattern<cir::SelectOp>::OpConversionPattern;
226+
227+
mlir::LogicalResult
228+
matchAndRewrite(cir::SelectOp op, OpAdaptor,
229+
mlir::ConversionPatternRewriter &) const override;
230+
};
231+
212232
class CIRToLLVMBrOpLowering : public mlir::OpConversionPattern<cir::BrOp> {
213233
public:
214234
using mlir::OpConversionPattern<cir::BrOp>::OpConversionPattern;

0 commit comments

Comments
 (0)