Skip to content

Commit 6e2e4d4

Browse files
committed
Revert "[MLIR][Arith] Add denormal attribute to binary/unary operations (#112700)"
This reverts commit 4a7b56e. There is no agreement.
1 parent 3083acc commit 6e2e4d4

File tree

13 files changed

+50
-263
lines changed

13 files changed

+50
-263
lines changed

mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
5151
template <typename SourceOp, typename TargetOp>
5252
class AttrConvertFastMathToLLVM {
5353
public:
54-
explicit AttrConvertFastMathToLLVM(SourceOp srcOp) {
54+
AttrConvertFastMathToLLVM(SourceOp srcOp) {
5555
// Copy the source attributes.
5656
convertedAttr = NamedAttrList{srcOp->getAttrs()};
5757
// Get the name of the arith fastmath attribute.
@@ -81,7 +81,7 @@ class AttrConvertFastMathToLLVM {
8181
template <typename SourceOp, typename TargetOp>
8282
class AttrConvertOverflowToLLVM {
8383
public:
84-
explicit AttrConvertOverflowToLLVM(SourceOp srcOp) {
84+
AttrConvertOverflowToLLVM(SourceOp srcOp) {
8585
// Copy the source attributes.
8686
convertedAttr = NamedAttrList{srcOp->getAttrs()};
8787
// Get the name of the arith overflow attribute.
@@ -109,7 +109,7 @@ class AttrConverterConstrainedFPToLLVM {
109109
"LLVM::FPExceptionBehaviorOpInterface");
110110

111111
public:
112-
explicit AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
112+
AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
113113
// Copy the source attributes.
114114
convertedAttr = NamedAttrList{srcOp->getAttrs()};
115115

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -181,37 +181,4 @@ def Arith_RoundingModeAttr : I32EnumAttr<
181181
let cppNamespace = "::mlir::arith";
182182
}
183183

184-
//===----------------------------------------------------------------------===//
185-
// Arith_DenormalMode
186-
//===----------------------------------------------------------------------===//
187-
188-
// Denormal mode is applied on operands and results. For example, if denormal =
189-
// preserve_sign, operands and results will be flushed to sign preserving zero.
190-
// We do not distinguish between operands and results.
191-
192-
// The default mode. Denormals are preserved and processed as defined
193-
// by IEEE 754 rules.
194-
def Arith_DenormalModeIEEE : I32EnumAttrCase<"ieee", 0>;
195-
196-
// A mode where denormal numbers are flushed to zero, but the sign of the zero
197-
// (+0 or -0) is preserved.
198-
def Arith_DenormalModePreserveSign : I32EnumAttrCase<"preserve_sign", 1>;
199-
200-
// A mode where all denormal numbers are flushed to positive zero (+0),
201-
// ignoring the sign of the original number.
202-
def Arith_DenormalModePositiveZero : I32EnumAttrCase<"positive_zero", 2>;
203-
204-
def Arith_DenormalMode : I32EnumAttr<
205-
"DenormalMode", "denormal mode arith",
206-
[Arith_DenormalModeIEEE, Arith_DenormalModePreserveSign,
207-
Arith_DenormalModePositiveZero]> {
208-
let cppNamespace = "::mlir::arith";
209-
let genSpecializedAttr = 0;
210-
}
211-
212-
def Arith_DenormalModeAttr :
213-
EnumAttr<Arith_Dialect, Arith_DenormalMode, "denormal"> {
214-
let assemblyFormat = "`<` $value `>`";
215-
}
216-
217184
#endif // ARITH_BASE

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,35 +61,26 @@ class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
6161
// Base class for floating point unary operations.
6262
class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
6363
Arith_UnaryOp<mnemonic,
64-
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>,
65-
DeclareOpInterfaceMethods<ArithDenormalModeInterface>],
64+
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
6665
traits)>,
6766
Arguments<(ins FloatLike:$operand,
6867
DefaultValuedAttr<
69-
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
70-
DefaultValuedAttr<
71-
Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>,
68+
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
7269
Results<(outs FloatLike:$result)> {
7370
let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
74-
(`denormal` `` $denormal^)?
7571
attr-dict `:` type($result) }];
7672
}
7773

7874
// Base class for floating point binary operations.
7975
class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
8076
Arith_BinaryOp<mnemonic,
81-
!listconcat([Pure,
82-
DeclareOpInterfaceMethods<ArithFastMathInterface>,
83-
DeclareOpInterfaceMethods<ArithDenormalModeInterface>],
77+
!listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>],
8478
traits)>,
8579
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
8680
DefaultValuedAttr<
87-
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
88-
DefaultValuedAttr<
89-
Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>,
81+
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
9082
Results<(outs FloatLike:$result)> {
91-
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
92-
(`denormal` `` $denormal^)?
83+
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
9384
attr-dict `:` type($result) }];
9485
}
9586

