Skip to content

Commit 4ddfbb5

Browse files
jfurtekgithub-actions[bot]
authored andcommitted
Automerge: Add 'exact' flag to arith.shrui/shrsi/divsi/divui operations (#165923)
This MR adds support for the `exact` flag to the `arith.shrui/shrsi/divsi/divui` operations. The semantics are identical to those of the LLVM dialect and the LLVM language reference. This MR also modifies the mechanism for converting `arith` dialect **attributes** to corresponding **properties** in the `LLVM` dialect. (As a specific example, the integer overflow flags `nsw/nuw` are **properties** in the `LLVM` dialect, as opposed to attributes.) Previously, attribute converter classes were required to have a specific method to support integer overflow flags: ```C++ template <typename SourceOp, typename TargetOp> class AttrConvertPassThrough { public: ... LLVM::IntegerOverflowFlags getOverflowFlags() const { return LLVM::IntegerOverflowFlags::none; } }; ``` This method was required, even for `arith` source operations that did not use integer overflow flags (e.g. `AttrConvertFastMathToLLVM`). This MR modifies the interface required by `arith` dialect attribute converters to instead provide a (possibly NULL) properties attribute: ```C++ template <typename SourceOp, typename TargetOp> class AttrConvertPassThrough { public: ... Attribute getPropAttr() const { return {}; } }; ``` For `arith` operations with attributes that map to `LLVM` dialect **properties**, the attribute converter can create a `DictionaryAttr` containing target properties and return that attribute from the attribute converter's `getPropAttr()` method. The `arith` attribute conversion framework will set the `propertiesAttr` of an `OperationState`, and the target operation's `setPropertiesFromAttr()` method will be invoked to set the properties when the target operation is created. The `AttrConvertOverflowToLLVM` class in this MR uses the new approach.
2 parents 013fc2b + a770d2b commit 4ddfbb5

File tree

12 files changed

+167
-76
lines changed

12 files changed

+167
-76
lines changed

mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,8 @@ class AttrConvertFastMathToLLVM {
6565
convertArithFastMathAttrToLLVM(arithFMFAttr));
6666
}
6767
}
68-
6968
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
70-
LLVM::IntegerOverflowFlags getOverflowFlags() const {
71-
return LLVM::IntegerOverflowFlags::none;
72-
}
69+
Attribute getPropAttr() const { return {}; }
7370

7471
private:
7572
NamedAttrList convertedAttr;
@@ -82,23 +79,36 @@ template <typename SourceOp, typename TargetOp>
8279
class AttrConvertOverflowToLLVM {
8380
public:
8481
AttrConvertOverflowToLLVM(SourceOp srcOp) {
82+
using IntegerOverflowFlagsAttr = LLVM::IntegerOverflowFlagsAttr;
83+
8584
// Copy the source attributes.
8685
convertedAttr = NamedAttrList{srcOp->getAttrs()};
8786
// Get the name of the arith overflow attribute.
8887
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
89-
// Remove the source overflow attribute.
88+
// Remove the source overflow attribute from the set that will be present
89+
// in the target.
9090
if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
9191
convertedAttr.erase(arithAttrName))) {
92-
overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
92+
auto llvmFlag = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
93+
// Create a dictionary attribute holding the overflow flags property.
94+
// (In the LLVM dialect, the overflow flags are a property, not an
95+
// attribute.)
96+
MLIRContext *ctx = srcOp.getOperation()->getContext();
97+
Builder b(ctx);
98+
auto llvmFlagAttr = IntegerOverflowFlagsAttr::get(ctx, llvmFlag);
99+
StringRef llvmAttrName = TargetOp::getOverflowFlagsAttrName();
100+
NamedAttribute attr{llvmAttrName, llvmFlagAttr};
101+
// Set the properties attribute of the operation state so that the
102+
// property can be updated when the operation is created.
103+
propertiesAttr = b.getDictionaryAttr(ArrayRef(attr));
93104
}
94105
}
95-
96106
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
97-
LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
107+
Attribute getPropAttr() const { return propertiesAttr; }
98108

99109
private:
100110
NamedAttrList convertedAttr;
101-
LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
111+
DictionaryAttr propertiesAttr;
102112
};
103113

104114
template <typename SourceOp, typename TargetOp>
@@ -129,9 +139,7 @@ class AttrConverterConstrainedFPToLLVM {
129139
}
130140

131141
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
132-
LLVM::IntegerOverflowFlags getOverflowFlags() const {
133-
return LLVM::IntegerOverflowFlags::none;
134-
}
142+
Attribute getPropAttr() const { return {}; }
135143

