Skip to content

Commit e6c78b5

Browse files
committed
Add 'exact' flag to arith.shrui/shrsi/divsi/divui operations
1 parent f76c132 commit e6c78b5

File tree

13 files changed

+203
-76
lines changed

13 files changed

+203
-76
lines changed

mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,8 @@ class AttrConvertFastMathToLLVM {
6565
convertArithFastMathAttrToLLVM(arithFMFAttr));
6666
}
6767
}
68-
6968
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
70-
LLVM::IntegerOverflowFlags getOverflowFlags() const {
71-
return LLVM::IntegerOverflowFlags::none;
72-
}
69+
Attribute getPropAttr() const { return {}; }
7370

7471
private:
7572
NamedAttrList convertedAttr;
@@ -82,23 +79,37 @@ template <typename SourceOp, typename TargetOp>
8279
class AttrConvertOverflowToLLVM {
8380
public:
8481
AttrConvertOverflowToLLVM(SourceOp srcOp) {
82+
using IntegerOverflowFlagsAttr = LLVM::IntegerOverflowFlagsAttr;
83+
8584
// Copy the source attributes.
8685
convertedAttr = NamedAttrList{srcOp->getAttrs()};
8786
// Get the name of the arith overflow attribute.
8887
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
89-
// Remove the source overflow attribute.
88+
// Remove the source overflow attribute from the set that will be present
89+
// in the target.
9090
if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
9191
convertedAttr.erase(arithAttrName))) {
92-
overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
92+
auto llvmFlag = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
93+
// Create a dictionary attribute holding the overflow flags property.
94+
// (In the LLVM dialect, the overflow flags are a property, not an
95+
// attribute.)
96+
MLIRContext *ctx = srcOp.getOperation()->getContext();
97+
Builder b(ctx);
98+
auto llvmFlagAttr = IntegerOverflowFlagsAttr::get(ctx, llvmFlag);
99+
StringRef llvmAttrName = TargetOp::getOverflowFlagsAttrName();
100+
SmallVector<NamedAttribute> attrs;
101+
attrs.push_back(b.getNamedAttr(llvmAttrName, llvmFlagAttr));
102+
// Set the properties attribute of the operation state so that the
103+
// property can be updated when the operation is created.
104+
propertiesAttr = b.getDictionaryAttr(attrs);
93105
}
94106
}
95-
96107
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
97-
LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
108+
Attribute getPropAttr() const { return propertiesAttr; }
98109

99110
private:
100111
NamedAttrList convertedAttr;
101-
LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
112+
DictionaryAttr propertiesAttr;
102113
};
103114

104115
template <typename SourceOp, typename TargetOp>
@@ -129,9 +140,7 @@ class AttrConverterConstrainedFPToLLVM {
129140
}
130141

131142
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
132-
LLVM::IntegerOverflowFlags getOverflowFlags() const {
133-
return LLVM::IntegerOverflowFlags::none;
134-
}
143+
Attribute getPropAttr() const { return {}; }
135144

136145
private:
137146
NamedAttrList convertedAttr;

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,14 @@ class CallOpInterface;
1919

2020
namespace LLVM {
2121
namespace detail {
22-
/// Handle generically setting flags as native properties on LLVM operations.
23-
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
24-
2522
/// Replaces the given operation "op" with a new operation of type "targetOp"
2623
/// and given operands.
27-
LogicalResult oneToOneRewrite(
28-
Operation *op, StringRef targetOp, ValueRange operands,
29-
ArrayRef<NamedAttribute> targetAttrs,
30-
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
31-
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
24+
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
25+
ValueRange operands,
26+
ArrayRef<NamedAttribute> targetAttrs,
27+
Attribute propertiesAttr,
28+
const LLVMTypeConverter &typeConverter,
29+
ConversionPatternRewriter &rewriter);
3230

3331
/// Replaces the given operation "op" with a call to an LLVM intrinsic with the
3432
/// specified name "intrinsic" and operands.
@@ -307,9 +305,9 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
307305
LogicalResult
308306
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
309307
ConversionPatternRewriter &rewriter) const override {
310-
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
311-
adaptor.getOperands(), op->getAttrs(),
312-
*this->getTypeConverter(), rewriter);
308+
return LLVM::detail::oneToOneRewrite(
309+
op, TargetOp::getOperationName(), adaptor.getOperands(), op->getAttrs(),
310+
/*propertiesAttr=*/Attribute{}, *this->getTypeConverter(), rewriter);
313311
}
314312
};
315313

mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,36 +54,40 @@ LogicalResult handleMultidimensionalVectors(
5454
std::function<Value(Type, ValueRange)> createOperand,
5555
ConversionPatternRewriter &rewriter);
5656

