Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,8 @@ class AttrConvertFastMathToLLVM {
convertArithFastMathAttrToLLVM(arithFMFAttr));
}
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}
Attribute getPropAttr() const { return {}; }

private:
NamedAttrList convertedAttr;
Expand All @@ -82,23 +79,37 @@ template <typename SourceOp, typename TargetOp>
class AttrConvertOverflowToLLVM {
public:
AttrConvertOverflowToLLVM(SourceOp srcOp) {
using IntegerOverflowFlagsAttr = LLVM::IntegerOverflowFlagsAttr;

// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith overflow attribute.
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
// Remove the source overflow attribute.
// Remove the source overflow attribute from the set that will be present
// in the target.
if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName))) {
overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
auto llvmFlag = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
// Create a dictionary attribute holding the overflow flags property.
// (In the LLVM dialect, the overflow flags are a property, not an
// attribute.)
MLIRContext *ctx = srcOp.getOperation()->getContext();
Builder b(ctx);
auto llvmFlagAttr = IntegerOverflowFlagsAttr::get(ctx, llvmFlag);
StringRef llvmAttrName = TargetOp::getOverflowFlagsAttrName();
SmallVector<NamedAttribute> attrs;
attrs.push_back(b.getNamedAttr(llvmAttrName, llvmFlagAttr));
// Set the properties attribute of the operation state so that the
// property can be updated when the operation is created.
propertiesAttr = b.getDictionaryAttr(attrs);
Comment on lines +100 to +104
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need a vector just to create a single-element array ref:

Suggested change
SmallVector<NamedAttribute> attrs;
attrs.push_back(b.getNamedAttr(llvmAttrName, llvmFlagAttr));
// Set the properties attribute of the operation state so that the
// property can be updated when the operation is created.
propertiesAttr = b.getDictionaryAttr(attrs);
// Set the properties attribute of the operation state so that the
// property can be updated when the operation is created.
NamedAttribute attr{llvmAttrName, llvmFlagAttr};
propertiesAttr = b.getDictionaryAttr(ArrayRef(attr));

}
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
Attribute getPropAttr() const { return propertiesAttr; }

private:
NamedAttrList convertedAttr;
LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
DictionaryAttr propertiesAttr;
};

template <typename SourceOp, typename TargetOp>
Expand Down Expand Up @@ -129,9 +140,7 @@ class AttrConverterConstrainedFPToLLVM {
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}
Attribute getPropAttr() const { return {}; }

private:
NamedAttrList convertedAttr;
Expand Down
20 changes: 9 additions & 11 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@ class CallOpInterface;

namespace LLVM {
namespace detail {
/// Handle generically setting flags as native properties on LLVM operations.
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);

/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
Attribute propertiesAttr,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);

/// Replaces the given operation "op" with a call to an LLVM intrinsic with the
/// specified name "intrinsic" and operands.
Expand Down Expand Up @@ -307,9 +305,9 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
adaptor.getOperands(), op->getAttrs(),
*this->getTypeConverter(), rewriter);
return LLVM::detail::oneToOneRewrite(
op, TargetOp::getOperationName(), adaptor.getOperands(), op->getAttrs(),
/*propertiesAttr=*/Attribute{}, *this->getTypeConverter(), rewriter);
}
};

Expand Down
34 changes: 19 additions & 15 deletions mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,36 +54,40 @@ LogicalResult handleMultidimensionalVectors(
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter);

LogicalResult vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
Attribute propertiesAttr,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
} // namespace detail
} // namespace LLVM

// Default attribute conversion class, which passes all source attributes
// through to the target op, unmodified.
// through to the target op, unmodified. The attribute to set properties of the
// target operation will be nullptr (i.e. any properties that exist in will have
// default values).
template <typename SourceOp, typename TargetOp>
class AttrConvertPassThrough {
public:
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}

ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}
Attribute getPropAttr() const { return {}; }

private:
ArrayRef<NamedAttribute> srcAttrs;
};

