Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,44 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return create<cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
}

mlir::Value createShift(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
bool isShiftLeft) {
return create<cir::ShiftOp>(loc, lhs.getType(), lhs, rhs, isShiftLeft);
}

mlir::Value createShift(mlir::Location loc, mlir::Value lhs,
const llvm::APInt &rhs, bool isShiftLeft) {
return createShift(loc, lhs, getConstAPInt(loc, lhs.getType(), rhs),
isShiftLeft);
}

mlir::Value createShift(mlir::Location loc, mlir::Value lhs, unsigned bits,
bool isShiftLeft) {
auto width = mlir::dyn_cast<cir::IntType>(lhs.getType()).getWidth();
auto shift = llvm::APInt(width, bits);
return createShift(loc, lhs, shift, isShiftLeft);
}

mlir::Value createShiftLeft(mlir::Location loc, mlir::Value lhs,
unsigned bits) {
return createShift(loc, lhs, bits, true);
}

mlir::Value createShiftRight(mlir::Location loc, mlir::Value lhs,
unsigned bits) {
return createShift(loc, lhs, bits, false);
}

mlir::Value createShiftLeft(mlir::Location loc, mlir::Value lhs,
mlir::Value rhs) {
return createShift(loc, lhs, rhs, true);
}

mlir::Value createShiftRight(mlir::Location loc, mlir::Value lhs,
mlir::Value rhs) {
return createShift(loc, lhs, rhs, false);
}

//
// Block handling helpers
// ----------------------
Expand Down
73 changes: 73 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,79 @@ def BinOp : CIR_Op<"binop", [Pure,
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//

def ShiftOp : CIR_Op<"shift", [Pure]> {
let summary = "Shift";
let description = [{
The `cir.shift` operation performs a bitwise shift, either to the left or to
the right, based on the first operand. The second operand specifies the
value to be shifted, and the third operand determines the number of
positions by which the shift is applied. Both the second and third operands
are required to be integers.

```mlir
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
```
}];

// TODO(cir): Support vectors. CIR_IntType -> CIR_AnyIntOrVecOfInt. Also
// update the description above.
let results = (outs CIR_IntType:$result);
let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
UnitAttr:$isShiftleft);

let assemblyFormat = [{
`(`
(`left` $isShiftleft^) : (```right`)?
`,` $value `:` type($value)
`,` $amount `:` type($amount)
`)` `->` type($result) attr-dict
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

def SelectOp : CIR_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>]> {
let summary = "Yield one of two values based on a boolean value";
let description = [{
The `cir.select` operation takes three operands. The first operand
`condition` is a boolean value of type `!cir.bool`. The second and the third
operand can be of any CIR types, but their types must be the same. If the
first operand is `true`, the operation yields its second operand. Otherwise,
the operation yields its third operand.

Example:

```mlir
%0 = cir.const #cir.bool<true> : !cir.bool
%1 = cir.const #cir.int<42> : !s32i
%2 = cir.const #cir.int<72> : !s32i
%3 = cir.select if %0 then %1 else %2 : (!cir.bool, !s32i, !s32i) -> !s32i
```
}];

let arguments = (ins CIR_BoolType:$condition, CIR_AnyType:$true_value,
CIR_AnyType:$false_value);
let results = (outs CIR_AnyType:$result);

let assemblyFormat = [{
`if` $condition `then` $true_value `else` $false_value
`:` `(`
qualified(type($condition)) `,`
qualified(type($true_value)) `,`
qualified(type($false_value))
`)` `->` qualified(type($result)) attr-dict
}];
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 2 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,8 +1308,7 @@ mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &ops) {
mlir::isa<cir::IntType>(ops.lhs.getType()))
cgf.cgm.errorNYI("sanitizers");

cgf.cgm.errorNYI("shift ops");
return {};
return builder.createShiftLeft(cgf.getLoc(ops.loc), ops.lhs, ops.rhs);
}

mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) {
Expand All @@ -1333,8 +1332,7 @@ mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) {

// Note that we don't need to distinguish unsigned treatment at this
// point since it will be handled later by LLVM lowering.
cgf.cgm.errorNYI("shift ops");
return {};
return builder.createShiftRight(cgf.getLoc(ops.loc), ops.lhs, ops.rhs);
}

