Skip to content

Commit 8c72844

Browse files
author
Xiang Li
committed
Limit to fastmask nnan | ninf
1 parent 8d49f3f commit 8c72844

File tree

3 files changed

+17
-26
lines changed

3 files changed

+17
-26
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,12 +1281,14 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
12811281
// mulf(x, 1) -> x
12821282
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
12831283
return getLhs();
1284-
// mulf(NaN, x) -> NaN
1285-
if (matchPattern(adaptor.getLhs(), m_NaNFloat()))
1286-
return getLhs();
1287-
// mulf(x, 0) -> 0
1288-
if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1289-
return getRhs();
1284+
1285+
arith::FastMathFlags fmf = getFastmath();
1286+
if (arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1287+
arith::FastMathFlags::ninf)) {
1288+
// mulf(x, 0) -> 0
1289+
if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1290+
return getRhs();
1291+
}
12901292

12911293
return constFoldBinaryOp<FloatAttr>(
12921294
adaptor.getOperands(),

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2221,20 +2221,9 @@ func.func @test_mulf2(%arg0 : f32, %arg1 : f32) -> (f32, f32) {
22212221
// CHECK-NEXT: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
22222222
// CHECK-NEXT: return %[[C0]], %[[C0]]
22232223
%c0 = arith.constant 0.0 : f32
2224-
%0 = arith.mulf %arg0, %c0 : f32
2225-
%1 = arith.mulf %c0, %arg1 : f32
2226-
return %0, %1 : f32, f32
2227-
}
2228-
2229-
// CHECK-LABEL: @test_mulf3(
2230-
func.func @test_mulf3(%arg0 : f32, %arg1 : f32) -> (f32, f32) {
2231-
// CHECK-NEXT: %[[NAN:.+]] = arith.constant 0x7FC00000 : f32
2232-
// CHECK-NEXT: return %[[NAN]], %[[NAN]]
2233-
%c0 = arith.constant 0.0 : f32
2234-
%nan = arith.constant 0x7FC00000 : f32
2235-
%0 = arith.mulf %nan, %c0 : f32
2236-
%1 = arith.mulf %c0, %nan : f32
2237-
return %0, %1 : f32, f32
2224+
%0 = arith.mulf %arg0, %c0 fastmath<nnan,ninf> : f32
2225+
%1 = arith.mulf %c0, %arg1 fastmath<nnan,ninf> : f32
2226+
return %0, %1 : f32, f32
22382227
}
22392228

22402229
// -----

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -930,31 +930,31 @@ func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: i
930930
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
931931
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
932932
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
933-
// CHECK-DAG: %[[CST10:.*]] = arith.constant 1.000000e+01 : f32
933+
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
934934
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
935935
// Prologue:
936936
// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
937937
// Kernel:
938938
// CHECK-NEXT: %[[L1:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
939939
// CHECK-SAME: step %[[C1]] iter_args(%[[ARG0:.*]] = %[[CST2]], %[[ARG1:.*]] = %[[L0]]) -> (f32, f32) {
940940
// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG1]], %[[ARG0]] : f32
941-
// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST10]] : f32
941+
// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST0]] : f32
942942
// CHECK-NEXT: memref.store %[[MUL0]], %[[A]][%[[IV]]] : memref<?xf32>
943943
// CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
944944
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
945-
// CHECK-NEXT: scf.yield %[[CST10]], %[[L2]] : f32
945+
// CHECK-NEXT: scf.yield %[[CST0]], %[[L2]] : f32
946946
// CHECK-NEXT: }
947947
// Epilogue:
948-
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST10]] : f32
949-
// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST10]] : f32
948+
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST0]] : f32
949+
// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST0]] : f32
950950
// CHECK-NEXT: memref.store %[[MUL1]], %[[A]][%[[C3]]] : memref<?xf32>
951951
// CHECK-NEXT: return %[[L1]]#0 : f32
952952

953953
func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 {
954954
%c0 = arith.constant 0 : index
955955
%c1 = arith.constant 1 : index
956956
%c4 = arith.constant 4 : index
957-
%cf0 = arith.constant 10.0 : f32
957+
%cf0 = arith.constant 0.0 : f32
958958
%cf2 = arith.constant 2.0 : f32
959959
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf2) -> f32 {
960960
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>

0 commit comments

Comments
 (0)