136144
private:
137145
NamedAttrList convertedAttr;

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,14 @@ class CallOpInterface;
1919

2020
namespace LLVM {
2121
namespace detail {
22-
/// Handle generically setting flags as native properties on LLVM operations.
23-
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
24-
2522
/// Replaces the given operation "op" with a new operation of type "targetOp"
2623
/// and given operands.
27-
LogicalResult oneToOneRewrite(
28-
Operation *op, StringRef targetOp, ValueRange operands,
29-
ArrayRef<NamedAttribute> targetAttrs,
30-
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
31-
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
24+
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
25+
ValueRange operands,
26+
ArrayRef<NamedAttribute> targetAttrs,
27+
Attribute propertiesAttr,
28+
const LLVMTypeConverter &typeConverter,
29+
ConversionPatternRewriter &rewriter);
3230

3331
/// Replaces the given operation "op" with a call to an LLVM intrinsic with the
3432
/// specified name "intrinsic" and operands.
@@ -307,9 +305,9 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
307305
LogicalResult
308306
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
309307
ConversionPatternRewriter &rewriter) const override {
310-
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
311-
adaptor.getOperands(), op->getAttrs(),
312-
*this->getTypeConverter(), rewriter);
308+
return LLVM::detail::oneToOneRewrite(
309+
op, TargetOp::getOperationName(), adaptor.getOperands(), op->getAttrs(),
310+
/*propertiesAttr=*/Attribute{}, *this->getTypeConverter(), rewriter);
313311
}
314312
};
315313

mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,36 +54,40 @@ LogicalResult handleMultidimensionalVectors(
5454
std::function<Value(Type, ValueRange)> createOperand,
5555
ConversionPatternRewriter &rewriter);
5656

57-
LogicalResult vectorOneToOneRewrite(
58-
Operation *op, StringRef targetOp, ValueRange operands,
59-
ArrayRef<NamedAttribute> targetAttrs,
60-
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
61-
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
57+
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
58+
ValueRange operands,
59+
ArrayRef<NamedAttribute> targetAttrs,
60+
Attribute propertiesAttr,
61+
const LLVMTypeConverter &typeConverter,
62+
ConversionPatternRewriter &rewriter);
6263
} // namespace detail
6364
} // namespace LLVM
6465

6566
// Default attribute conversion class, which passes all source attributes
66-
// through to the target op, unmodified.
67+
// through to the target op, unmodified. The attribute to set properties of the
68+
// target operation will be nullptr (i.e. any properties that exist in will have
69+
// default values).
6770
template <typename SourceOp, typename TargetOp>
6871
class AttrConvertPassThrough {
6972
public:
7073
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
7174

7275
ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
73-
LLVM::IntegerOverflowFlags getOverflowFlags() const {
74-
return LLVM::IntegerOverflowFlags::none;
75-
}
76+
Attribute getPropAttr() const { return {}; }
7677

7778
private:
7879
ArrayRef<NamedAttribute> srcAttrs;
7980
};
8081

8182
/// Basic lowering implementation to rewrite Ops with just one result to the
8283
/// LLVM Dialect. This supports higher-dimensional vector types.
83-
/// The AttrConvert template template parameter should be a template class
84-
/// with SourceOp and TargetOp type parameters, a constructor that takes
85-
/// a SourceOp instance, and a getAttrs() method that returns
86-
/// ArrayRef<NamedAttribute>.
84+
/// The AttrConvert template template parameter should:
85+
// - be a template class with SourceOp and TargetOp type parameters
86+
// - have a constructor that takes a SourceOp instance
87+
// - a getAttrs() method that returns ArrayRef<NamedAttribute> containing
88+
// attributes that the target operation will have
89+
// - a getPropAttr() method that returns either a NULL attribute or a
90+
// DictionaryAttribute with properties that exist for the target operation
8791
template <typename SourceOp, typename TargetOp,
8892
template <typename, typename> typename AttrConvert =
8993
AttrConvertPassThrough,
@@ -137,8 +141,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
137141

138142
return LLVM::detail::vectorOneToOneRewrite(
139143
op, TargetOp::getOperationName(), adaptor.getOperands(),
140-
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
141-
attrConvert.getOverflowFlags());
144+
attrConvert.getAttrs(), attrConvert.getPropAttr(),
145+
*this->getTypeConverter(), rewriter);
142146
}
143147
};
144148
} // namespace mlir

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,18 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
158158
attr-dict `:` type($result) }];
159159
}
160160