/// Basic lowering implementation to rewrite Ops with just one result to the
/// LLVM Dialect. This supports higher-dimensional vector types.
/// The AttrConvert template template parameter should be a template class
/// with SourceOp and TargetOp type parameters, a constructor that takes
/// a SourceOp instance, and a getAttrs() method that returns
/// ArrayRef<NamedAttribute>.
/// The AttrConvert template template parameter should:
// - be a template class with SourceOp and TargetOp type parameters
// - have a constructor that takes a SourceOp instance
// - a getAttrs() method that returns ArrayRef<NamedAttribute> containing
// attributes that the target operation will have
// - a getPropAttr() method that returns either a NULL attribute or a
// DictionaryAttribute with properties that exist for the target operation
template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough>
Expand Down Expand Up @@ -134,8 +138,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {

return LLVM::detail::vectorOneToOneRewrite(
op, TargetOp::getOperationName(), adaptor.getOperands(),
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
attrConvert.getOverflowFlags());
attrConvert.getAttrs(), attrConvert.getPropAttr(),
*this->getTypeConverter(), rewriter);
}
};
} // namespace mlir
Expand Down
54 changes: 44 additions & 10 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
attr-dict `:` type($result) }];
}

class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
SignlessIntegerOrIndexLike:$rhs,
UnitAttr:$isExact)>,
Results<(outs SignlessIntegerOrIndexLike:$result)> {

let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
attr-dict `:` type($result) }];
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -482,7 +494,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
// DivUIOp
//===----------------------------------------------------------------------===//

def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
def Arith_DivUIOp : Arith_IntBinaryOpWithExactFlag<"divui",
[ConditionallySpeculatable]> {
let summary = "unsigned integer division operation";
let description = [{
Unsigned integer division. Rounds towards zero. Treats the leading bit as
Expand All @@ -493,12 +506,18 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
`tensor` values, the behavior is undefined if _any_ elements are divided by
zero.

If the `exact` attribute is present, the result value is poison if `lhs` is
not a multiple of `rhs`.

Example:

```mlir
// Scalar unsigned integer division.
%a = arith.divui %b, %c : i64

// Scalar unsigned integer division where %b is known to be a multiple of %c.
%a = arith.divui %b, %c exact : i64

// SIMD vector element-wise division.
%f = arith.divui %g, %h : vector<4xi32>

Expand All @@ -519,7 +538,8 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
// DivSIOp
//===----------------------------------------------------------------------===//

def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> {
def Arith_DivSIOp : Arith_IntBinaryOpWithExactFlag<"divsi",
[ConditionallySpeculatable]> {
let summary = "signed integer division operation";
let description = [{
Signed integer division. Rounds towards zero. Treats the leading bit as
Expand All @@ -530,12 +550,18 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> {
behavior is undefined if _any_ of its elements are divided by zero or has a
signed division overflow.

If the `exact` attribute is present, the result value is poison if `lhs` is
not a multiple of `rhs`.

Example:

```mlir
// Scalar signed integer division.
%a = arith.divsi %b, %c : i64

// Scalar signed integer division where %b is known to be a multiple of %c.
%a = arith.divsi %b, %c exact : i64

// SIMD vector element-wise division.
%f = arith.divsi %g, %h : vector<4xi32>

Expand Down Expand Up @@ -821,7 +847,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
// ShRUIOp
//===----------------------------------------------------------------------===//

def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
def Arith_ShRUIOp : Arith_IntBinaryOpWithExactFlag<"shrui", [Pure]> {
let summary = "unsigned integer right-shift";
let description = [{
The `shrui` operation shifts an integer value of the first operand to the right
Expand All @@ -830,12 +856,17 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
filled with zeros. If the value of the second operand is greater or equal than the
bitwidth of the first operand, then the operation returns poison.

If the `exact` keyword is present, the result value of shrui is a poison
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If the `exact` keyword is present, the result value of shrui is a poison
If the `exact` attribute is present, the result value of shrui is a poison

value if any of the bits shifted out are non-zero.

Example:

```mlir
%1 = arith.constant 160 : i8 // %1 is 0b10100000
%1 = arith.constant 160 : i8 // %1 is 0b10100000
%2 = arith.constant 3 : i8
%3 = arith.shrui %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100
%3 = arith.constant 6 : i8
%4 = arith.shrui %1, %2 exact : i8 // %4 is 0b00010100
%5 = arith.shrui %1, %3 : i8 // %3 is 0b00000010
```
}];
let hasFolder = 1;
Expand All @@ -845,7 +876,7 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
// ShRSIOp
//===----------------------------------------------------------------------===//

def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
def Arith_ShRSIOp : Arith_IntBinaryOpWithExactFlag<"shrsi", [Pure]> {
let summary = "signed integer right-shift";
let description = [{
The `shrsi` operation shifts an integer value of the first operand to the right
Expand All @@ -856,14 +887,17 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
operand is greater or equal than bitwidth of the first operand, then the operation
returns poison.

If the `exact` keyword is present, the result value of shrsi is a poison
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

value if any of the bits shifted out are non-zero.

Example:

```mlir
%1 = arith.constant 160 : i8 // %1 is 0b10100000
%1 = arith.constant 160 : i8 // %1 is 0b10100000
%2 = arith.constant 3 : i8
%3 = arith.shrsi %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100
%4 = arith.constant 96 : i8 // %4 is 0b01100000
%5 = arith.shrsi %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100
%3 = arith.shrsi %1, %2 exact : i8 // %3 is 0b11110100
%4 = arith.constant 98 : i8 // %4 is 0b01100010
%5 = arith.shrsi %4, %2 : i8 // %5 is 0b00001100
```
}];
let hasFolder = 1;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
adaptor.getOperands(), op->getAttrs(),
/*propAttr=*/Attribute{},
*getTypeConverter(), rewriter);
}

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
op->getAttrs(), *getTypeConverter(), rewriter);
op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(),
rewriter);
}
};

