Skip to content

Commit eb55412

Browse files
authored
[mlir][Arith] Prevent IR modification for non-matching pattern (llvm#150103)
The F4E2M1 truncation emulation was expanding or truncating operations to F32 even when the pattern did not apply, causing non-convergent rewrites when operating on doubles. Also, fix a pair of whitespace issues that snuck in.
1 parent 796f551 commit eb55412

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,10 +518,10 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
518518
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
519519
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
520520

521-
if (!isa<Float32Type>(operandETy))
522-
operand = b.create<arith::ExtFOp>(f32Ty, operand);
523521
if (!isa<Float4E2M1FNType>(resultETy))
524522
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
523+
if (!isa<Float32Type>(operandETy))
524+
operand = b.create<arith::ExtFOp>(f32Ty, operand);
525525

526526
Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
527527
Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
@@ -657,6 +657,7 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
657657
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
658658
op.getFastmathAttr());
659659
}
660+
// Catch scale types like f8E5M2.
660661
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
661662
return rewriter.notifyMatchFailure(
662663
op, "scaling_extf is using scales of type which can not be converted "
@@ -777,7 +778,7 @@ struct ArithExpandOpsPass
777778
if (includeBf16)
778779
legalTypes &= !(inETy.isF32() && outETy.isBF16());
779780
if (includeF8E8M0)
780-
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
781+
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
781782
if (includeF4E2M1)
782783
legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
783784
return legalTypes;
@@ -832,7 +833,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
832833
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
833834
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
834835
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
835-
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
836+
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
836837
>(patterns.getContext());
837838
// clang-format on
838839
}

0 commit comments

Comments
 (0)