Skip to content

Commit 814d53b

Browse files
improving extf implementation
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 607d6b1 commit 814d53b

File tree

1 file changed

+102
-148
lines changed

1 file changed

+102
-148
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 102 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1212
#include "mlir/IR/BuiltinTypeInterfaces.h"
1313
#include "mlir/IR/ImplicitLocOpBuilder.h"
14+
#include "mlir/IR/Location.h"
1415
#include "mlir/IR/TypeUtilities.h"
1516
#include "mlir/Transforms/DialectConversion.h"
1617
#include "llvm/ADT/SmallVectorExtras.h"
18+
#include <cstdint>
1719

1820
namespace mlir {
1921
namespace arith {
@@ -333,133 +335,92 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
333335
}
334336
};
335337

338+
/// In this implementation of extf we take advantage of some key patterns we
339+
/// notice between the binary representation of an F4E2M1 value and its
340+
/// corresponding value in F32.
341+
///
342+
/// Note: x is sign bit
343+
/// | Binary | F4E2M1 | f32[23:32]
344+
/// | x000 | 0.0 | x000 0000 00
345+
/// | x001 | 0.5 | x011 1111 00
346+
/// | x010 | 1.0 | x011 1111 10
347+
/// | x011 | 1.5 | x011 1111 11
348+
/// | x100 | 2.0 | x010 0000 00
349+
/// | x101 | 3.0 | x010 0000 01
350+
/// | x110 | 4.0 | x010 0000 10
351+
/// | x111 | 6.0 | x010 0000 11
352+
///
353+
/// 1) There are only two versions of bits [25:31] in the f32 result
354+
/// F4E2M1 bits[2:3] decide whether:
355+
/// - F32 bits[25:31] = 0011 1111
356+
/// - F32 bits[25:31] = 0010 0000
357+
/// Exception is zero where
358+
/// - F32 bits[25:31] = 0000 0000
359+
///
360+
/// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
361+
/// Exception is 0.5 where
362+
/// - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
363+
///
364+
/// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
365+
///
366+
/// 4) F32 bits[1:22] = 0
336367
struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
337368
using OpRewritePattern::OpRewritePattern;
338-
F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
339-
: OpRewritePattern<arith::ExtFOp>(context, benefit) {}
340369
LogicalResult matchAndRewrite(arith::ExtFOp op,
341370
PatternRewriter &rewriter) const final {
342-
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
343-
Value operand = op.getOperand();
344-
Type operandTy = operand.getType();
345-
Type resultTy = op.getType();
346-
Type operandETy = getElementTypeOrSelf(operandTy);
347-
Type resultETy = getElementTypeOrSelf(resultTy);
348-
349-
if (!isa<Float4E2M1FNType>(operandETy)) {
350-
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
351-
}
352-
353-
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
354-
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
355-
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
356-
357-
Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
358-
359-
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
360-
Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
361-
Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
362-
Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
363-
Value cZero =
364-
createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
365-
Value cHalf =
366-
createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
367-
368-
Value mantissaBitmask = c0x1;
369-
Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
370-
Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
371-
372-
Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
373-
Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
374-
f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
375-
376-
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
377-
Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
378-
f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
379-
Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
380-
f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
381-
Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
382-
f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
383-
384-
Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
385-
Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
386-
f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
387-
f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
388-
389-
// Special consideration for subnormal exponent (exp == 00).
390-
Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
391-
f32ExpBits, biasAdjustment);
392-
Value isManSet =
393-
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
394-
Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
395-
396-
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
397-
result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
398-
if (!isa<Float32Type>(resultETy)) {
399-
result = b.create<arith::TruncFOp>(resultETy, operand);
400-
}
401-
rewriter.replaceOp(op, result);
402-
return success();
403-
}
404-
};
405-
406-
struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
407-
using OpRewritePattern::OpRewritePattern;
408-
ScalarF4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 2)
409-
: OpRewritePattern<arith::ExtFOp>(context, benefit) {}
410-
LogicalResult matchAndRewrite(arith::ExtFOp op,
411-
PatternRewriter &rewriter) const final {
412-
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
371+
Location loc = op.getLoc();
372+
ImplicitLocOpBuilder b(loc, rewriter);
413373
Value operand = op.getOperand();
414374
Type operandTy = operand.getType();
415375
Type resultTy = op.getType();
416376
Type operandETy = getElementTypeOrSelf(operandTy);
417377
Type resultETy = getElementTypeOrSelf(resultTy);
418378

419-
if (isa<ShapedType>(operandTy))
420-
return failure();
421-
422379
if (!isa<Float4E2M1FNType>(operandETy))
423380
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
424381

425-
SmallVector<int> values = {
426-
0x00000000, // 0.0
427-
0x3f000000, // 0.5
428-
0x3f800000, // 1.0
429-
0x3fc00000, // 1.5
430-
0x40000000, // 2.0
431-
0x40400000, // 3.0
432-
0x40800000, // 4.0
433-
0x40c00000 // 6.0
434-
};
435-
// auto type = RankedTensorType::get({8}, b.getI32Type());
436-
VectorType type = VectorType::get({8}, b.getI32Type());
437-
SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector(
438-
values, [&](int v) -> Attribute { return b.getI32IntegerAttr(v); });
439-
Value lookupTable = b.create<arith::ConstantOp>(
440-
DenseIntElementsAttr::get(type, lookupTableAttr));
441-
442382
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
443383
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
444384
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
445-
Type i64Ty = cloneToShapedType(operandTy, b.getI64Type());
446-
447385
Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
448386

449-
Value expManBitmask = createConst(op.getLoc(), i4Ty, 0x7, rewriter);
450-
Value indexI4 = b.create<arith::AndIOp>(i4Bits, expManBitmask);
451-
Value indexI64 = b.create<arith::ExtUIOp>(i64Ty, indexI4);
452-
Value index = b.create<arith::IndexCastOp>(b.getIndexType(), indexI64);
453-
454-
Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
455-
Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
456-
Value signBitI4 = b.create<arith::AndIOp>(i4Bits, signBitmask);
457-
Value signBitI32 = b.create<arith::ExtUIOp>(i32Ty, signBitI4);
458-
signBitI32 = b.create<arith::ShLIOp>(signBitI32, c0x0000001c);
459-
460-
Value unsignedBits = b.create<vector::ExtractOp>(lookupTable, index);
461-
Value f32Bits = b.create<arith::OrIOp>(signBitI32, unsignedBits);
462-
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
387+
Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
388+
Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
389+
Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
390+
Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
391+
392+
// Set last Exponent bit and Mantissa.
393+
Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
394+
Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2);
395+
Value isHalf =
396+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
397+
bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24);
398+
bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24);
399+
bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014);
400+
401+
// Set first 7 bits of Exponent.
402+
Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
403+
Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
404+
Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
405+
Value useLargerExp =
406+
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
407+
Value bits25To31 =
408+
b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
409+
Value zeroExp =
410+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
411+
bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
412+
413+
// Set sign.
414+
Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
415+
Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
416+
Value negative =
417+
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
418+
Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits);
419+
420+
// Add segments together.
421+
Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31);
422+
Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
423+
Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
463424
if (!isa<Float32Type>(resultETy))
464425
result = b.create<arith::TruncFOp>(resultETy, operand);
465426

