@@ -330,21 +330,16 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
330330 LogicalResult matchAndRewrite (arith::ExtFOp op,
331331 PatternRewriter &rewriter) const final {
332332 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
333- auto operand = op.getOperand ();
333+ Value operand = op.getOperand ();
334334 Type operandTy = operand.getType ();
335335 Type resultTy = op.getType ();
336336 Type operandETy = getElementTypeOrSelf (operandTy);
337337 Type resultETy = getElementTypeOrSelf (resultTy);
338338
339- if (!operandETy. isF8E8M0FNU ( )) {
339+ if (!llvm::isa<Float8E8M0FNUType>(operandETy )) {
340340 return rewriter.notifyMatchFailure (op, " not a ext of F8E8M0FNU" );
341341 }
342342
343- if (!resultETy.isBF16 () && !resultETy.isF16 () && !resultETy.isF32 ()) {
344- return rewriter.notifyMatchFailure (
345- op, " not a ext of F8M0FNU on a larger 16-bit or 32-bit width float." );
346- }
347-
348343 Type i8Ty = b.getI8Type ();
349344 Type i32Ty = b.getI32Type ();
350345 Type f32Ty = b.getF32Type ();
@@ -368,10 +363,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
368363 // select for NaNs
369364 f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
370365 Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
371- if (resultETy.isBF16 ()) {
372- result = b.create <arith::TruncFOp>(resultTy, result);
373- } else if (resultETy.isF16 ()) {
366+ if (resultETy.getIntOrFloatBitWidth () < 32 ) {
374367 result = b.create <arith::TruncFOp>(resultTy, result);
368+ } else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
369+ result = b.create <arith::ExtFOp>(resultTy, result);
375370 }
376371 rewriter.replaceOp (op, result);
377372 return success ();
@@ -388,18 +383,14 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
388383 LogicalResult matchAndRewrite (arith::TruncFOp op,
389384 PatternRewriter &rewriter) const final {
390385 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
391- auto operand = op.getOperand ();
386+ Value operand = op.getOperand ();
392387 Type operandTy = operand.getType ();
393388 Type operandETy = getElementTypeOrSelf (operandTy);
394389 Type resultTy = op.getType ();
395390 Type resultETy = getElementTypeOrSelf (resultTy);
396- if (!resultETy. isF8E8M0FNU ( )) {
391+ if (!llvm::isa<Float8E8M0FNUType>(resultETy )) {
397392 return rewriter.notifyMatchFailure (op, " not a truncf to f8E8M0FNU" );
398393 }
399- if (!operandETy.isBF16 () && !operandETy.isF16 () && !operandETy.isF32 ()) {
400- return rewriter.notifyMatchFailure (
401- op, " not a truncf of 16-bit or 32-bit float to f8E8M0FNU." );
402- }
403394
404395 if (op.getRoundingmodeAttr ()) {
405396 return rewriter.notifyMatchFailure (
@@ -414,8 +405,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
414405 i32Ty = shapedTy.clone (i32Ty);
415406 f32Ty = shapedTy.clone (f32Ty);
416407 }
417- if (! operandETy.isF32 () ) {
408+ if (operandETy.getIntOrFloatBitWidth () < 32 ) {
418409 operand = b.create <arith::ExtFOp>(f32Ty, operand);
410+ } else if (operandETy.getIntOrFloatBitWidth () > 32 ) {
411+ operand = b.create <arith::TruncFOp>(f32Ty, operand);
419412 }
420413 Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
421414 Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
@@ -453,36 +446,37 @@ struct ArithExpandOpsPass
453446 arith::MinNumFOp
454447 >();
455448
456- if (includeBf16) {
449+ if (includeBf16) {
457450 arith::populateExpandBFloat16Patterns (patterns);
458451 }
459- if (includeF8E8M0) {
452+ if (includeF8E8M0) {
460453 arith::populateExpandF8E8M0Patterns (patterns);
461454 }
462- if (includeBf16 || includeF8E8M0) {
463- target.addDynamicallyLegalOp <arith::ExtFOp>(
464- [=](arith::ExtFOp op) {
465- Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
466- Type outETy = getElementTypeOrSelf (op.getType ());
467- if (includeBf16 && includeF8E8M0)
468- return !(inETy.isBF16 () && outETy.isF32 ()) && !(inETy.isF8E8M0FNU () && (outETy.isF32 () || outETy.isBF16 () || outETy.isF16 ()));
469- if (includeBf16)
470- return !(inETy.isBF16 () && outETy.isF32 ());
471- return !(inETy.isF8E8M0FNU () && (outETy.isF32 () || outETy.isBF16 () || outETy.isF16 ()));
472- });
473-
474- target.addDynamicallyLegalOp <arith::TruncFOp>(
475- [=](arith::TruncFOp op) {
476- Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
477- Type outETy = getElementTypeOrSelf (op.getType ());
478- if (includeBf16 && includeF8E8M0)
479- return !(inETy.isF32 () && outETy.isBF16 ()) && !(outETy.isF8E8M0FNU () && (inETy.isF32 () || inETy.isF16 () || inETy.isBF16 ()));
480- if (includeBf16)
481- return !(inETy.isF32 () && outETy.isBF16 ());
482- return
483- !(outETy.isF8E8M0FNU () && (inETy.isF32 () || inETy.isF16 () || inETy.isBF16 ()));
484- });
485- }
455+
456+ target.addDynamicallyLegalOp <arith::ExtFOp>(
457+ [=](arith::ExtFOp op) {
458+ Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
459+ Type outETy = getElementTypeOrSelf (op.getType ());
460+ bool legalTypes = true ;
461+ if (includeBf16)
462+ legalTypes &= !(inETy.isBF16 () && outETy.isF32 ());
463+ if (includeF8E8M0)
464+ legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
465+ return legalTypes;
466+ });
467+
468+ target.addDynamicallyLegalOp <arith::TruncFOp>(
469+ [=](arith::TruncFOp op) {
470+ Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
471+ Type outETy = getElementTypeOrSelf (op.getType ());
472+ bool legalTypes = true ;
473+ if (includeBf16)
474+ legalTypes &= !(inETy.isF32 () && outETy.isBF16 ());
475+ if (includeF8E8M0)
476+ legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
477+ return legalTypes;
478+ });
479+
486480 // clang-format on
487481 if (failed (applyPartialConversion (getOperation (), target,
488482 std::move (patterns))))
0 commit comments