@@ -34,6 +34,18 @@ static Value createConst(Location loc, Type type, int value,
3434 return rewriter.create <arith::ConstantOp>(loc, attr);
3535}
3636
37+ // / Create an float constant.
38+ static Value createFloatConst (Location loc, Type type, float value,
39+ PatternRewriter &rewriter) {
40+ auto attr = rewriter.getFloatAttr (getElementTypeOrSelf (type), value);
41+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
42+ return rewriter.create <arith::ConstantOp>(
43+ loc, DenseElementsAttr::get (shapedTy, attr));
44+ }
45+
46+ return rewriter.create <arith::ConstantOp>(loc, attr);
47+ }
48+
3749// / Creates shapedType using shape from cloneFrom and base type from cloneTo
3850static Type cloneToShapedType (Type cloneFrom, Type cloneTo) {
3951 if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -439,6 +451,13 @@ Note: x is sign bit
439451| x111 | 6.0
440452
441453Conversion procedure:
454+
455+ Step 1: Clamp to max f4 value
456+
457+ Step 2: convert exponent, if signed int comparison <= 0, set 0
458+
459+ Step 3: if mantissa[1:] greater than 1000000, add 1
460+
442461Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
443462Create bias adjusted exponent, E_1 <- E_0 - 126
444463If E_0 <= 0111 1110
@@ -485,32 +504,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
485504
486505 // Constants
487506 Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
488- Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
489507 Value c0x7e = createConst (op.getLoc (), i8Ty, 0x7e , rewriter);
508+ Value c0x0000007e = createConst (op.getLoc (), i32Ty, 0x7e , rewriter);
509+
490510 Value c0x00000009 = createConst (op->getLoc (), i32Ty, 9 , rewriter);
511+ Value c0x00000016 = createConst (op->getLoc (), i32Ty, 22 , rewriter);
491512 Value c0x00000017 = createConst (op->getLoc (), i32Ty, 23 , rewriter);
513+ Value c0x0000001f = createConst (op->getLoc (), i32Ty, 31 , rewriter);
492514 Value c0x00200000 = createConst (op.getLoc (), i32Ty, 0x200000 , rewriter);
493515 Value c0x00400000 = createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
494516 Value c0x00600000 = createConst (op.getLoc (), i32Ty, 0x600000 , rewriter);
517+ Value c0x003fffff = createConst (op->getLoc (), i32Ty, 0x3fffff , rewriter);
495518 Value c0x007fffff = createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
496-
497- Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
498-
499519 Value cF32MantissaWidth = c0x00000017; // 23
500- Value cF4MantissaWidth = c0x1; // 1
501- Value cF32SignExpWidth = c0x00000009; // 9
520+ Value cF4MantissaWidth = c0x1; // 1
521+ Value cF32SignExpWidth = c0x00000009; // 9
522+ Value cF32FirstBitMask = c0x00400000;
523+ Value cF32Last22BitMask = c0x003fffff;
502524 Value cF32MantissaMask = c0x007fffff;
525+
526+ // Step 1: Clamp to bounds.
527+ Value cHigherBound = createFloatConst (op->getLoc (), f32Ty, 6.0 , rewriter);
528+ Value cLowerBound = createFloatConst (op->getLoc (), f32Ty, -6.0 , rewriter);
529+ Value clampHigh = b.create <arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
530+ Value clampLow = b.create <arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
531+ Value operandClamped = b.create <arith::SelectOp>(clampHigh, cHigherBound, operand);
532+ operandClamped = b.create <arith::SelectOp>(clampLow, cLowerBound, operandClamped);
533+ Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operandClamped);
534+
535+ // Step 2: Convert exponent by adjusting bias.
503536 Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
504- Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32MantissaMask);
505- Value exp8Bits = b.create <arith::TruncIOp>(i8Ty, f32SignExp);
537+ Value biasAdjustment = c0x0000007e; // 126
538+ Value biasAdjustedSignExp = b.create <arith::SubIOp>(f32SignExp, biasAdjustment);
539+ Value f4SignExp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
540+ f4SignExp = b.create <arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
506541
507- Value cSubnormalExp = c0x7e; // 126
542+ // Step 0: Special consideration for conversion to 0.5.
543+ Value cSubnormalLowerBound = createFloatConst (op->getLoc (), f32Ty, 0.25 , rewriter);
544+ Value cSubnormalHigherBound = createFloatConst (op->getLoc (), f32Ty, 0.75 , rewriter);
545+ Value cLowerBound = createConst (op->getLoc (), f32Ty, -6.0 , rewriter);
546+ Value isSubnormal =
547+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
548+
549+ // Step 3: Set mantissa to first bit.
550+ Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32FirstBitMask);
551+ Value man1Bit = b.create <arith::ShRUIOp>(man23Bits, c0x00000016);
552+ Value f4Man = b.create <arith::TruncIOp>(i4Ty, man1Bit);
553+ Value f4Bits = b.create <arith::AddIOp>(f4SignExp, f4Man);
554+
555+ // Step 4: Round up if necessary.
556+ Value cRound = c0x00200000; // 010 0000...
557+ Value man22Bits = b.create <arith::AndIOp>(f32Bits, cF32Last22BitMask);
558+ Value shouldRound =
559+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
560+ Value roundedF4Bits = b.create <arith::AddIOp>(f4Bits, c0x1);
561+ f4Bits = b.create <arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
508562
509563 // Regular case
510- Value biasAdjustment = c0x7e; // 126
511- Value cRoundUp = c0x00600000; // 110 0000...
512- Value cRoundDown = c0x00200000; // 010 0000...
513- Value biasAdjustedExp = b.create <arith::SubIOp>(exp8Bits, biasAdjustment);
564+
565+
566+
514567 Value f4Exp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedExp);
515568 Value f4ExpRounded = b.create <arith::AddIOp>(f4Exp, c0x1);
516569 // If we round up or down to even, set mantissa to 0
0 commit comments