Skip to content

Commit 2d06374

Browse files
authored
[mlir][arith] Add mulf(x, 0) -> 0 to mulf folder (#161395)
Fold `mulf(x, 0) -> 0` when (nnan | nsz)
1 parent a2330a3 commit 2d06374

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,13 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
12821282
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
12831283
return getLhs();
12841284

1285+
if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1286+
arith::FastMathFlags::nsz)) {
1287+
// mulf(x, 0) -> 0
1288+
if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1289+
return getRhs();
1290+
}
1291+
12851292
return constFoldBinaryOp<FloatAttr>(
12861293
adaptor.getOperands(),
12871294
[](const APFloat &a, const APFloat &b) { return a * b; });

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,6 +2216,18 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) {
22162216
return %2 : f32
22172217
}
22182218

2219+
// CHECK-LABEL: @test_mulf2(
2220+
func.func @test_mulf2(%arg0 : f32) -> (f32, f32) {
2221+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
2222+
// CHECK-DAG: %[[C0n:.+]] = arith.constant -0.000000e+00 : f32
2223+
// CHECK-NEXT: return %[[C0]], %[[C0n]]
2224+
%c0 = arith.constant 0.0 : f32
2225+
%c0n = arith.constant -0.0 : f32
2226+
%0 = arith.mulf %c0, %arg0 fastmath<nnan,nsz> : f32
2227+
%1 = arith.mulf %c0n, %arg0 fastmath<nnan,nsz> : f32
2228+
return %0, %1 : f32, f32
2229+
}
2230+
22192231
// -----
22202232

22212233
// CHECK-LABEL: @test_divf(

0 commit comments

Comments
 (0)