mlir::Value ScalarExprEmitter::emitAnd(const BinOpInfo &ops) {
Expand Down
21 changes: 21 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,9 @@ mlir::OpTrait::impl::verifySameFirstOperandAndResultType(Operation *op) {
// been implemented yet.
mlir::LogicalResult cir::FuncOp::verify() { return success(); }

//===----------------------------------------------------------------------===//
// BinOp
//===----------------------------------------------------------------------===//
LogicalResult cir::BinOp::verify() {
bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
bool saturated = getSaturated();
Expand Down Expand Up @@ -1028,6 +1031,24 @@ LogicalResult cir::BinOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//
LogicalResult cir::ShiftOp::verify() {
mlir::Operation *op = getOperation();
mlir::Type resType = getResult().getType();
assert(!cir::MissingFeatures::vectorType());
bool isOp0Vec = false;
bool isOp1Vec = false;
if (isOp0Vec != isOp1Vec)
return emitOpError() << "input types cannot be one vector and one scalar";
if (isOp1Vec && op->getOperand(1).getType() != resType) {
return emitOpError() << "shift amount must have the type of the result "
<< "if it is vector shift";
}
return mlir::success();
}

//===----------------------------------------------------------------------===//
// UnaryOp
//===----------------------------------------------------------------------===//
Expand Down
88 changes: 88 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Types.h"
Expand All @@ -28,6 +29,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Transforms/DialectConversion.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/LoweringHelpers.h"
Expand Down Expand Up @@ -1294,6 +1296,90 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
cir::ShiftOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType());

// Operands could also be vector type
assert(!cir::MissingFeatures::vectorType());
mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
mlir::Value amt = adaptor.getAmount();
mlir::Value val = adaptor.getValue();

// TODO(cir): Assert for vector types
assert((cirValTy && cirAmtTy) &&
"shift input type must be integer or vector type, otherwise NYI");

assert((cirValTy == op.getType()) && "inconsistent operands' types NYI");

// Ensure shift amount is the same type as the value. Some undefined
// behavior might occur in the casts below as per [C99 6.5.7.3].
// Vector type shift amount needs no cast as type consistency is expected to
// be already be enforced at CIRGen.
if (cirAmtTy)
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
true, cirAmtTy.getWidth(), cirValTy.getWidth());

// Lower to the proper LLVM shift operation.
if (op.getIsShiftleft()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
} else {
assert(!cir::MissingFeatures::vectorType());
bool isUnsigned = !cirValTy.isSigned();
if (isUnsigned)
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
}

return mlir::success();
}

mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite(
cir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto getConstantBool = [](mlir::Value value) -> cir::BoolAttr {
auto definingOp =
mlir::dyn_cast_if_present<cir::ConstantOp>(value.getDefiningOp());
if (!definingOp)
return {};

auto constValue = mlir::dyn_cast<cir::BoolAttr>(definingOp.getValue());
if (!constValue)
return {};

return constValue;
};

// Two special cases in the LLVMIR codegen of select op:
// - select %0, %1, false => and %0, %1
// - select %0, true, %1 => or %0, %1
if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) {
cir::BoolAttr trueValue = getConstantBool(op.getTrueValue());
cir::BoolAttr falseValue = getConstantBool(op.getFalseValue());
if (falseValue && !falseValue.getValue()) {
// select %0, %1, false => and %0, %1
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(),
adaptor.getTrueValue());
return mlir::success();
}
if (trueValue && trueValue.getValue()) {
// select %0, true, %1 => or %0, %1
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(),
adaptor.getFalseValue());
return mlir::success();
}
}

mlir::Value llvmCondition = adaptor.getCondition();
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());

return mlir::success();
}

static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
mlir::DataLayout &dataLayout) {
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
Expand Down Expand Up @@ -1439,6 +1525,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMConstantOpLowering,
CIRToLLVMFuncOpLowering,
CIRToLLVMGetGlobalOpLowering,
CIRToLLVMSelectOpLowering,
CIRToLLVMShiftOpLowering,
CIRToLLVMTrapOpLowering,
CIRToLLVMUnaryOpLowering
// clang-format on
Expand Down
20 changes: 20 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,26 @@ class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMShiftOpLowering
: public mlir::OpConversionPattern<cir::ShiftOp> {
public:
using mlir::OpConversionPattern<cir::ShiftOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::ShiftOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMSelectOpLowering
: public mlir::OpConversionPattern<cir::SelectOp> {
public:
using mlir::OpConversionPattern<cir::SelectOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::SelectOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBrOpLowering : public mlir::OpConversionPattern<cir::BrOp> {
public:
using mlir::OpConversionPattern<cir::BrOp>::OpConversionPattern;
Expand Down
Loading
Loading