57-
LogicalResult vectorOneToOneRewrite(
58-
Operation *op, StringRef targetOp, ValueRange operands,
59-
ArrayRef<NamedAttribute> targetAttrs,
60-
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
61-
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
57+
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
58+
ValueRange operands,
59+
ArrayRef<NamedAttribute> targetAttrs,
60+
Attribute propertiesAttr,
61+
const LLVMTypeConverter &typeConverter,
62+
ConversionPatternRewriter &rewriter);
6263
} // namespace detail
6364
} // namespace LLVM
6465

6566
// Default attribute conversion class, which passes all source attributes
66-
// through to the target op, unmodified.
67+
// through to the target op, unmodified. The attribute to set properties of the
68+
// target operation will be nullptr (i.e. any properties that exist in will have
69+
// default values).
6770
template <typename SourceOp, typename TargetOp>
6871
class AttrConvertPassThrough {
6972
public:
7073
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
7174

7275
ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
73-
LLVM::IntegerOverflowFlags getOverflowFlags() const {
74-
return LLVM::IntegerOverflowFlags::none;
75-
}
76+
Attribute getPropAttr() const { return {}; }
7677

7778
private:
7879
ArrayRef<NamedAttribute> srcAttrs;
7980
};
8081

8182
/// Basic lowering implementation to rewrite Ops with just one result to the
8283
/// LLVM Dialect. This supports higher-dimensional vector types.
83-
/// The AttrConvert template template parameter should be a template class
84-
/// with SourceOp and TargetOp type parameters, a constructor that takes
85-
/// a SourceOp instance, and a getAttrs() method that returns
86-
/// ArrayRef<NamedAttribute>.
84+
/// The AttrConvert template template parameter should:
85+
// - be a template class with SourceOp and TargetOp type parameters
86+
// - have a constructor that takes a SourceOp instance
87+
// - a getAttrs() method that returns ArrayRef<NamedAttribute> containing
88+
// attributes that the target operation will have
89+
// - a getPropAttr() method that returns either a NULL attribute or a
90+
// DictionaryAttribute with properties that exist for the target operation
8791
template <typename SourceOp, typename TargetOp,
8892
template <typename, typename> typename AttrConvert =
8993
AttrConvertPassThrough>
@@ -134,8 +138,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
134138

135139
return LLVM::detail::vectorOneToOneRewrite(
136140
op, TargetOp::getOperationName(), adaptor.getOperands(),
137-
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
138-
attrConvert.getOverflowFlags());
141+
attrConvert.getAttrs(), attrConvert.getPropAttr(),
142+
*this->getTypeConverter(), rewriter);
139143
}
140144
};
141145
} // namespace mlir

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,19 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
158158
attr-dict `:` type($result) }];
159159
}
160160

161+
class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
162+
Arith_BinaryOp<mnemonic, traits #
163+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
164+
DeclareOpInterfaceMethods<ArithExactFlagInterface>]>,
165+
Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
166+
SignlessIntegerOrIndexLike:$rhs,
167+
UnitAttr:$isExact)>,
168+
Results<(outs SignlessIntegerOrIndexLike:$result)> {
169+
170+
let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
171+
attr-dict `:` type($result) }];
172+
}
173+
161174
//===----------------------------------------------------------------------===//
162175
// ConstantOp
163176
//===----------------------------------------------------------------------===//
@@ -482,7 +495,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
482495
// DivUIOp
483496
//===----------------------------------------------------------------------===//
484497