161+
class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
162+
Arith_BinaryOp<mnemonic, traits #
163+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
164+
Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
165+
SignlessIntegerOrIndexLike:$rhs,
166+
UnitAttr:$isExact)>,
167+
Results<(outs SignlessIntegerOrIndexLike:$result)> {
168+
169+
let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
170+
attr-dict `:` type($result) }];
171+
}
172+
161173
//===----------------------------------------------------------------------===//
162174
// ConstantOp
163175
//===----------------------------------------------------------------------===//
@@ -482,7 +494,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
482494
// DivUIOp
483495
//===----------------------------------------------------------------------===//
484496

485-
def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
497+
def Arith_DivUIOp : Arith_IntBinaryOpWithExactFlag<"divui",
498+
[ConditionallySpeculatable]> {
486499
let summary = "unsigned integer division operation";
487500
let description = [{
488501
Unsigned integer division. Rounds towards zero. Treats the leading bit as
@@ -493,12 +506,18 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
493506
`tensor` values, the behavior is undefined if _any_ elements are divided by
494507
zero.
495508

509+
If the `exact` attribute is present, the result value is poison if `lhs` is
510+
not a multiple of `rhs`.
511+
496512
Example:
497513

498514
```mlir
499515
// Scalar unsigned integer division.
500516
%a = arith.divui %b, %c : i64
501517

518+
// Scalar unsigned integer division where %b is known to be a multiple of %c.
519+
%a = arith.divui %b, %c exact : i64
520+
502521
// SIMD vector element-wise division.
503522
%f = arith.divui %g, %h : vector<4xi32>
504523

@@ -519,7 +538,8 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
519538
// DivSIOp
520539
//===----------------------------------------------------------------------===//
521540

522-
def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> {
541+
def Arith_DivSIOp : Arith_IntBinaryOpWithExactFlag<"divsi",
542+
[ConditionallySpeculatable]> {
523543
let summary = "signed integer division operation";
524544
let description = [{
525545
Signed integer division. Rounds towards zero. Treats the leading bit as
@@ -530,12 +550,18 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> {
530550
behavior is undefined if _any_ of its elements are divided by zero or has a
531551
signed division overflow.
532552

553+
If the `exact` attribute is present, the result value is poison if `lhs` is
554+
not a multiple of `rhs`.
555+
533556
Example:
534557

535558
```mlir
536559
// Scalar signed integer division.
537560
%a = arith.divsi %b, %c : i64
538561

562+
// Scalar signed integer division where %b is known to be a multiple of %c.
563+
%a = arith.divsi %b, %c exact : i64
564+
539565
// SIMD vector element-wise division.
540566
%f = arith.divsi %g, %h : vector<4xi32>
541567

@@ -821,7 +847,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
821847
// ShRUIOp
822848
//===----------------------------------------------------------------------===//
823849

824-
def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
850+
def Arith_ShRUIOp : Arith_IntBinaryOpWithExactFlag<"shrui", [Pure]> {
825851
let summary = "unsigned integer right-shift";
826852
let description = [{
827853
The `shrui` operation shifts an integer value of the first operand to the right
@@ -830,12 +856,17 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
830856
filled with zeros. If the value of the second operand is greater or equal than the
831857
bitwidth of the first operand, then the operation returns poison.
832858

859+
If the `exact` attribute is present, the result value of shrui is a poison
860+
value if any of the bits shifted out are non-zero.
861+
833862
Example:
834863

835864
```mlir
836-
%1 = arith.constant 160 : i8 // %1 is 0b10100000
865+
%1 = arith.constant 160 : i8 // %1 is 0b10100000
837866
%2 = arith.constant 3 : i8
838-
%3 = arith.shrui %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100
867+
%3 = arith.constant 6 : i8
868+
%4 = arith.shrui %1, %2 exact : i8 // %4 is 0b00010100
869+
%5 = arith.shrui %1, %3 : i8 // %3 is 0b00000010
839870
```
840871
}];
841872
let hasFolder = 1;
@@ -845,7 +876,7 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
845876
// ShRSIOp
846877
//===----------------------------------------------------------------------===//
847878

848-
def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
879+
def Arith_ShRSIOp : Arith_IntBinaryOpWithExactFlag<"shrsi", [Pure]> {
849880
let summary = "signed integer right-shift";
850881
let description = [{
851882
The `shrsi` operation shifts an integer value of the first operand to the right
@@ -856,14 +887,17 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
856887
operand is greater or equal than bitwidth of the first operand, then the operation
857888
returns poison.
858889

890+
If the `exact` attribute is present, the result value of shrsi is a poison
891+
value if any of the bits shifted out are non-zero.
892+
859893
Example:
860894

861895
```mlir
862-
%1 = arith.constant 160 : i8 // %1 is 0b10100000
896+
%1 = arith.constant 160 : i8 // %1 is 0b10100000
863897
%2 = arith.constant 3 : i8
864-
%3 = arith.shrsi %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100
865-
%4 = arith.constant 96 : i8 // %4 is 0b01100000
866-
%5 = arith.shrsi %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100
898+
%3 = arith.shrsi %1, %2 exact : i8 // %3 is 0b11110100
899+
%4 = arith.constant 98 : i8 // %4 is 0b01100010
900+
%5 = arith.shrsi %4, %2 : i8 // %5 is 0b00001100
867901
```
868902
}];
869903
let hasFolder = 1;

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
281281
ConversionPatternRewriter &rewriter) const {
282282
return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
283283
adaptor.getOperands(), op->getAttrs(),
284+
/*propAttr=*/Attribute{},
284285
*getTypeConverter(), rewriter);
285286
}
286287

mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
9696
ConversionPatternRewriter &rewriter) const override {
9797
return LLVM::detail::oneToOneRewrite(
9898
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99-
op->getAttrs(), *getTypeConverter(), rewriter);
99+
op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(),
100+
rewriter);
100101
}
101102
};
102103

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -296,19 +296,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
296296
// Detail methods
297297
//===----------------------------------------------------------------------===//
298298

