@@ -243,8 +243,9 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
243243 Type operandETy = getElementTypeOrSelf (operandTy);
244244 Type resultETy = getElementTypeOrSelf (resultTy);
245245
246- if (!operandETy.isBF16 () || !resultETy.isF32 ())
246+ if (!operandETy.isBF16 () || !resultETy.isF32 ()) {
247247 return rewriter.notifyMatchFailure (op, " not a ext of bf16 to f32." );
248+ }
248249
249250 Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
250251 Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
@@ -272,8 +273,9 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
272273 Type operandETy = getElementTypeOrSelf (operandTy);
273274 Type resultETy = getElementTypeOrSelf (resultTy);
274275
275- if (!operandETy.isF32 () || !resultETy.isBF16 ())
276+ if (!operandETy.isF32 () || !resultETy.isBF16 ()) {
276277 return rewriter.notifyMatchFailure (op, " not a trunc of f32 to bf16." );
278+ }
277279
278280 if (op.getRoundingmodeAttr ()) {
279281 return rewriter.notifyMatchFailure (
@@ -422,7 +424,7 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
422424 Value bits1To32 = b.create <arith::AddIOp>(bits1To31, bit32);
423425 Value result = b.create <arith::BitcastOp>(f32Ty, bits1To32);
424426 if (!isa<Float32Type>(resultETy))
425- result = b.create <arith::TruncFOp>(resultETy, operand );
427+ result = b.create <arith::TruncFOp>(resultTy, result );
426428
427429 rewriter.replaceOp (op, result);
428430 return success ();
@@ -440,8 +442,9 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
440442 Type operandETy = getElementTypeOrSelf (operandTy);
441443 Type resultETy = getElementTypeOrSelf (resultTy);
442444
443- if (!llvm::isa<Float8E8M0FNUType>(operandETy))
445+ if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
444446 return rewriter.notifyMatchFailure (op, " not a ext of F8E8M0FNU" );
447+ }
445448
446449 Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
447450 Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
@@ -512,16 +515,16 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
512515 Type operandETy = getElementTypeOrSelf (operandTy);
513516 Type resultETy = getElementTypeOrSelf (resultTy);
514517
515- if (!isa<Float32Type>(operandETy))
516- operand = b.create <arith::ExtFOp>(b.getF32Type (), operand);
517- if (!isa<Float4E2M1FNType>(resultETy))
518- return rewriter.notifyMatchFailure (op, " not a trunc of F4E2M1FN" );
519-
520518 Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
521519 Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
522520 Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
523521 Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
524522
523+ if (!isa<Float32Type>(operandETy))
524+ operand = b.create <arith::ExtFOp>(f32Ty, operand);
525+ if (!isa<Float4E2M1FNType>(resultETy))
526+ return rewriter.notifyMatchFailure (op, " not a trunc of F4E2M1FN" );
527+
525528 Value c0x1 = createConst (loc, i4Ty, 1 , rewriter);
526529 Value c0x3 = createConst (loc, i4Ty, 3 , rewriter);
527530 Value c0x00000016 = createConst (loc, i32Ty, 22 , rewriter);
@@ -611,12 +614,14 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
611614 Type operandETy = getElementTypeOrSelf (operandTy);
612615 Type resultTy = op.getType ();
613616 Type resultETy = getElementTypeOrSelf (resultTy);
614- if (!llvm::isa<Float8E8M0FNUType>(resultETy))
617+ if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
615618 return rewriter.notifyMatchFailure (op, " not a truncf to f8E8M0FNU" );
619+ }
616620
617- if (op.getRoundingmodeAttr ())
621+ if (op.getRoundingmodeAttr ()) {
618622 return rewriter.notifyMatchFailure (
619623 op, " only applicable to default rounding mode." );
624+ }
620625
621626 Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
622627 Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
0 commit comments