Skip to content

Commit 2c20ce6

Browse files
adding extf implementation
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 4c672c3 commit 2c20ce6

File tree

1 file changed

+94
-23
lines changed

1 file changed

+94
-23
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 94 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,70 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
334334
}
335335
};
336336

337+
struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
338+
using OpRewritePattern::OpRewritePattern;
339+
LogicalResult matchAndRewrite(arith::ExtFOp op,
340+
PatternRewriter &rewriter) const final {
341+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
342+
Value operand = op.getOperand();
343+
Type operandTy = operand.getType();
344+
Type resultTy = op.getType();
345+
Type operandETy = getElementTypeOrSelf(operandTy);
346+
Type resultETy = getElementTypeOrSelf(resultTy);
347+
348+
if (!llvm::isa<Float4E2M1FNType>(operandETy) ||
349+
!llvm::isa<Float32Type>(resultETy)) {
350+
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN to F32");
351+
}
352+
353+
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
354+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
355+
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
356+
357+
Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
358+
359+
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
360+
Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
361+
Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
362+
Value cZero =
363+
createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
364+
Value cHalf =
365+
createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
366+
367+
Value mantissaBitmask = c0x1;
368+
Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
369+
Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
370+
371+
Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
372+
Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
373+
f32Bits = b.create<arith::ShRUIOp>(f32Bits, c0x0000001c);
374+
375+
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
376+
Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
377+
f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
378+
Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
379+
f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
380+
f32ExpBits = b.create<arith::ShLIOp>(f32ExpBits, c0x00000014);
381+
f32Bits = b.create<arith::AddIOp>(f32Bits, f32ExpBits);
382+
383+
Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
384+
Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
385+
f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
386+
387+
// Special consideration for subnormal exp (exp == 0).
388+
Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
389+
f32ExpBits, biasAdjustment);
390+
Value isManSet =
391+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
392+
Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
393+
f32Bits = b.create<arith::SelectOp>(isSubnormal, subnormalVal, f32Bits);
394+
395+
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
396+
rewriter.replaceOp(op, result);
397+
return success();
398+
}
399+
};
400+
337401
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
338402
using OpRewritePattern::OpRewritePattern;
339403
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -389,7 +453,7 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
389453
/// Table of representable values in F4E2M1:
390454
///
391455
/// Note: x is sign bit
392-
/// | Binary | Value ( + / - )
456+
/// | Binary | Value ( + / - )
393457
/// | x000 | 0.0
394458
/// | x001 | 0.5
395459
/// | x010 | 1.0
@@ -399,12 +463,13 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
399463
/// | x110 | 4.0
400464
/// | x111 | 6.0
401465
///
402-
/// Conversion procedure:
466+
/// Conversion procedure:
403467
/// Step 1: Clamp to representable bounds.
404468
/// Step 2: Convert exponent by adjusting bias.
405469
/// Step 3: Set mantissa to first bit.
406470
/// Step 4: Special consideration for subnormal and zero exponent.
407-
/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
471+
/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
472+
/// subnormal.
408473
struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
409474
using OpRewritePattern::OpRewritePattern;
410475
LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -442,48 +507,54 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
442507

443508
// Step 2: Convert exponent by adjusting bias.
444509
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
445-
Value cF4MantissaWidth = c0x1; // 1
446-
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
510+
Value cF4MantissaWidth = c0x1; // 1
511+
Value cF32MantissaWidth =
512+
createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
447513
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
448-
Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
514+
Value biasAdjustedSignExp =
515+
b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
449516
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
450517
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
451-
518+
452519
// Step 3: Set mantissa to first bit.
453-
Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
520+
Value cF32FirstBitMask =
521+
createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
454522
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
455523
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
456524
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
457525
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
458-
526+
459527
// Step 4: Special consideration for conversion to 0.5.
460-
Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
528+
Value cF32MantissaMask =
529+
createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
461530
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
462531
Value isSubnormal =
463-
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
532+
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
464533
Value isNegOneExp =
465-
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
534+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
466535
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
467-
Value isNonZeroMan =
468-
b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
536+
Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
537+
man23Bits, c0x00000000);
469538
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
470-
Value isZeroExp =
471-
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
472-
539+
Value isZeroExp =
540+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
541+
473542
Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
474543
Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
475-
Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
544+
Value subResult =
545+
b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
476546
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
477547
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
478-
548+
479549
// Step 5: Round up if necessary.
480-
Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
481-
Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
550+
Value cF32Last22BitMask =
551+
createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
552+
Value cRound =
553+
createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
482554
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
483555
Value shouldRound =
484556
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
485-
shouldRound =
486-
b.create<arith::OrIOp>(shouldRound, isSubnormal);
557+
shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
487558
Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
488559
f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
489560

0 commit comments

Comments
 (0)