Skip to content

Commit 864392b

Browse files
committed
[mlir][Arith] Prevent IR modification for non-matching pattern
The F4E2M1 truncation emulation was expanding or truncating operations to F32 even when the pattern did not apply, causing non-convergent rewrites when operating on doubles.
1 parent 7234ae6 commit 864392b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,10 +518,10 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
518518
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
519519
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
520520

521-
if (!isa<Float32Type>(operandETy))
522-
operand = b.create<arith::ExtFOp>(f32Ty, operand);
523521
if (!isa<Float4E2M1FNType>(resultETy))
524522
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
523+
if (!isa<Float32Type>(operandETy))
524+
operand = b.create<arith::ExtFOp>(f32Ty, operand);
525525

526526
Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
527527
Value c0x3 = createConst(loc, i4Ty, 3, rewriter);

0 commit comments

Comments
 (0)