@@ -657,6 +657,7 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
657657 scaleOperand = b.create <arith::TruncFOp>(scaleTy, scaleOperand, nullptr ,
658658 op.getFastmathAttr ());
659659 }
660+ // Catch scale types like f8E5M2.
660661 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
661662 return rewriter.notifyMatchFailure (
662663 op, " scaling_extf is using scales of type which can not be converted "
@@ -777,7 +778,7 @@ struct ArithExpandOpsPass
777778 if (includeBf16)
778779 legalTypes &= !(inETy.isF32 () && outETy.isBF16 ());
779780 if (includeF8E8M0)
780- legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
781+ legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
781782 if (includeF4E2M1)
782783 legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
783784 return legalTypes;
@@ -832,7 +833,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
832833 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
833834 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
834835 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
835- MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
836+ MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
836837 >(patterns.getContext ());
837838 // clang-format on
838839}
0 commit comments