From 864392b3dc7e8afe91580dc4a56234a0762702e0 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 22 Jul 2025 20:25:46 +0000 Subject: [PATCH 1/4] [mlir][Arith] Prevent IR modification for non-matching pattern The F4E2M1 truncation emulation was expanding or truncating operations to F32 even when the pattern did not apply, causing non-convergent rewrites when operating on doubles. --- mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index f497d2db3bf7c..5e575de4065ca 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -518,10 +518,10 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern { Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); - if (!isa(operandETy)) - operand = b.create(f32Ty, operand); if (!isa(resultETy)) return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN"); + if (!isa(operandETy)) + operand = b.create(f32Ty, operand); Value c0x1 = createConst(loc, i4Ty, 1, rewriter); Value c0x3 = createConst(loc, i4Ty, 3, rewriter); From 7c1b38b13fb26cdc38c7071581948282b403fb3b Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 22 Jul 2025 20:41:04 +0000 Subject: [PATCH 2/4] Remove a can't-happen match failure --- mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 5e575de4065ca..bd6947278aac2 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -657,11 +657,7 @@ struct ScalingExtFOpConverter : public OpRewritePattern { scaleOperand = b.create(scaleTy, scaleOperand, nullptr, op.getFastmathAttr()); } - if (!llvm::isa(scaleETy)) { - return rewriter.notifyMatchFailure( - op, "scaling_extf is using scales of type which can not be converted " - "to f8E8M0FNU"); - } + Type resultTy = op.getType(); // extf on scale will essentially create floating point number // of type resulTy that is 2^scale and will also propagate NaNs From e7b511636e1cad45531dab2c4f2e7f477f13d07b Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 22 Jul 2025 21:24:27 +0000 Subject: [PATCH 3/4] Revert "Remove a can't-happen match failure" This reverts commit 7c1b38b13fb26cdc38c7071581948282b403fb3b. --- mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index bd6947278aac2..5e575de4065ca 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -657,7 +657,11 @@ struct ScalingExtFOpConverter : public OpRewritePattern { scaleOperand = b.create(scaleTy, scaleOperand, nullptr, op.getFastmathAttr()); } - + if (!llvm::isa(scaleETy)) { + return rewriter.notifyMatchFailure( + op, "scaling_extf is using scales of type which can not be converted " + "to f8E8M0FNU"); + } Type resultTy = op.getType(); // extf on scale will essentially create floating point number // of type resulTy that is 2^scale and will also propagate NaNs From 93646f252b52115cb9c57be476bf478e5e811d67 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 22 Jul 2025 21:27:35 +0000 Subject: [PATCH 4/4] Add explanatory comment for guard, fix misc. whitespace issues --- mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 5e575de4065ca..ab57557f3f13d 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -657,6 +657,7 @@ struct ScalingExtFOpConverter : public OpRewritePattern { scaleOperand = b.create(scaleTy, scaleOperand, nullptr, op.getFastmathAttr()); } + // Catch scale types like f8E5M2. if (!llvm::isa(scaleETy)) { return rewriter.notifyMatchFailure( op, "scaling_extf is using scales of type which can not be converted " @@ -777,7 +778,7 @@ struct ArithExpandOpsPass if (includeBf16) legalTypes &= !(inETy.isF32() && outETy.isBF16()); if (includeF8E8M0) - legalTypes &= !(llvm::isa(outETy)); + legalTypes &= !(llvm::isa(outETy)); if (includeF4E2M1) legalTypes &= !llvm::isa(outETy); return legalTypes; @@ -832,7 +833,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { MaximumMinimumFOpConverter, MaximumMinimumFOpConverter, MaxNumMinNumFOpConverter, - MaxNumMinNumFOpConverter + MaxNumMinNumFOpConverter >(patterns.getContext()); // clang-format on }