299-
void LLVM::detail::setNativeProperties(Operation *op,
300-
IntegerOverflowFlags overflowFlags) {
301-
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
302-
iface.setOverflowFlags(overflowFlags);
303-
}
304-
305299
/// Replaces the given operation "op" with a new operation of type "targetOp"
306300
/// and given operands.
307301
LogicalResult LLVM::detail::oneToOneRewrite(
308302
Operation *op, StringRef targetOp, ValueRange operands,
309-
ArrayRef<NamedAttribute> targetAttrs,
310-
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
311-
IntegerOverflowFlags overflowFlags) {
303+
ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
304+
const LLVMTypeConverter &typeConverter,
305+
ConversionPatternRewriter &rewriter) {
312306
unsigned numResults = op->getNumResults();
313307

314308
SmallVector<Type> resultTypes;
@@ -320,11 +314,10 @@ LogicalResult LLVM::detail::oneToOneRewrite(
320314
}
321315

322316
// Create the operation through state since we don't know its C++ type.
323-
Operation *newOp =
324-
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
325-
resultTypes, targetAttrs);
326-
327-
setNativeProperties(newOp, overflowFlags);
317+
OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
318+
resultTypes, targetAttrs);
319+
state.propertiesAttr = propertiesAttr;
320+
Operation *newOp = rewriter.create(state);
328321

329322
// If the operation produced 0 or 1 result, return them immediately.
330323
if (numResults == 0)

mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
105105

106106
LogicalResult LLVM::detail::vectorOneToOneRewrite(
107107
Operation *op, StringRef targetOp, ValueRange operands,
108-
ArrayRef<NamedAttribute> targetAttrs,
109-
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
110-
IntegerOverflowFlags overflowFlags) {
108+
ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
109+
const LLVMTypeConverter &typeConverter,
110+
ConversionPatternRewriter &rewriter) {
111111
assert(!operands.empty());
112112

113113
// Cannot convert ops if their operands are not of LLVM type.
@@ -116,15 +116,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
116116

117117
auto llvmNDVectorTy = operands[0].getType();
118118
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
119-
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
120-
rewriter, overflowFlags);
121-
122-
auto callback = [op, targetOp, targetAttrs, overflowFlags,
119+
return oneToOneRewrite(op, targetOp, operands, targetAttrs, propertiesAttr,
120+
typeConverter, rewriter);
121+
auto callback = [op, targetOp, targetAttrs, propertiesAttr,
123122
&rewriter](Type llvm1DVectorTy, ValueRange operands) {
124-
Operation *newOp =
125-
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
126-
operands, llvm1DVectorTy, targetAttrs);
127-
LLVM::detail::setNativeProperties(newOp, overflowFlags);
123+
OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp),
124+
operands, llvm1DVectorTy, targetAttrs);
125+
state.propertiesAttr = propertiesAttr;
126+
Operation *newOp = rewriter.create(state);
128127
return newOp->getResult(0);
129128
};
130129

0 commit comments

Comments
 (0)