1111#include " mlir/Dialect/Vector/IR/VectorOps.h"
1212#include " mlir/IR/BuiltinTypeInterfaces.h"
1313#include " mlir/IR/ImplicitLocOpBuilder.h"
14+ #include " mlir/IR/Location.h"
1415#include " mlir/IR/TypeUtilities.h"
1516#include " mlir/Transforms/DialectConversion.h"
17+ #include " llvm/ADT/SmallVectorExtras.h"
18+ #include < cstdint>
1619
1720namespace mlir {
1821namespace arith {
@@ -34,6 +37,18 @@ static Value createConst(Location loc, Type type, int value,
3437 return rewriter.create <arith::ConstantOp>(loc, attr);
3538}
3639
40+ // / Create a float constant.
41+ static Value createFloatConst (Location loc, Type type, APFloat value,
42+ PatternRewriter &rewriter) {
43+ auto attr = rewriter.getFloatAttr (getElementTypeOrSelf (type), value);
44+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
45+ return rewriter.create <arith::ConstantOp>(
46+ loc, DenseElementsAttr::get (shapedTy, attr));
47+ }
48+
49+ return rewriter.create <arith::ConstantOp>(loc, attr);
50+ }
51+
3752// / Creates shapedType using shape from cloneFrom and base type from cloneTo
3853static Type cloneToShapedType (Type cloneFrom, Type cloneTo) {
3954 if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -322,6 +337,100 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
322337 }
323338};
324339
340+ // / In this implementation of extf we take advantage of some key patterns we
341+ // / notice between the binary representation of an F4E2M1 value and its
342+ // / corresponding value in F32.
343+ // /
344+ // / Note: x is sign bit
345+ // / | Binary | F4E2M1 | f32[23:32]
346+ // / | x000 | 0.0 | x000 0000 00
347+ // / | x001 | 0.5 | x011 1111 00
348+ // / | x010 | 1.0 | x011 1111 10
349+ // / | x011 | 1.5 | x011 1111 11
350+ // / | x100 | 2.0 | x010 0000 00
351+ // / | x101 | 3.0 | x010 0000 01
352+ // / | x110 | 4.0 | x010 0000 10
353+ // / | x111 | 6.0 | x010 0000 11
354+ // /
355+ // / 1) There are only two versions of bits [25:31] in the f32 result
356+ // / F4E2M1 bits[2:3] decide whether:
357+ // / - F32 bits[25:31] = 0011 1111
358+ // / - F32 bits[25:31] = 0010 0000
359+ // / Exception is zero where
360+ // / - F32 bits[25:31] = 0000 0000
361+ // /
362+ // / 2) F4E2M1 bits[1:2] = F32 bits[23:24]
363+ // / Exception is 0.5 where
364+ // / - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
365+ // /
366+ // / 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
367+ // /
368+ // / 4) F32 bits[1:22] = 0
369+ struct F4E2M1ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
370+ using OpRewritePattern::OpRewritePattern;
371+ LogicalResult matchAndRewrite (arith::ExtFOp op,
372+ PatternRewriter &rewriter) const final {
373+ Location loc = op.getLoc ();
374+ ImplicitLocOpBuilder b (loc, rewriter);
375+ Value operand = op.getOperand ();
376+ Type operandTy = operand.getType ();
377+ Type resultTy = op.getType ();
378+ Type operandETy = getElementTypeOrSelf (operandTy);
379+ Type resultETy = getElementTypeOrSelf (resultTy);
380+
381+ if (!isa<Float4E2M1FNType>(operandETy))
382+ return rewriter.notifyMatchFailure (op, " not a ext of F4E2M1FN" );
383+
384+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
385+ Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
386+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
387+ Value i4Bits = b.create <arith::BitcastOp>(i4Ty, operand);
388+
389+ Value c0x0 = createConst (loc, i4Ty, 0x0 , rewriter);
390+ Value c0x1 = createConst (loc, i4Ty, 0x1 , rewriter);
391+ Value c0x2 = createConst (loc, i4Ty, 0x2 , rewriter);
392+ Value c0x4 = createConst (loc, i4Ty, 0x4 , rewriter);
393+
394+ // Set last Exponent bit and Mantissa.
395+ Value c0x00000014 = createConst (loc, i32Ty, 0x14 , rewriter);
396+ Value bits1To24 = b.create <arith::ShLIOp>(i4Bits, c0x2);
397+ Value isHalf =
398+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
399+ bits1To24 = b.create <arith::SelectOp>(isHalf, c0x0, bits1To24);
400+ bits1To24 = b.create <arith::ExtUIOp>(i32Ty, bits1To24);
401+ bits1To24 = b.create <arith::ShLIOp>(bits1To24, c0x00000014);
402+
403+ // Set first 7 bits of Exponent.
404+ Value zeroExpBits = createConst (loc, i32Ty, 0x00000000 , rewriter);
405+ Value highExpBits = createConst (loc, i32Ty, 0x40000000 , rewriter);
406+ Value lowExpBits = createConst (loc, i32Ty, 0x3f000000 , rewriter);
407+ Value useLargerExp =
408+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
409+ Value bits25To31 =
410+ b.create <arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
411+ Value zeroExp =
412+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
413+ bits25To31 = b.create <arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
414+
415+ // Set sign.
416+ Value c0x80000000 = createConst (loc, i32Ty, 0x80000000 , rewriter);
417+ Value c0x8 = createConst (loc, i4Ty, 0x8 , rewriter);
418+ Value negative =
419+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
420+ Value bit32 = b.create <arith::SelectOp>(negative, c0x80000000, zeroExpBits);
421+
422+ // Add segments together.
423+ Value bits1To31 = b.create <arith::AddIOp>(bits1To24, bits25To31);
424+ Value bits1To32 = b.create <arith::AddIOp>(bits1To31, bit32);
425+ Value result = b.create <arith::BitcastOp>(f32Ty, bits1To32);
426+ if (!isa<Float32Type>(resultETy))
427+ result = b.create <arith::TruncFOp>(resultTy, result);
428+
429+ rewriter.replaceOp (op, result);
430+ return success ();
431+ }
432+ };
433+
325434struct F8E8M0ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
326435 using OpRewritePattern::OpRewritePattern;
327436 LogicalResult matchAndRewrite (arith::ExtFOp op,
@@ -366,6 +475,130 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
366475 }
367476};
368477
478+ // / Conversion from F32 to F4E2M1 according to the OCP Spec:
479+ // / www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
480+ // /
481+ // / The spec requiers us to perform Round to Nearest, Ties to Even.
482+ // /
483+ // / This means that after rounding, we should break ties by choosing the option
484+ // / which results in a mantissa of 0 in the least significant digit.
485+ // /
486+ // / Table of representable values in F4E2M1:
487+ // /
488+ // / Note: x is sign bit
489+ // / | Binary | F4E2M1 | F32[23:32]
490+ // / | x000 | 0.0 | x000 0000 00
491+ // / | x001 | 0.5 | x011 1111 00
492+ // / | x010 | 1.0 | x011 1111 10
493+ // / | x011 | 1.5 | x011 1111 11
494+ // / | x100 | 2.0 | x010 0000 00
495+ // / | x101 | 3.0 | x010 0000 01
496+ // / | x110 | 4.0 | x010 0000 10
497+ // / | x111 | 6.0 | x010 0000 11
498+ // /
499+ // / Conversion procedure:
500+ // / Step 1: Clamp to representable bounds.
501+ // / Step 2: Convert exponent by adjusting bias.
502+ // / Step 3: Set mantissa to first bit.
503+ // / Step 4: Special consideration for subnormal and zero exponent.
504+ // / Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
505+ // / subnormal.
506+ struct F4E2M1TruncFOpConverter : public OpRewritePattern <arith::TruncFOp> {
507+ using OpRewritePattern::OpRewritePattern;
508+ LogicalResult matchAndRewrite (arith::TruncFOp op,
509+ PatternRewriter &rewriter) const final {
510+ Location loc = op.getLoc ();
511+ ImplicitLocOpBuilder b (loc, rewriter);
512+ Value operand = op.getOperand ();
513+ Type operandTy = operand.getType ();
514+ Type resultTy = op.getType ();
515+ Type operandETy = getElementTypeOrSelf (operandTy);
516+ Type resultETy = getElementTypeOrSelf (resultTy);
517+
518+ Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
519+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
520+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
521+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
522+
523+ if (!isa<Float32Type>(operandETy))
524+ operand = b.create <arith::ExtFOp>(f32Ty, operand);
525+ if (!isa<Float4E2M1FNType>(resultETy))
526+ return rewriter.notifyMatchFailure (op, " not a trunc of F4E2M1FN" );
527+
528+ Value c0x1 = createConst (loc, i4Ty, 1 , rewriter);
529+ Value c0x3 = createConst (loc, i4Ty, 3 , rewriter);
530+ Value c0x00000016 = createConst (loc, i32Ty, 22 , rewriter);
531+ Value c0x00 = createConst (loc, i8Ty, 0x00 , rewriter);
532+ Value c0xff = createConst (loc, i8Ty, 0xff , rewriter);
533+ Value zeroExpBits = createConst (loc, i32Ty, 0 , rewriter);
534+
535+ // Step 0: Clamp to bounds.
536+ Value cHigherBound = createFloatConst (loc, f32Ty, APFloat (6 .0f ), rewriter);
537+ Value cLowerBound = createFloatConst (loc, f32Ty, APFloat (-6 .0f ), rewriter);
538+ Value operandClamped = b.create <arith::MinNumFOp>(cHigherBound, operand);
539+ operandClamped = b.create <arith::MaxNumFOp>(cLowerBound, operandClamped);
540+ Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operandClamped);
541+
542+ // Step 1: Set sign bit.
543+ Value cF32ExpManWidth = createConst (loc, i32Ty, 31 , rewriter); // 23
544+ Value f32Sign = b.create <arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
545+ Value f4Sign = b.create <arith::TruncIOp>(i4Ty, f32Sign);
546+ Value f4Bits = b.create <arith::ShLIOp>(f4Sign, c0x3);
547+
548+ // Step 2: Convert exponent by adjusting bias.
549+ Value biasAdjustment = createConst (loc, i32Ty, 0x7e , rewriter);
550+ Value cF4MantissaWidth = c0x1; // 1
551+ Value cF32MantissaWidth = createConst (loc, i32Ty, 23 , rewriter); // 23
552+ Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
553+ Value biasAdjustedSignExp =
554+ b.create <arith::SubIOp>(f32SignExp, biasAdjustment);
555+ Value f4Exp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
556+ f4Exp = b.create <arith::ShLIOp>(f4Exp, cF4MantissaWidth);
557+ f4Bits = b.create <arith::AddIOp>(f4Bits, f4Exp);
558+
559+ // Step 3: Set mantissa to first bit.
560+ Value cF32FirstBitMask = createConst (loc, i32Ty, 0x400000 , rewriter);
561+ Value man1Bit = b.create <arith::AndIOp>(f32Bits, cF32FirstBitMask);
562+ man1Bit = b.create <arith::ShRUIOp>(man1Bit, c0x00000016);
563+ Value f4Man = b.create <arith::TruncIOp>(i4Ty, man1Bit);
564+ f4Bits = b.create <arith::AddIOp>(f4Bits, f4Man);
565+
566+ // Step 4: Special consideration for conversion to 0.5.
567+ Value cF32MantissaMask = createConst (loc, i32Ty, 0x7fffff , rewriter);
568+ Value f8Exp = b.create <arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
569+ Value isSubnormal =
570+ b.create <arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
571+ Value isNegOneExp =
572+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
573+ Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32MantissaMask);
574+ Value isNonZeroMan = b.create <arith::CmpIOp>(arith::CmpIPredicate::ugt,
575+ man23Bits, zeroExpBits);
576+ Value roundToHalf = b.create <arith::AndIOp>(isNegOneExp, isNonZeroMan);
577+ Value isZeroExp =
578+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
579+ Value subnormalF4Bits = createConst (loc, i4Ty, 0xf , rewriter);
580+ Value halfF4Bits = createConst (loc, i4Ty, 0x0 , rewriter);
581+ Value subResult =
582+ b.create <arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
583+ subResult = b.create <arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
584+ f4Bits = b.create <arith::SelectOp>(isZeroExp, f4Bits, subResult);
585+
586+ // Step 5: Round up if necessary.
587+ Value cF32Last22BitMask = createConst (loc, i32Ty, 0x3fffff , rewriter);
588+ Value cRound = createConst (loc, i32Ty, 0x200000 , rewriter); // 010 0000...
589+ Value man22Bits = b.create <arith::AndIOp>(f32Bits, cF32Last22BitMask);
590+ Value shouldRound =
591+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
592+ shouldRound = b.create <arith::OrIOp>(shouldRound, isSubnormal);
593+ Value roundedF4Bits = b.create <arith::AddIOp>(f4Bits, c0x1);
594+ f4Bits = b.create <arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
595+
596+ Value result = b.create <arith::BitcastOp>(resultTy, f4Bits);
597+ rewriter.replaceOp (op, result);
598+ return success ();
599+ }
600+ };
601+
369602/*
370603TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
371604Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
@@ -498,6 +731,8 @@ struct ArithExpandOpsPass
498731 arith::populateArithExpandOpsPatterns (patterns);
499732
500733 target.addLegalDialect <arith::ArithDialect>();
734+ target.addLegalDialect <vector::VectorDialect>();
735+
501736 // clang-format off
502737 target.addIllegalOp <
503738 arith::CeilDivSIOp,
@@ -515,22 +750,24 @@ struct ArithExpandOpsPass
515750 arith::ScalingTruncFOp
516751 >();
517752
518- if (includeBf16) {
753+ if (includeBf16)
519754 arith::populateExpandBFloat16Patterns (patterns);
520- }
521- if (includeF8E8M0) {
755+ if (includeF8E8M0)
522756 arith::populateExpandF8E8M0Patterns (patterns);
523- }
757+ if (includeF4E2M1)
758+ arith::populateExpandF4E2M1Patterns (patterns);
524759
525760 target.addDynamicallyLegalOp <arith::ExtFOp>(
526761 [=](arith::ExtFOp op) {
527762 Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
528763 Type outETy = getElementTypeOrSelf (op.getType ());
529764 bool legalTypes = true ;
530- if (includeBf16)
765+ if (includeBf16)
531766 legalTypes &= !(inETy.isBF16 () && outETy.isF32 ());
532767 if (includeF8E8M0)
533768 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
769+ if (includeF4E2M1)
770+ legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
534771 return legalTypes;
535772 });
536773
@@ -539,10 +776,12 @@ struct ArithExpandOpsPass
539776 Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
540777 Type outETy = getElementTypeOrSelf (op.getType ());
541778 bool legalTypes = true ;
542- if (includeBf16)
779+ if (includeBf16)
543780 legalTypes &= !(inETy.isF32 () && outETy.isBF16 ());
544- if (includeF8E8M0)
781+ if (includeF8E8M0)
545782 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
783+ if (includeF4E2M1)
784+ legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
546785 return legalTypes;
547786 });
548787
@@ -567,6 +806,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
567806 patterns.getContext ());
568807}
569808
809+ void mlir::arith::populateExpandF4E2M1Patterns (RewritePatternSet &patterns) {
810+ patterns.add <F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
811+ patterns.getContext ());
812+ }
813+
570814void mlir::arith::populateExpandF8E8M0Patterns (RewritePatternSet &patterns) {
571815 patterns.add <F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
572816 patterns.getContext ());
0 commit comments