Skip to content

Commit 8199367

Browse files
committed
Address review feedback
- add createShift functions to CIRBuilder - Rephrase ShiftOpn description comment - Remove folding for now - Always zero-extend instead of potentially sign-extend - Tests with diffferently sized integral types
1 parent 1383377 commit 8199367

File tree

6 files changed

+276
-42
lines changed

6 files changed

+276
-42
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,44 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
353353
return create<cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
354354
}
355355

356+
mlir::Value createShift(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
357+
bool isShiftLeft) {
358+
return create<cir::ShiftOp>(loc, lhs.getType(), lhs, rhs, isShiftLeft);
359+
}
360+
361+
mlir::Value createShift(mlir::Location loc, mlir::Value lhs,
362+
const llvm::APInt &rhs, bool isShiftLeft) {
363+
return createShift(loc, lhs, getConstAPInt(loc, lhs.getType(), rhs),
364+
isShiftLeft);
365+
}
366+
367+
mlir::Value createShift(mlir::Location loc, mlir::Value lhs, unsigned bits,
368+
bool isShiftLeft) {
369+
auto width = mlir::dyn_cast<cir::IntType>(lhs.getType()).getWidth();
370+
auto shift = llvm::APInt(width, bits);
371+
return createShift(loc, lhs, shift, isShiftLeft);
372+
}
373+
374+
mlir::Value createShiftLeft(mlir::Location loc, mlir::Value lhs,
375+
unsigned bits) {
376+
return createShift(loc, lhs, bits, true);
377+
}
378+
379+
mlir::Value createShiftRight(mlir::Location loc, mlir::Value lhs,
380+
unsigned bits) {
381+
return createShift(loc, lhs, bits, false);
382+
}
383+
384+
mlir::Value createShiftLeft(mlir::Location loc, mlir::Value lhs,
385+
mlir::Value rhs) {
386+
return createShift(loc, lhs, rhs, true);
387+
}
388+
389+
mlir::Value createShiftRight(mlir::Location loc, mlir::Value lhs,
390+
mlir::Value rhs) {
391+
return createShift(loc, lhs, rhs, false);
392+
}
393+
356394
//
357395
// Block handling helpers
358396
// ----------------------

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,9 +1183,11 @@ def BinOp : CIR_Op<"binop", [Pure,
11831183
def ShiftOp : CIR_Op<"shift", [Pure]> {
11841184
let summary = "Shift";
11851185
let description = [{
1186-
Shift `left` or `right`, according to the first operand. Second operand is
1187-
the shift target and the third the amount. Second and the thrid operand are
1188-
integers.
1186+
The `cir.shift` operation performs a bitwise shift, either to the left or to
1187+
the right, based on the first operand. The second operand specifies the
1188+
value to be shifted, and the third operand determines the number of
1189+
positions by which the shift is applied. Both the second and third operands
1190+
are required to be integers.
11891191

11901192
```mlir
11911193
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
@@ -1245,8 +1247,6 @@ def SelectOp : CIR_Op<"select", [Pure,
12451247
qualified(type($false_value))
12461248
`)` `->` qualified(type($result)) attr-dict
12471249
}];
1248-
1249-
let hasFolder = 1;
12501250
}
12511251

12521252
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,9 +1308,7 @@ mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &ops) {
13081308
mlir::isa<cir::IntType>(ops.lhs.getType()))
13091309
cgf.cgm.errorNYI("sanitizers");
13101310

1311-
return builder.create<cir::ShiftOp>(cgf.getLoc(ops.loc),
1312-
cgf.convertType(ops.fullType), ops.lhs,
1313-
ops.rhs, cgf.getBuilder().getUnitAttr());
1311+
return builder.createShiftLeft(cgf.getLoc(ops.loc), ops.lhs, ops.rhs);
13141312
}
13151313

13161314
mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) {
@@ -1334,8 +1332,7 @@ mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) {
13341332

13351333
// Note that we don't need to distinguish unsigned treatment at this
13361334
// point since it will be handled later by LLVM lowering.
1337-
return builder.create<cir::ShiftOp>(
1338-
cgf.getLoc(ops.loc), cgf.convertType(ops.fullType), ops.lhs, ops.rhs);
1335+
return builder.createShiftRight(cgf.getLoc(ops.loc), ops.lhs, ops.rhs);
13391336
}
13401337

13411338
mlir::Value ScalarExprEmitter::emitAnd(const BinOpInfo &ops) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,28 +1049,6 @@ LogicalResult cir::ShiftOp::verify() {
10491049
return mlir::success();
10501050
}
10511051

1052-
//===----------------------------------------------------------------------===//
1053-
// SelectOp
1054-
//===----------------------------------------------------------------------===//
1055-
1056-
OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
1057-
mlir::Attribute condition = adaptor.getCondition();
1058-
if (condition) {
1059-
bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
1060-
return conditionValue ? getTrueValue() : getFalseValue();
1061-
}
1062-
1063-
// cir.select if %0 then x else x -> x
1064-
mlir::Attribute trueValue = adaptor.getTrueValue();
1065-
mlir::Attribute falseValue = adaptor.getFalseValue();
1066-
if (trueValue && trueValue == falseValue)
1067-
return trueValue;
1068-
if (getTrueValue() == getFalseValue())
1069-
return getTrueValue();
1070-
1071-
return {};
1072-
}
1073-
10741052
//===----------------------------------------------------------------------===//
10751053
// UnaryOp
10761054
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/DLTI/DLTI.h"
2020
#include "mlir/Dialect/Func/IR/FuncOps.h"
2121
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22+
#include "mlir/IR/BuiltinAttributes.h"
2223
#include "mlir/IR/BuiltinDialect.h"
2324
#include "mlir/IR/BuiltinOps.h"
2425
#include "mlir/IR/Types.h"
@@ -28,6 +29,7 @@
2829
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2930
#include "mlir/Target/LLVMIR/Export.h"
3031
#include "mlir/Transforms/DialectConversion.h"
32+
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
3133
#include "clang/CIR/Dialect/IR/CIRDialect.h"
3234
#include "clang/CIR/Dialect/Passes.h"
3335
#include "clang/CIR/LoweringHelpers.h"
@@ -1318,8 +1320,7 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
13181320
// be already be enforced at CIRGen.
13191321
if (cirAmtTy)
13201322
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
1321-
!cirAmtTy.isSigned(), cirAmtTy.getWidth(),
1322-
cirValTy.getWidth());
1323+
true, cirAmtTy.getWidth(), cirValTy.getWidth());
13231324

13241325
// Lower to the proper LLVM shift operation.
13251326
if (op.getIsShiftleft()) {
@@ -1339,32 +1340,32 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
13391340
mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite(
13401341
cir::SelectOp op, OpAdaptor adaptor,
13411342
mlir::ConversionPatternRewriter &rewriter) const {
1342-
auto getConstantBool = [](mlir::Value value) -> std::optional<bool> {
1343+
auto getConstantBool = [](mlir::Value value) -> cir::BoolAttr {
13431344
auto definingOp =
13441345
mlir::dyn_cast_if_present<cir::ConstantOp>(value.getDefiningOp());
13451346
if (!definingOp)
1346-
return std::nullopt;
1347+
return {};
13471348

13481349
auto constValue = mlir::dyn_cast<cir::BoolAttr>(definingOp.getValue());
13491350
if (!constValue)
1350-
return std::nullopt;
1351+
return {};
13511352

1352-
return constValue.getValue();
1353+
return constValue;
13531354
};
13541355

13551356
// Two special cases in the LLVMIR codegen of select op:
13561357
// - select %0, %1, false => and %0, %1
13571358
// - select %0, true, %1 => or %0, %1
13581359
if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) {
1359-
std::optional<bool> trueValue = getConstantBool(op.getTrueValue());
1360-
std::optional<bool> falseValue = getConstantBool(op.getFalseValue());
1361-
if (falseValue.has_value() && !*falseValue) {
1360+
cir::BoolAttr trueValue = getConstantBool(op.getTrueValue());
1361+
cir::BoolAttr falseValue = getConstantBool(op.getFalseValue());
1362+
if (falseValue && !falseValue.getValue()) {
13621363
// select %0, %1, false => and %0, %1
13631364
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(),
13641365
adaptor.getTrueValue());
13651366
return mlir::success();
13661367
}
1367-
if (trueValue.has_value() && *trueValue) {
1368+
if (trueValue && trueValue.getValue()) {
13681369
// select %0, true, %1 => or %0, %1
13691370
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(),
13701371
adaptor.getFalseValue());

0 commit comments

Comments
 (0)