Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,14 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();

arith::FastMathFlags fmf = getFastmath();
if (arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
arith::FastMathFlags::ninf)) {
// mulf(x, 0) -> 0
if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
return getRhs();
}

return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return a * b; });
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2216,6 +2216,16 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) {
return %2 : f32
}

// CHECK-LABEL: @test_mulf2(
func.func @test_mulf2(%arg0 : f32, %arg1 : f32) -> (f32, f32) {
// CHECK-NEXT: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: return %[[C0]], %[[C0]]
%c0 = arith.constant 0.0 : f32
%0 = arith.mulf %arg0, %c0 fastmath<nnan,ninf> : f32
%1 = arith.mulf %c0, %arg1 fastmath<nnan,ninf> : f32
return %0, %1 : f32, f32
}

// -----

// CHECK-LABEL: @test_divf(
Expand Down