Expand Down
21 changes: 7 additions & 14 deletions mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Detail methods
//===----------------------------------------------------------------------===//

void LLVM::detail::setNativeProperties(Operation *op,
IntegerOverflowFlags overflowFlags) {
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
iface.setOverflowFlags(overflowFlags);
}

/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult LLVM::detail::oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags) {
ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
unsigned numResults = op->getNumResults();

SmallVector<Type> resultTypes;
Expand All @@ -320,11 +314,10 @@ LogicalResult LLVM::detail::oneToOneRewrite(
}

// Create the operation through state since we don't know its C++ type.
Operation *newOp =
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
resultTypes, targetAttrs);

setNativeProperties(newOp, overflowFlags);
OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
resultTypes, targetAttrs);
state.propertiesAttr = propertiesAttr;
Operation *newOp = rewriter.create(state);

// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
Expand Down
21 changes: 10 additions & 11 deletions mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(

LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags) {
ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
assert(!operands.empty());

// Cannot convert ops if their operands are not of LLVM type.
Expand All @@ -116,15 +116,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(

auto llvmNDVectorTy = operands[0].getType();
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
rewriter, overflowFlags);

auto callback = [op, targetOp, targetAttrs, overflowFlags,
return oneToOneRewrite(op, targetOp, operands, targetAttrs, propertiesAttr,
typeConverter, rewriter);
auto callback = [op, targetOp, targetAttrs, propertiesAttr,
&rewriter](Type llvm1DVectorTy, ValueRange operands) {
Operation *newOp =
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
operands, llvm1DVectorTy, targetAttrs);
LLVM::detail::setNativeProperties(newOp, overflowFlags);
OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp),
operands, llvm1DVectorTy, targetAttrs);
state.propertiesAttr = propertiesAttr;
Operation *newOp = rewriter.create(state);
return newOp->getResult(0);
};

Expand Down
Loading