Skip to content

Commit 63c10e6

Browse files
intermediate commit
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 36b5905 commit 63c10e6

File tree

1 file changed

+66
-13
lines changed

1 file changed

+66
-13
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ static Value createConst(Location loc, Type type, int value,
3434
return rewriter.create<arith::ConstantOp>(loc, attr);
3535
}
3636

37+
/// Create an float constant.
38+
static Value createFloatConst(Location loc, Type type, float value,
39+
PatternRewriter &rewriter) {
40+
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
41+
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
42+
return rewriter.create<arith::ConstantOp>(
43+
loc, DenseElementsAttr::get(shapedTy, attr));
44+
}
45+
46+
return rewriter.create<arith::ConstantOp>(loc, attr);
47+
}
48+
3749
/// Creates shapedType using shape from cloneFrom and base type from cloneTo
3850
static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
3951
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -439,6 +451,13 @@ Note: x is sign bit
439451
| x111 | 6.0
440452
441453
Conversion procedure:
454+
455+
Step 1: Clamp to max f4 value
456+
457+
Step 2: convert exponent, if signed int comparison <= 0, set 0
458+
459+
Step 3: if mantissa[1:] greater than 1000000, add 1
460+
442461
Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
443462
Create bias adjusted exponent, E_1 <- E_0 - 126
444463
If E_0 <= 0111 1110
@@ -485,32 +504,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
485504

486505
// Constants
487506
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
488-
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
489507
Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
508+
Value c0x0000007e = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
509+
490510
Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
511+
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
491512
Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
513+
Value c0x0000001f = createConst(op->getLoc(), i32Ty, 31, rewriter);
492514
Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
493515
Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
494516
Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
517+
Value c0x003fffff = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
495518
Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
496-
497-
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
498-
499519
Value cF32MantissaWidth = c0x00000017; // 23
500-
Value cF4MantissaWidth = c0x1; // 1
501-
Value cF32SignExpWidth = c0x00000009; // 9
520+
Value cF4MantissaWidth = c0x1; // 1
521+
Value cF32SignExpWidth = c0x00000009; // 9
522+
Value cF32FirstBitMask = c0x00400000;
523+
Value cF32Last22BitMask = c0x003fffff;
502524
Value cF32MantissaMask = c0x007fffff;
525+
526+
// Step 1: Clamp to bounds.
527+
Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
528+
Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
529+
Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
530+
Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
531+
Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
532+
operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
533+
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
534+
535+
// Step 2: Convert exponent by adjusting bias.
503536
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
504-
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
505-
Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
537+
Value biasAdjustment = c0x0000007e; // 126
538+
Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
539+
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
540+
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
506541

507-
Value cSubnormalExp = c0x7e; // 126
542+
// Step 0: Special consideration for conversion to 0.5.
543+
Value cSubnormalLowerBound = createFloatConst(op->getLoc(), f32Ty, 0.25, rewriter);
544+
Value cSubnormalHigherBound = createFloatConst(op->getLoc(), f32Ty, 0.75, rewriter);
545+
Value cLowerBound = createConst(op->getLoc(), f32Ty, -6.0, rewriter);
546+
Value isSubnormal =
547+
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
548+
549+
// Step 3: Set mantissa to first bit.
550+
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
551+
Value man1Bit = b.create<arith::ShRUIOp>(man23Bits, c0x00000016);
552+
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
553+
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
554+
555+
// Step 4: Round up if necessary.
556+
Value cRound = c0x00200000; // 010 0000...
557+
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
558+
Value shouldRound =
559+
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
560+
Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
561+
f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
508562

509563
// Regular case
510-
Value biasAdjustment = c0x7e; // 126
511-
Value cRoundUp = c0x00600000; // 110 0000...
512-
Value cRoundDown = c0x00200000; // 010 0000...
513-
Value biasAdjustedExp = b.create<arith::SubIOp>(exp8Bits, biasAdjustment);
564+
565+
566+
514567
Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
515568
Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
516569
// If we round up or down to even, set mantissa to 0

0 commit comments

Comments
 (0)