1111#include " mlir/Dialect/Vector/IR/VectorOps.h"
1212#include " mlir/IR/BuiltinTypeInterfaces.h"
1313#include " mlir/IR/ImplicitLocOpBuilder.h"
14+ #include " mlir/IR/Location.h"
1415#include " mlir/IR/TypeUtilities.h"
1516#include " mlir/Transforms/DialectConversion.h"
1617#include " llvm/ADT/SmallVectorExtras.h"
18+ #include < cstdint>
1719
1820namespace mlir {
1921namespace arith {
@@ -333,133 +335,92 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
333335 }
334336};
335337
338+ // / In this implementation of extf we take advantage of some key patterns we
339+ // / notice between the binary representation of an F4E2M1 value and its
340+ // / corresponding value in F32.
341+ // /
342+ // / Note: x is sign bit
343+ // / | Binary | F4E2M1 | f32[23:32]
344+ // / | x000 | 0.0 | x000 0000 00
345+ // / | x001 | 0.5 | x011 1111 00
346+ // / | x010 | 1.0 | x011 1111 10
347+ // / | x011 | 1.5 | x011 1111 11
348+ // / | x100 | 2.0 | x010 0000 00
349+ // / | x101 | 3.0 | x010 0000 01
350+ // / | x110 | 4.0 | x010 0000 10
351+ // / | x111 | 6.0 | x010 0000 11
352+ // /
353+ // / 1) There are only two versions of bits [25:31] in the f32 result
354+ // / F4E2M1 bits[2:3] decide whether:
355+ // / - F32 bits[25:31] = 0011 1111
356+ // / - F32 bits[25:31] = 0010 0000
357+ // / Exception is zero where
358+ // / - F32 bits[25:31] = 0000 0000
359+ // /
360+ // / 2) F4E2M1 bits[1:2] = F32 bits[23:24]
361+ // / Exception is 0.5 where
362+ // / - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
363+ // /
364+ // / 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
365+ // /
366+ // / 4) F32 bits[1:22] = 0
336367struct F4E2M1ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
337368 using OpRewritePattern::OpRewritePattern;
338- F4E2M1ExtFOpConverter (MLIRContext *context, PatternBenefit benefit = 1 )
339- : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
340369 LogicalResult matchAndRewrite (arith::ExtFOp op,
341370 PatternRewriter &rewriter) const final {
342- ImplicitLocOpBuilder b (op.getLoc (), rewriter);
343- Value operand = op.getOperand ();
344- Type operandTy = operand.getType ();
345- Type resultTy = op.getType ();
346- Type operandETy = getElementTypeOrSelf (operandTy);
347- Type resultETy = getElementTypeOrSelf (resultTy);
348-
349- if (!isa<Float4E2M1FNType>(operandETy)) {
350- return rewriter.notifyMatchFailure (op, " not a ext of F4E2M1FN" );
351- }
352-
353- Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
354- Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
355- Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
356-
357- Value bitcast = b.create <arith::BitcastOp>(i4Ty, operand);
358-
359- Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
360- Value c0x00000014 = createConst (op->getLoc (), i32Ty, 22 , rewriter);
361- Value c0x00000015 = createConst (op->getLoc (), i32Ty, 23 , rewriter);
362- Value c0x0000001c = createConst (op->getLoc (), i32Ty, 28 , rewriter);
363- Value cZero =
364- createFloatConst (op->getLoc (), f32Ty, APFloat (0 .0f ), rewriter);
365- Value cHalf =
366- createFloatConst (op->getLoc (), f32Ty, APFloat (0 .5f ), rewriter);
367-
368- Value mantissaBitmask = c0x1;
369- Value exponentBitmask = createConst (op.getLoc (), i4Ty, 0x6 , rewriter);
370- Value signBitmask = createConst (op.getLoc (), i4Ty, 0x8 , rewriter);
371-
372- Value f4SignBit = b.create <arith::AndIOp>(bitcast, signBitmask);
373- Value f32Bits = b.create <arith::ExtUIOp>(i32Ty, f4SignBit);
374- f32Bits = b.create <arith::ShLIOp>(f32Bits, c0x0000001c);
375-
376- Value biasAdjustment = createConst (op.getLoc (), i32Ty, 126 , rewriter);
377- Value f4ExpBits = b.create <arith::AndIOp>(bitcast, exponentBitmask);
378- f4ExpBits = b.create <arith::ShRUIOp>(f4ExpBits, c0x1);
379- Value f32ExpBits = b.create <arith::ExtUIOp>(i32Ty, f4ExpBits);
380- f32ExpBits = b.create <arith::AddIOp>(f32ExpBits, biasAdjustment);
381- Value f32Exp = b.create <arith::ShLIOp>(f32ExpBits, c0x00000015);
382- f32Bits = b.create <arith::AddIOp>(f32Bits, f32Exp);
383-
384- Value f4ManBit = b.create <arith::AndIOp>(bitcast, mantissaBitmask);
385- Value f32ManBit = b.create <arith::ExtUIOp>(i32Ty, f4ManBit);
386- f32ManBit = b.create <arith::ShLIOp>(f32ManBit, c0x00000014);
387- f32Bits = b.create <arith::AddIOp>(f32Bits, f32ManBit);
388-
389- // Special consideration for subnormal exponent (exp == 00).
390- Value isSubnormal = b.create <arith::CmpIOp>(arith::CmpIPredicate::eq,
391- f32ExpBits, biasAdjustment);
392- Value isManSet =
393- b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
394- Value subnormalVal = b.create <arith::SelectOp>(isManSet, cHalf, cZero);
395-
396- Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
397- result = b.create <arith::SelectOp>(isSubnormal, subnormalVal, result);
398- if (!isa<Float32Type>(resultETy)) {
399- result = b.create <arith::TruncFOp>(resultETy, operand);
400- }
401- rewriter.replaceOp (op, result);
402- return success ();
403- }
404- };
405-
406- struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
407- using OpRewritePattern::OpRewritePattern;
408- ScalarF4E2M1ExtFOpConverter (MLIRContext *context, PatternBenefit benefit = 2 )
409- : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
410- LogicalResult matchAndRewrite (arith::ExtFOp op,
411- PatternRewriter &rewriter) const final {
412- ImplicitLocOpBuilder b (op.getLoc (), rewriter);
371+ Location loc = op.getLoc ();
372+ ImplicitLocOpBuilder b (loc, rewriter);
413373 Value operand = op.getOperand ();
414374 Type operandTy = operand.getType ();
415375 Type resultTy = op.getType ();
416376 Type operandETy = getElementTypeOrSelf (operandTy);
417377 Type resultETy = getElementTypeOrSelf (resultTy);
418378
419- if (isa<ShapedType>(operandTy))
420- return failure ();
421-
422379 if (!isa<Float4E2M1FNType>(operandETy))
423380 return rewriter.notifyMatchFailure (op, " not a ext of F4E2M1FN" );
424381
425- SmallVector<int > values = {
426- 0x00000000 , // 0.0
427- 0x3f000000 , // 0.5
428- 0x3f800000 , // 1.0
429- 0x3fc00000 , // 1.5
430- 0x40000000 , // 2.0
431- 0x40400000 , // 3.0
432- 0x40800000 , // 4.0
433- 0x40c00000 // 6.0
434- };
435- // auto type = RankedTensorType::get({8}, b.getI32Type());
436- VectorType type = VectorType::get ({8 }, b.getI32Type ());
437- SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector (
438- values, [&](int v) -> Attribute { return b.getI32IntegerAttr (v); });
439- Value lookupTable = b.create <arith::ConstantOp>(
440- DenseIntElementsAttr::get (type, lookupTableAttr));
441-
442382 Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
443383 Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
444384 Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
445- Type i64Ty = cloneToShapedType (operandTy, b.getI64Type ());
446-
447385 Value i4Bits = b.create <arith::BitcastOp>(i4Ty, operand);
448386
449- Value expManBitmask = createConst (op.getLoc (), i4Ty, 0x7 , rewriter);
450- Value indexI4 = b.create <arith::AndIOp>(i4Bits, expManBitmask);
451- Value indexI64 = b.create <arith::ExtUIOp>(i64Ty, indexI4);
452- Value index = b.create <arith::IndexCastOp>(b.getIndexType (), indexI64);
453-
454- Value c0x0000001c = createConst (op->getLoc (), i32Ty, 28 , rewriter);
455- Value signBitmask = createConst (op.getLoc (), i4Ty, 0x8 , rewriter);
456- Value signBitI4 = b.create <arith::AndIOp>(i4Bits, signBitmask);
457- Value signBitI32 = b.create <arith::ExtUIOp>(i32Ty, signBitI4);
458- signBitI32 = b.create <arith::ShLIOp>(signBitI32, c0x0000001c);
459-
460- Value unsignedBits = b.create <vector::ExtractOp>(lookupTable, index);
461- Value f32Bits = b.create <arith::OrIOp>(signBitI32, unsignedBits);
462- Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
387+ Value c0x0 = createConst (loc, i4Ty, 0x0 , rewriter);
388+ Value c0x1 = createConst (loc, i4Ty, 0x1 , rewriter);
389+ Value c0x2 = createConst (loc, i4Ty, 0x2 , rewriter);
390+ Value c0x4 = createConst (loc, i4Ty, 0x4 , rewriter);
391+
392+ // Set last Exponent bit and Mantissa.
393+ Value c0x00000014 = createConst (loc, i32Ty, 0x14 , rewriter);
394+ Value bits1To24 = b.create <arith::ShLIOp>(i4Bits, c0x2);
395+ Value isHalf =
396+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
397+ bits1To24 = b.create <arith::SelectOp>(isHalf, c0x0, bits1To24);
398+ bits1To24 = b.create <arith::ExtUIOp>(i32Ty, bits1To24);
399+ bits1To24 = b.create <arith::ShLIOp>(bits1To24, c0x00000014);
400+
401+ // Set first 7 bits of Exponent.
402+ Value zeroExpBits = createConst (loc, i32Ty, 0x00000000 , rewriter);
403+ Value highExpBits = createConst (loc, i32Ty, 0x40000000 , rewriter);
404+ Value lowExpBits = createConst (loc, i32Ty, 0x3f000000 , rewriter);
405+ Value useLargerExp =
406+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
407+ Value bits25To31 =
408+ b.create <arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
409+ Value zeroExp =
410+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
411+ bits25To31 = b.create <arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
412+
413+ // Set sign.
414+ Value c0x80000000 = createConst (loc, i32Ty, 0x80000000 , rewriter);
415+ Value c0x8 = createConst (loc, i4Ty, 0x8 , rewriter);
416+ Value negative =
417+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
418+ Value bit32 = b.create <arith::SelectOp>(negative, c0x80000000, zeroExpBits);
419+
420+ // Add segments together.
421+ Value bits1To31 = b.create <arith::AddIOp>(bits1To24, bits25To31);
422+ Value bits1To32 = b.create <arith::AddIOp>(bits1To31, bit32);
423+ Value result = b.create <arith::BitcastOp>(f32Ty, bits1To32);
463424 if (!isa<Float32Type>(resultETy))
464425 result = b.create <arith::TruncFOp>(resultETy, operand);
465426
@@ -522,15 +483,15 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
522483// / Table of representable values in F4E2M1:
523484// /
524485// / Note: x is sign bit
525- // / | Binary | Value ( + / - )
526- // / | x000 | 0.0
527- // / | x001 | 0.5
528- // / | x010 | 1.0
529- // / | x011 | 1.5
530- // / | x100 | 2.0
531- // / | x101 | 3.0
532- // / | x110 | 4.0
533- // / | x111 | 6.0
486+ // / | Binary | F4E2M1 | F32[23:32]
487+ // / | x000 | 0.0 | x000 0000 00
488+ // / | x001 | 0.5 | x011 1111 00
489+ // / | x010 | 1.0 | x011 1111 10
490+ // / | x011 | 1.5 | x011 1111 11
491+ // / | x100 | 2.0 | x010 0000 00
492+ // / | x101 | 3.0 | x010 0000 01
493+ // / | x110 | 4.0 | x010 0000 10
494+ // / | x111 | 6.0 | x010 0000 11
534495// /
535496// / Conversion procedure:
536497// / Step 1: Clamp to representable bounds.
@@ -543,7 +504,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
543504 using OpRewritePattern::OpRewritePattern;
544505 LogicalResult matchAndRewrite (arith::TruncFOp op,
545506 PatternRewriter &rewriter) const final {
546- ImplicitLocOpBuilder b (op.getLoc (), rewriter);
507+ Location loc = op.getLoc ();
508+ ImplicitLocOpBuilder b (loc, rewriter);
547509 Value operand = op.getOperand ();
548510 Type operandTy = operand.getType ();
549511 Type resultTy = op.getType ();
@@ -560,34 +522,30 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
560522 Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
561523 Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
562524
563- Value c0x1 = createConst (op-> getLoc () , i4Ty, 1 , rewriter);
564- Value c0x3 = createConst (op-> getLoc () , i4Ty, 3 , rewriter);
565- Value c0x00000016 = createConst (op-> getLoc () , i32Ty, 22 , rewriter);
566- Value c0x00 = createConst (op. getLoc () , i8Ty, 0x00 , rewriter);
567- Value c0xff = createConst (op. getLoc () , i8Ty, 0xff , rewriter);
568- Value c0x00000000 = createConst (op. getLoc () , i32Ty, 0 , rewriter);
525+ Value c0x1 = createConst (loc , i4Ty, 1 , rewriter);
526+ Value c0x3 = createConst (loc , i4Ty, 3 , rewriter);
527+ Value c0x00000016 = createConst (loc , i32Ty, 22 , rewriter);
528+ Value c0x00 = createConst (loc , i8Ty, 0x00 , rewriter);
529+ Value c0xff = createConst (loc , i8Ty, 0xff , rewriter);
530+ Value zeroExpBits = createConst (loc , i32Ty, 0 , rewriter);
569531
570532 // Step 0: Clamp to bounds.
571- Value cHigherBound =
572- createFloatConst (op->getLoc (), f32Ty, APFloat (6 .0f ), rewriter);
573- Value cLowerBound =
574- createFloatConst (op->getLoc (), f32Ty, APFloat (-6 .0f ), rewriter);
533+ Value cHigherBound = createFloatConst (loc, f32Ty, APFloat (6 .0f ), rewriter);
534+ Value cLowerBound = createFloatConst (loc, f32Ty, APFloat (-6 .0f ), rewriter);
575535 Value operandClamped = b.create <arith::MinNumFOp>(cHigherBound, operand);
576536 operandClamped = b.create <arith::MaxNumFOp>(cLowerBound, operandClamped);
577537 Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operandClamped);
578538
579539 // Step 1: Set sign bit.
580- Value cF32ExpManWidth =
581- createConst (op->getLoc (), i32Ty, 31 , rewriter); // 23
540+ Value cF32ExpManWidth = createConst (loc, i32Ty, 31 , rewriter); // 23
582541 Value f32Sign = b.create <arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
583542 Value f4Sign = b.create <arith::TruncIOp>(i4Ty, f32Sign);
584543 Value f4Bits = b.create <arith::ShLIOp>(f4Sign, c0x3);
585544
586545 // Step 2: Convert exponent by adjusting bias.
587- Value biasAdjustment = createConst (op.getLoc (), i32Ty, 0x7e , rewriter);
588- Value cF4MantissaWidth = c0x1; // 1
589- Value cF32MantissaWidth =
590- createConst (op->getLoc (), i32Ty, 23 , rewriter); // 23
546+ Value biasAdjustment = createConst (loc, i32Ty, 0x7e , rewriter);
547+ Value cF4MantissaWidth = c0x1; // 1
548+ Value cF32MantissaWidth = createConst (loc, i32Ty, 23 , rewriter); // 23
591549 Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
592550 Value biasAdjustedSignExp =
593551 b.create <arith::SubIOp>(f32SignExp, biasAdjustment);
@@ -596,39 +554,35 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
596554 f4Bits = b.create <arith::AddIOp>(f4Bits, f4Exp);
597555
598556 // Step 3: Set mantissa to first bit.
599- Value cF32FirstBitMask =
600- createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
557+ Value cF32FirstBitMask = createConst (loc, i32Ty, 0x400000 , rewriter);
601558 Value man1Bit = b.create <arith::AndIOp>(f32Bits, cF32FirstBitMask);
602559 man1Bit = b.create <arith::ShRUIOp>(man1Bit, c0x00000016);
603560 Value f4Man = b.create <arith::TruncIOp>(i4Ty, man1Bit);
604561 f4Bits = b.create <arith::AddIOp>(f4Bits, f4Man);
605562
606563 // Step 4: Special consideration for conversion to 0.5.
607- Value cF32MantissaMask =
608- createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
564+ Value cF32MantissaMask = createConst (loc, i32Ty, 0x7fffff , rewriter);
609565 Value f8Exp = b.create <arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
610566 Value isSubnormal =
611567 b.create <arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
612568 Value isNegOneExp =
613569 b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
614570 Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32MantissaMask);
615571 Value isNonZeroMan = b.create <arith::CmpIOp>(arith::CmpIPredicate::ugt,
616- man23Bits, c0x00000000 );
572+ man23Bits, zeroExpBits );
617573 Value roundToHalf = b.create <arith::AndIOp>(isNegOneExp, isNonZeroMan);
618574 Value isZeroExp =
619575 b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
620- Value subnormalF4Bits = createConst (op-> getLoc () , i4Ty, 0xf , rewriter);
621- Value halfF4Bits = createConst (op-> getLoc () , i4Ty, 0x0 , rewriter);
576+ Value subnormalF4Bits = createConst (loc , i4Ty, 0xf , rewriter);
577+ Value halfF4Bits = createConst (loc , i4Ty, 0x0 , rewriter);
622578 Value subResult =
623579 b.create <arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
624580 subResult = b.create <arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
625581 f4Bits = b.create <arith::SelectOp>(isZeroExp, f4Bits, subResult);
626582
627583 // Step 5: Round up if necessary.
628- Value cF32Last22BitMask =
629- createConst (op->getLoc (), i32Ty, 0x3fffff , rewriter);
630- Value cRound =
631- createConst (op.getLoc (), i32Ty, 0x200000 , rewriter); // 010 0000...
584+ Value cF32Last22BitMask = createConst (loc, i32Ty, 0x3fffff , rewriter);
585+ Value cRound = createConst (loc, i32Ty, 0x200000 , rewriter); // 010 0000...
632586 Value man22Bits = b.create <arith::AndIOp>(f32Bits, cF32Last22BitMask);
633587 Value shouldRound =
634588 b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
@@ -848,8 +802,8 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
848802}
849803
850804void mlir::arith::populateExpandF4E2M1Patterns (RewritePatternSet &patterns) {
851- patterns.add <F4E2M1ExtFOpConverter, ScalarF4E2M1ExtFOpConverter,
852- F4E2M1TruncFOpConverter>( patterns.getContext ());
805+ patterns.add <F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
806+ patterns.getContext ());
853807}
854808
855809void mlir::arith::populateExpandF8E8M0Patterns (RewritePatternSet &patterns) {
0 commit comments