@@ -345,9 +345,8 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
345345 Type operandETy = getElementTypeOrSelf (operandTy);
346346 Type resultETy = getElementTypeOrSelf (resultTy);
347347
348- if (!llvm::isa<Float4E2M1FNType>(operandETy) ||
349- !llvm::isa<Float32Type>(resultETy)) {
350- return rewriter.notifyMatchFailure (op, " not a ext of F4E2M1FN to F32" );
348+ if (!isa<Float4E2M1FNType>(operandETy)) {
349+ return rewriter.notifyMatchFailure (op, " not a ext of F4E2M1FN" );
351350 }
352351
353352 Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
@@ -357,8 +356,9 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
357356 Value bitcast = b.create <arith::BitcastOp>(i4Ty, operand);
358357
359358 Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
360- Value c0x0000001c = createConst (op->getLoc (), i32Ty, 28 , rewriter);
361359 Value c0x00000014 = createConst (op->getLoc (), i32Ty, 22 , rewriter);
360+ Value c0x00000015 = createConst (op->getLoc (), i32Ty, 23 , rewriter);
361+ Value c0x0000001c = createConst (op->getLoc (), i32Ty, 28 , rewriter);
362362 Value cZero =
363363 createFloatConst (op->getLoc (), f32Ty, APFloat (0 .0f ), rewriter);
364364 Value cHalf =
@@ -370,29 +370,33 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
370370
371371 Value f4SignBit = b.create <arith::AndIOp>(bitcast, signBitmask);
372372 Value f32Bits = b.create <arith::ExtUIOp>(i32Ty, f4SignBit);
373- f32Bits = b.create <arith::ShRUIOp >(f32Bits, c0x0000001c);
373+ f32Bits = b.create <arith::ShLIOp >(f32Bits, c0x0000001c);
374374
375375 Value biasAdjustment = createConst (op.getLoc (), i32Ty, 126 , rewriter);
376376 Value f4ExpBits = b.create <arith::AndIOp>(bitcast, exponentBitmask);
377377 f4ExpBits = b.create <arith::ShRUIOp>(f4ExpBits, c0x1);
378378 Value f32ExpBits = b.create <arith::ExtUIOp>(i32Ty, f4ExpBits);
379379 f32ExpBits = b.create <arith::AddIOp>(f32ExpBits, biasAdjustment);
380- f32ExpBits = b.create <arith::ShLIOp>(f32ExpBits, c0x00000014 );
381- f32Bits = b.create <arith::AddIOp>(f32Bits, f32ExpBits );
380+ Value f32Exp = b.create <arith::ShLIOp>(f32ExpBits, c0x00000015 );
381+ f32Bits = b.create <arith::AddIOp>(f32Bits, f32Exp );
382382
383383 Value f4ManBit = b.create <arith::AndIOp>(bitcast, mantissaBitmask);
384384 Value f32ManBit = b.create <arith::ExtUIOp>(i32Ty, f4ManBit);
385+ f32ManBit = b.create <arith::ShLIOp>(f32ManBit, c0x00000014);
385386 f32Bits = b.create <arith::AddIOp>(f32Bits, f32ManBit);
386387
387- // Special consideration for subnormal exp (exp == 0 ).
388+ // Special consideration for subnormal exponent (exp == 00 ).
388389 Value isSubnormal = b.create <arith::CmpIOp>(arith::CmpIPredicate::eq,
389390 f32ExpBits, biasAdjustment);
390391 Value isManSet =
391392 b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
392393 Value subnormalVal = b.create <arith::SelectOp>(isManSet, cHalf, cZero);
393- f32Bits = b.create <arith::SelectOp>(isSubnormal, subnormalVal, f32Bits);
394394
395395 Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
396+ result = b.create <arith::SelectOp>(isSubnormal, subnormalVal, result);
397+ if (!isa<Float32Type>(resultETy)) {
398+ result = b.create <arith::TruncFOp>(resultETy, operand);
399+ }
396400 rewriter.replaceOp (op, result);
397401 return success ();
398402 }
@@ -481,8 +485,11 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
481485 Type operandETy = getElementTypeOrSelf (operandTy);
482486 Type resultETy = getElementTypeOrSelf (resultTy);
483487
484- if (!isa<Float32Type>(operandETy) || !isa<Float4E2M1FNType>(resultETy)) {
485- return rewriter.notifyMatchFailure (op, " not a trunc of F32 to F4E2M1FN" );
488+ if (!isa<Float32Type>(operandETy)) {
489+ operand = b.create <arith::ExtFOp>(b.getF32Type (), operand);
490+ }
491+ if (!isa<Float4E2M1FNType>(resultETy)) {
492+ return rewriter.notifyMatchFailure (op, " not a trunc of F4E2M1FN" );
486493 }
487494
488495 Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
@@ -491,20 +498,28 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
491498 Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
492499
493500 Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
501+ Value c0x3 = createConst (op->getLoc (), i4Ty, 3 , rewriter);
494502 Value c0x00000016 = createConst (op->getLoc (), i32Ty, 22 , rewriter);
495503 Value c0x00 = createConst (op.getLoc (), i8Ty, 0x00 , rewriter);
496504 Value c0xff = createConst (op.getLoc (), i8Ty, 0xff , rewriter);
497505 Value c0x00000000 = createConst (op.getLoc (), i32Ty, 0 , rewriter);
498506
499- // Step 1 : Clamp to bounds.
507+ // Step 0 : Clamp to bounds.
500508 Value cHigherBound =
501509 createFloatConst (op->getLoc (), f32Ty, APFloat (6 .0f ), rewriter);
502510 Value cLowerBound =
503511 createFloatConst (op->getLoc (), f32Ty, APFloat (-6 .0f ), rewriter);
504- Value operandClamped = b.create <arith::MinimumFOp>(cLowerBound , operand);
505- operandClamped = b.create <arith::MaximumFOp>(cHigherBound , operandClamped);
512+ Value operandClamped = b.create <arith::MinimumFOp>(cHigherBound , operand);
513+ operandClamped = b.create <arith::MaximumFOp>(cLowerBound , operandClamped);
506514 Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operandClamped);
507515
516+ // Step 1: Set sign bit.
517+ Value cF32ExpManWidth =
518+ createConst (op->getLoc (), i32Ty, 31 , rewriter); // 23
519+ Value f32Sign = b.create <arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
520+ Value f4Sign = b.create <arith::TruncIOp>(i4Ty, f32Sign);
521+ Value f4Bits = b.create <arith::ShLIOp>(f4Sign, c0x3);
522+
508523 // Step 2: Convert exponent by adjusting bias.
509524 Value biasAdjustment = createConst (op.getLoc (), i32Ty, 0x7e , rewriter);
510525 Value cF4MantissaWidth = c0x1; // 1
@@ -513,16 +528,17 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
513528 Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
514529 Value biasAdjustedSignExp =
515530 b.create <arith::SubIOp>(f32SignExp, biasAdjustment);
516- Value f4SignExp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
517- f4SignExp = b.create <arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
531+ Value f4Exp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
532+ f4Exp = b.create <arith::ShLIOp>(f4Exp, cF4MantissaWidth);
533+ f4Bits = b.create <arith::AddIOp>(f4Bits, f4Exp);
518534
519535 // Step 3: Set mantissa to first bit.
520536 Value cF32FirstBitMask =
521537 createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
522538 Value man1Bit = b.create <arith::AndIOp>(f32Bits, cF32FirstBitMask);
523539 man1Bit = b.create <arith::ShRUIOp>(man1Bit, c0x00000016);
524540 Value f4Man = b.create <arith::TruncIOp>(i4Ty, man1Bit);
525- Value f4Bits = b.create <arith::AddIOp>(f4SignExp , f4Man);
541+ f4Bits = b.create <arith::AddIOp>(f4Bits , f4Man);
526542
527543 // Step 4: Special consideration for conversion to 0.5.
528544 Value cF32MantissaMask =
@@ -538,7 +554,6 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
538554 Value roundToHalf = b.create <arith::AndIOp>(isNegOneExp, isNonZeroMan);
539555 Value isZeroExp =
540556 b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
541-
542557 Value subnormalF4Bits = createConst (op->getLoc (), i4Ty, 0xf , rewriter);
543558 Value halfF4Bits = createConst (op->getLoc (), i4Ty, 0x0 , rewriter);
544559 Value subResult =
@@ -719,16 +734,24 @@ struct ArithExpandOpsPass
719734 if (includeF8E8M0) {
720735 arith::populateExpandF8E8M0Patterns (patterns);
721736 }
737+ if (includeF4E2M1) {
738+ arith::populateExpandF4E2M1Patterns (patterns);
739+ }
722740
723741 target.addDynamicallyLegalOp <arith::ExtFOp>(
724742 [=](arith::ExtFOp op) {
725743 Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
726744 Type outETy = getElementTypeOrSelf (op.getType ());
727745 bool legalTypes = true ;
728- if (includeBf16)
746+ if (includeBf16) {
729747 legalTypes &= !(inETy.isBF16 () && outETy.isF32 ());
730- if (includeF8E8M0)
748+ }
749+ if (includeF8E8M0) {
731750 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
751+ }
752+ if (includeF4E2M1) {
753+ legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
754+ }
732755 return legalTypes;
733756 });
734757
@@ -737,10 +760,15 @@ struct ArithExpandOpsPass
737760 Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
738761 Type outETy = getElementTypeOrSelf (op.getType ());
739762 bool legalTypes = true ;
740- if (includeBf16)
763+ if (includeBf16) {
741764 legalTypes &= !(inETy.isF32 () && outETy.isBF16 ());
742- if (includeF8E8M0)
765+ }
766+ if (includeF8E8M0) {
743767 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
768+ }
769+ if (includeF4E2M1) {
770+ legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
771+ }
744772 return legalTypes;
745773 });
746774
@@ -765,6 +793,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
765793 patterns.getContext ());
766794}
767795
796+ void mlir::arith::populateExpandF4E2M1Patterns (RewritePatternSet &patterns) {
797+ patterns.add <F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
798+ patterns.getContext ());
799+ }
800+
768801void mlir::arith::populateExpandF8E8M0Patterns (RewritePatternSet &patterns) {
769802 patterns.add <F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
770803 patterns.getContext ());
0 commit comments