@@ -1094,6 +1085,7 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
10941085
let hasFolder = 1;
10951086
}
10961087

1088+
10971089
//===----------------------------------------------------------------------===//
10981090
// MulFOp
10991091
//===----------------------------------------------------------------------===//
@@ -1119,6 +1111,8 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
11191111
%x = arith.mulf %y, %z : tensor<4x?xbf16>
11201112
```
11211113

1114+
TODO: In the distant future, this will accept optional attributes for fast
1115+
math, contraction, rounding mode, and other controls.
11221116
}];
11231117
let hasFolder = 1;
11241118
let hasCanonicalizer = 1;

mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
4545
return "fastmath";
4646
}]
4747
>
48+
4849
];
4950
}
5051

5152
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
5253
let description = [{
53-
Access to operation integer overflow flags.
54+
Access to op integer overflow flags.
5455
}];
5556

5657
let cppNamespace = "::mlir::arith";
@@ -107,7 +108,7 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
107108

108109
def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
109110
let description = [{
110-
Access to operation rounding mode.
111+
Access to op rounding mode.
111112
}];
112113

113114
let cppNamespace = "::mlir::arith";
@@ -138,39 +139,4 @@ def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
138139
];
139140
}
140141

141-
142-
def ArithDenormalModeInterface : OpInterface<"ArithDenormalModeInterface"> {
143-
let description = [{
144-
Access the operation denormal modes.
145-
}];
146-
147-
let cppNamespace = "::mlir::arith";
148-
149-
let methods = [
150-
InterfaceMethod<
151-
/*desc=*/ "Returns a DenormalModeAttr attribute for the operation",
152-
/*returnType=*/ "DenormalModeAttr",
153-
/*methodName=*/ "getDenormalModeAttr",
154-
/*args=*/ (ins),
155-
/*methodBody=*/ [{}],
156-
/*defaultImpl=*/ [{
157-
auto op = cast<ConcreteOp>(this->getOperation());
158-
return op.getDenormalAttr();
159-
}]
160-
>,
161-
StaticInterfaceMethod<
162-
/*desc=*/ [{Returns the name of the DenormalModeAttr attribute for
163-
the operation}],
164-
/*returnType=*/ "StringRef",
165-
/*methodName=*/ "getDenormalModeAttrName",
166-
/*args=*/ (ins),
167-
/*methodBody=*/ [{}],
168-
/*defaultImpl=*/ [{
169-
return "denormal";
170-
}]
171-
>
172-
];
173-
}
174-
175-
176142
#endif // ARITH_OPS_INTERFACES

mlir/include/mlir/IR/Matchers.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,6 @@ inline detail::constant_float_predicate_matcher m_NegInfFloat() {
438438
}};
439439
}
440440

441-
/// Matches a constant scalar / vector splat / tensor splat with denormal
442-
/// values.
443-
inline detail::constant_float_predicate_matcher m_isDenormalFloat() {
444-
return {[](const APFloat &value) { return value.isDenormal(); }};
445-
}
446-
447441
/// Matches a constant scalar / vector splat / tensor splat integer zero.
448442
inline detail::constant_int_predicate_matcher m_Zero() {
449443
return {[](const APInt &value) { return 0 == value; }};

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -53,49 +53,22 @@ struct ConstrainedVectorConvertToLLVMPattern
5353
}
5454
};
5555

56-
template <typename SourceOp, typename TargetOp,
57-
template <typename, typename> typename AttrConvert =
58-
AttrConvertPassThrough>
59-
struct DenormalOpConversionToLLVMPattern
60-
: public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
61-
using VectorConvertToLLVMPattern<SourceOp, TargetOp,
62-
AttrConvert>::VectorConvertToLLVMPattern;
63-
64-
LogicalResult
65-
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
66-
ConversionPatternRewriter &rewriter) const override {
67-
// TODO: Here, we need a legalization step. LLVM provides a function-level
68-
// attribute for denormal; here, we need to move this information from the
69-
// operation to the function, making sure all the operations in the same
70-
// function are consistent.
71-
if (op.getDenormalModeAttr().getValue() != arith::DenormalMode::ieee)
72-
return rewriter.notifyMatchFailure(
73-
op, "only ieee denormal mode is supported at the moment");
74-
75-
StringRef arithDenormalAttrName = SourceOp::getDenormalModeAttrName();
76-
op->removeAttr(arithDenormalAttrName);
77-
return VectorConvertToLLVMPattern<SourceOp, TargetOp,
78-
AttrConvert>::matchAndRewrite(op, adaptor,
79-
rewriter);
80-
}
81-
};
82-
8356
//===----------------------------------------------------------------------===//
8457
// Straightforward Op Lowerings
8558
//===----------------------------------------------------------------------===//
8659

8760
using AddFOpLowering =
88-
DenormalOpConversionToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
89-
arith::AttrConvertFastMathToLLVM>;
61+
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
62+
arith::AttrConvertFastMathToLLVM>;
9063
using AddIOpLowering =
9164
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
9265
arith::AttrConvertOverflowToLLVM>;
9366
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
9467
using BitcastOpLowering =
9568
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
9669
using DivFOpLowering =
97-
DenormalOpConversionToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
98-
arith::AttrConvertFastMathToLLVM>;
70+
VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
71+
arith::AttrConvertFastMathToLLVM>;
9972
using DivSIOpLowering =
10073
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
10174
using DivUIOpLowering =
@@ -110,38 +83,38 @@ using FPToSIOpLowering =
11083
using FPToUIOpLowering =
11184
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
11285
using MaximumFOpLowering =
113-
DenormalOpConversionToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
114-
arith::AttrConvertFastMathToLLVM>;
86+
VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
87+
arith::AttrConvertFastMathToLLVM>;
11588
using MaxNumFOpLowering =
116-
DenormalOpConversionToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
117-
arith::AttrConvertFastMathToLLVM>;
89+
VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
90+
arith::AttrConvertFastMathToLLVM>;
11891
using MaxSIOpLowering =
11992
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
12093
using MaxUIOpLowering =
12194
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
12295
using MinimumFOpLowering =
123-
DenormalOpConversionToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
124-
arith::AttrConvertFastMathToLLVM>;
96+
VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
97+
arith::AttrConvertFastMathToLLVM>;
12598
using MinNumFOpLowering =
126-
DenormalOpConversionToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
127-
arith::AttrConvertFastMathToLLVM>;
99+
VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
100+
arith::AttrConvertFastMathToLLVM>;
128101
using MinSIOpLowering =
129102
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
130103
using MinUIOpLowering =
131104
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
132105
using MulFOpLowering =
133-
DenormalOpConversionToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
134-
arith::AttrConvertFastMathToLLVM>;
106+
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
107+
arith::AttrConvertFastMathToLLVM>;
135108
using MulIOpLowering =
136109
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
137110
arith::AttrConvertOverflowToLLVM>;
138111
using NegFOpLowering =
139-
DenormalOpConversionToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
140-
arith::AttrConvertFastMathToLLVM>;
112+
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
113+
arith::AttrConvertFastMathToLLVM>;
141114
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
142115
using RemFOpLowering =
143-
DenormalOpConversionToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
144-
arith::AttrConvertFastMathToLLVM>;
116+
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
117+
arith::AttrConvertFastMathToLLVM>;
145118
using RemSIOpLowering =
146119
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
147120
using RemUIOpLowering =
@@ -158,8 +131,8 @@ using ShRUIOpLowering =
158131
using SIToFPOpLowering =
159132
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
160133
using SubFOpLowering =
161-
DenormalOpConversionToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
162-
arith::AttrConvertFastMathToLLVM>;
134+
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
135+
arith::AttrConvertFastMathToLLVM>;
163136
using SubIOpLowering =
164137
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
165138
arith::AttrConvertOverflowToLLVM>;

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -422,23 +422,21 @@ def TruncIShrUIMulIToMulUIExtended :
422422
//===----------------------------------------------------------------------===//
423423

424424
// mulf(negf(x), negf(y)) -> mulf(x,y)
425-
// (retain fastmath flags and denormal mode of the original divf)
425+
// (retain fastmath flags of original mulf)
426426
def MulFOfNegF :
427-
Pat<(Arith_MulFOp (Arith_NegFOp $x, $_, $_),
428-
(Arith_NegFOp $y, $_, $_), $fmf, $mode),
429-
(Arith_MulFOp $x, $y, $fmf, $mode),
427+
Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
428+
(Arith_MulFOp $x, $y, $fmf),
430429
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
431430

432431
//===----------------------------------------------------------------------===//
433432
// DivFOp
434433
//===----------------------------------------------------------------------===//
435434

436435
// divf(negf(x), negf(y)) -> divf(x,y)
437-
// (retain fastmath flags and denormal mode of the original divf)
436+
// (retain fastmath flags of original divf)
438437
def DivFOfNegF :
439-
Pat<(Arith_DivFOp (Arith_NegFOp $x, $_, $_),
440-
(Arith_NegFOp $y, $_, $_), $fmf, $mode),
441-
(Arith_DivFOp $x, $y, $fmf, $mode),
438+
Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
439+
(Arith_DivFOp $x, $y, $fmf),
442440
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
443441

444442
#endif // ARITH_PATTERNS

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
956956
//===----------------------------------------------------------------------===//
957957

958958
OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
959-
// negf(negf(x)) -> x
959+
/// negf(negf(x)) -> x
960960
if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
961961
return op.getOperand();
962962
return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
@@ -986,14 +986,6 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
986986
if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
987987
return getLhs();
988988

989-
// Simplifies subf(x, rhs) to x if the following conditions are met:
990-
// 1. `rhs` is a denormal floating-point value.
991-
// 2. The denormal mode for the operation is set to positive zero.
992-
bool isPositiveZeroMode =
993-
getDenormalModeAttr().getValue() == DenormalMode::positive_zero;
994-
if (isPositiveZeroMode && matchPattern(adaptor.getRhs(), m_isDenormalFloat()))
995-
return getLhs();
996-
997989
return constFoldBinaryOp<FloatAttr>(
998990
adaptor.getOperands(),
999991
[](const APFloat &a, const APFloat &b) { return a - b; });

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,17 +1498,15 @@ static Operation *findPayloadOp(Block *body, bool initFirst = false) {
14981498

14991499
void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
15001500
SmallVector<StringRef> elidedAttrs;
1501+
std::string attrToElide;
15011502
p << " { " << payloadOp->getName().getStringRef();
15021503
for (const auto &attr : payloadOp->getAttrs()) {
1503-
if (auto fastAttr = dyn_cast<arith::FastMathFlagsAttr>(attr.getValue())) {
1504-
if (fastAttr.getValue() == arith::FastMathFlags::none) {
1505-
elidedAttrs.push_back(attr.getName());
1506-
}
1507-
}
1508-
if (auto denormAttr = dyn_cast<arith::DenormalModeAttr>(attr.getValue())) {
1509-
if (denormAttr.getValue() == arith::DenormalMode::ieee) {
1510-
elidedAttrs.push_back(attr.getName());
1511-
}
1504+
auto fastAttr =
1505+
llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1506+
if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1507+
attrToElide = attr.getName().str();
1508+
elidedAttrs.push_back(attrToElide);
1509+
break;
15121510
}
15131511
}
15141512
p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);

0 commit comments

Comments
 (0)