Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,13 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();

if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
arith::FastMathFlags::nsz)) {
Comment on lines +1285 to +1286
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't it need also ninf? inf * 0 -> Nan

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to check this with Alive: https://alive2.llvm.org/ce/z/wvNkdy

Copy link
Member

@kuhar kuhar Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because nnan applies to the result as well:

nnan
No NaNs - Allow optimizations to assume the arguments and result are not NaN.

// mulf(x, 0) -> 0
if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
return getRhs();
}

return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return a * b; });
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2216,6 +2216,18 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) {
return %2 : f32
}

// CHECK-LABEL: @test_mulf2(
func.func @test_mulf2(%arg0 : f32) -> (f32, f32) {
// CHECK-NEXT: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[C0n:.+]] = arith.constant -0.000000e+00 : f32
// CHECK-NEXT: return %[[C0]], %[[C0n]]
%c0 = arith.constant 0.0 : f32
%c0n = arith.constant -0.0 : f32
%0 = arith.mulf %c0, %arg0 fastmath<nnan,nsz> : f32
%1 = arith.mulf %c0n, %arg0 fastmath<nnan,nsz> : f32
return %0, %1 : f32, f32
}

// -----

// CHECK-LABEL: @test_divf(
Expand Down