@@ -522,15 +483,15 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
522483
/// Table of representable values in F4E2M1:
523484
///
524485
/// Note: x is sign bit
525-
/// | Binary | Value ( + / - )
526-
/// | x000 | 0.0
527-
/// | x001 | 0.5
528-
/// | x010 | 1.0
529-
/// | x011 | 1.5
530-
/// | x100 | 2.0
531-
/// | x101 | 3.0
532-
/// | x110 | 4.0
533-
/// | x111 | 6.0
486+
/// | Binary | F4E2M1 | F32[23:32]
487+
/// | x000 | 0.0 | x000 0000 00
488+
/// | x001 | 0.5 | x011 1111 00
489+
/// | x010 | 1.0 | x011 1111 10
490+
/// | x011 | 1.5 | x011 1111 11
491+
/// | x100 | 2.0 | x010 0000 00
492+
/// | x101 | 3.0 | x010 0000 01
493+
/// | x110 | 4.0 | x010 0000 10
494+
/// | x111 | 6.0 | x010 0000 11
534495
///
535496
/// Conversion procedure:
536497
/// Step 1: Clamp to representable bounds.
@@ -543,7 +504,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
543504
using OpRewritePattern::OpRewritePattern;
544505
LogicalResult matchAndRewrite(arith::TruncFOp op,
545506
PatternRewriter &rewriter) const final {
546-
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
507+
Location loc = op.getLoc();
508+
ImplicitLocOpBuilder b(loc, rewriter);
547509
Value operand = op.getOperand();
548510
Type operandTy = operand.getType();
549511
Type resultTy = op.getType();
@@ -560,34 +522,30 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
560522
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
561523
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
562524

563-
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
564-
Value c0x3 = createConst(op->getLoc(), i4Ty, 3, rewriter);
565-
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
566-
Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
567-
Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
568-
Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
525+
Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
526+
Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
527+
Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
528+
Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
529+
Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
530+
Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
569531

570532
// Step 0: Clamp to bounds.
571-
Value cHigherBound =
572-
createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
573-
Value cLowerBound =
574-
createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
533+
Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
534+
Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
575535
Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
576536
operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
577537
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
578538

579539
// Step 1: Set sign bit.
580-
Value cF32ExpManWidth =
581-
createConst(op->getLoc(), i32Ty, 31, rewriter); // 23
540+
Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
582541
Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
583542
Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
584543
Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
585544

586545
// Step 2: Convert exponent by adjusting bias.
587-
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
588-
Value cF4MantissaWidth = c0x1; // 1
589-
Value cF32MantissaWidth =
590-
createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
546+
Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
547+
Value cF4MantissaWidth = c0x1; // 1
548+
Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
591549
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
592550
Value biasAdjustedSignExp =
593551
b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
@@ -596,39 +554,35 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
596554
f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
597555

598556
// Step 3: Set mantissa to first bit.
599-
Value cF32FirstBitMask =
600-
createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
557+
Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
601558
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
602559
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
603560
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
604561
f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
605562

606563
// Step 4: Special consideration for conversion to 0.5.
607-
Value cF32MantissaMask =
608-
createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
564+
Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
609565
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
610566
Value isSubnormal =
611567
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
612568
Value isNegOneExp =
613569
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
614570
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
615571
Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
616-
man23Bits, c0x00000000);
572+
man23Bits, zeroExpBits);
617573
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
618574
Value isZeroExp =
619575
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
620-
Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
621-
Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
576+
Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
577+
Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
622578
Value subResult =
623579
b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
624580
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
625581
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
626582

627583
// Step 5: Round up if necessary.
628-
Value cF32Last22BitMask =
629-
createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
630-
Value cRound =
631-
createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
584+
Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
585+
Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
632586
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
633587
Value shouldRound =
634588
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
@@ -848,8 +802,8 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
848802
}
849803

850804
void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
851-
patterns.add<F4E2M1ExtFOpConverter, ScalarF4E2M1ExtFOpConverter,
852-
F4E2M1TruncFOpConverter>(patterns.getContext());
805+
patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
806+
patterns.getContext());
853807
}
854808

855809
void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {

0 commit comments

Comments
 (0)