485-
def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
498+
def Arith_DivUIOp : Arith_IntBinaryOpWithExactFlag<"divui",
499+
[ConditionallySpeculatable]> {
486500
let summary = "unsigned integer division operation";
487501
let description = [{
488502
Unsigned integer division. Rounds towards zero. Treats the leading bit as
@@ -493,12 +507,18 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
493507
`tensor` values, the behavior is undefined if _any_ elements are divided by
494508
zero.
495509

510+
If the `exact` attribute is present, the result value is poison if `lhs` is
511+
not a multiple of `rhs`.
512+
496513
Example:
497514

498515
```mlir
499516
// Scalar unsigned integer division.
500517
%a = arith.divui %b, %c : i64
501518

519+
// Scalar unsigned integer division where %b is known to be a multiple of %c.
520+
%a = arith.divui %b, %c exact : i64
521+
502522
// SIMD vector element-wise division.
503523
%f = arith.divui %g, %h : vector<4xi32>
504524

@@ -519,7 +539,8 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
519539
// DivSIOp
520540
//===----------------------------------------------------------------------===//
521541

522-
def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> {
542+
def Arith_DivSIOp : Arith_IntBinaryOpWithExactFlag<"divsi",
543+
[ConditionallySpeculatable]> {
523544
let summary = "signed integer division operation";
524545
let description = [{
525546
Signed integer division. Rounds towards zero. Treats the leading bit as
@@ -530,12 +551,18 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> {
530551
behavior is undefined if _any_ of its elements are divided by zero or has a
531552
signed division overflow.
532553

554+
If the `exact` attribute is present, the result value is poison if `lhs` is
555+
not a multiple of `rhs`.
556+
533557
Example:
534558

535559
```mlir
536560
// Scalar signed integer division.
537561
%a = arith.divsi %b, %c : i64
538562

563+
// Scalar signed integer division where %b is known to be a multiple of %c.
564+
%a = arith.divsi %b, %c exact : i64
565+
539566
// SIMD vector element-wise division.
540567
%f = arith.divsi %g, %h : vector<4xi32>
541568

@@ -821,7 +848,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
821848
// ShRUIOp
822849
//===----------------------------------------------------------------------===//
823850

824-
def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
851+
def Arith_ShRUIOp : Arith_IntBinaryOpWithExactFlag<"shrui", [Pure]> {
825852
let summary = "unsigned integer right-shift";
826853
let description = [{
827854
The `shrui` operation shifts an integer value of the first operand to the right
@@ -830,12 +857,17 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
830857
filled with zeros. If the value of the second operand is greater or equal than the
831858
bitwidth of the first operand, then the operation returns poison.
832859

860+
If the `exact` keyword is present, the result value of shrui is a poison
861+
value if any of the bits shifted out are non-zero.
862+
833863
Example:
834864

835865
```mlir
836-
%1 = arith.constant 160 : i8 // %1 is 0b10100000
866+
%1 = arith.constant 160 : i8 // %1 is 0b10100000
837867
%2 = arith.constant 3 : i8
838-
%3 = arith.shrui %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100
868+
%3 = arith.constant 6 : i8
869+
%4 = arith.shrui %1, %2 exact : i8 // %4 is 0b00010100
870+
%5 = arith.shrui %1, %3 : i8 // %3 is 0b00000010
839871
```
840872
}];
841873
let hasFolder = 1;
@@ -845,7 +877,7 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
845877
// ShRSIOp
846878
//===----------------------------------------------------------------------===//
847879

848-
def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
880+
def Arith_ShRSIOp : Arith_IntBinaryOpWithExactFlag<"shrsi", [Pure]> {
849881
let summary = "signed integer right-shift";
850882
let description = [{
851883
The `shrsi` operation shifts an integer value of the first operand to the right
@@ -856,14 +888,17 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
856888
operand is greater or equal than bitwidth of the first operand, then the operation
857889
returns poison.
858890

891+
If the `exact` keyword is present, the result value of shrsi is a poison
892+
value if any of the bits shifted out are non-zero.
893+
859894
Example:
860895

861896
```mlir
862-
%1 = arith.constant 160 : i8 // %1 is 0b10100000
897+
%1 = arith.constant 160 : i8 // %1 is 0b10100000
863898
%2 = arith.constant 3 : i8
864-
%3 = arith.shrsi %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100
865-
%4 = arith.constant 96 : i8 // %4 is 0b01100000
866-
%5 = arith.shrsi %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100
899+
%3 = arith.shrsi %1, %2 exact : i8 // %3 is 0b11110100
900+
%4 = arith.constant 98 : i8 // %4 is 0b01100010
901+
%5 = arith.shrsi %4, %2 : i8 // %5 is 0b00001100
867902
```
868903
}];
869904
let hasFolder = 1;

mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,38 @@ def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
139139
];
140140
}
141141

142+
def ArithExactFlagInterface : OpInterface<"ArithExactFlagInterface"> {
143+
let description = [{
144+
Access to op exact flag.
145+
}];
146+
147+
let cppNamespace = "::mlir::arith";
148+
149+
let methods = [
150+
InterfaceMethod<
151+
/*desc=*/ "Returns whether the operation has the exact flag",
152+
/*returnType=*/ "bool",
153+
/*methodName=*/ "hasExactFlag",
154+
/*args=*/ (ins),
155+
/*methodBody=*/ [{}],
156+
/*defaultImpl=*/ [{
157+
auto op = cast<ConcreteOp>(this->getOperation());
158+
return op.getIsExact();
159+
}]
160+
>,
161+
StaticInterfaceMethod<
162+
/*desc=*/ [{Returns the name of the 'exact' attribute
163+
for the operation}],
164+
/*returnType=*/ "StringRef",
165+
/*methodName=*/ "getExactAttrName",
166+
/*args=*/ (ins),
167+
/*methodBody=*/ [{}],
168+
/*defaultImpl=*/ [{
169+
return "isExact";
170+
}]
171+
>
172+
173+
];
174+
}
175+
142176
#endif // ARITH_OPS_INTERFACES

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
258258
ConversionPatternRewriter &rewriter) const {
259259
return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
260260
adaptor.getOperands(), op->getAttrs(),
261+
/*propAttr=*/Attribute{},
261262
*getTypeConverter(), rewriter);
262263
}
263264

mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
9696
ConversionPatternRewriter &rewriter) const override {
9797
return LLVM::detail::oneToOneRewrite(
9898
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99-
op->getAttrs(), *getTypeConverter(), rewriter);
99+
op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(),
100+
rewriter);
100101
}
101102
};
102103

0 commit comments

Comments
 (0)