@@ -334,57 +334,6 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
334334 }
335335};
336336
337- struct F4E2M1ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
338- using OpRewritePattern::OpRewritePattern;
339- LogicalResult matchAndRewrite (arith::ExtFOp op,
340- PatternRewriter &rewriter) const final {
341- ImplicitLocOpBuilder b (op.getLoc (), rewriter);
342- Value operand = op.getOperand ();
343- Type operandTy = operand.getType ();
344- Type resultTy = op.getType ();
345- Type operandETy = getElementTypeOrSelf (operandTy);
346- Type resultETy = getElementTypeOrSelf (resultTy);
347-
348- if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
349- return rewriter.notifyMatchFailure (op, " not a ext of F4E2M1FN" );
350- }
351-
352- Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
353- Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
354- Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
355-
356- Value bitcast = b.create <arith::BitcastOp>(i4Ty, operand);
357-
358- // create constants to extract mantissa / exponent
359- Value cF4MantissaWidth = createConst (op->getLoc (), i32Ty, 1 , rewriter);
360- Value cF4SignAndExpWidth = createConst (op->getLoc (), i32Ty, 3 , rewriter);
361- // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
362-
363- // create constants for NaNs
364- Value cF4NaN = createConst (op.getLoc (), i4Ty, 0xf , rewriter);
365- Value cF32NaN = createConst (op.getLoc (), i32Ty, 0xffffffff , rewriter);
366- Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
367-
368- Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
369- Value f32Exponent = b.create <arith::ShLIOp>(exti, cF4MantissaWidth);
370- Value f32Mantissa = b.create <arith::ShRUIOp>(exti, cF4SignAndExpWidth);
371- Value f32Bits = b.create <arith::ShLIOp>(exti, cF32MantissaWidth);
372-
373- Value isNan =
374- b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
375- // select for NaNs
376- f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
377- Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
378- if (resultETy.getIntOrFloatBitWidth () < 32 ) {
379- result = b.create <arith::TruncFOp>(resultTy, result);
380- } else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
381- result = b.create <arith::ExtFOp>(resultTy, result);
382- }
383- rewriter.replaceOp (op, result);
384- return success ();
385- }
386- };
387-
388337struct F8E8M0ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
389338 using OpRewritePattern::OpRewritePattern;
390339 LogicalResult matchAndRewrite (arith::ExtFOp op,
@@ -428,60 +377,34 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
428377 return success ();
429378 }
430379};
431- /*
432- Conversion from F32 to F4E2M1 according to the OCP Spec:
433- www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
434-
435- The spec requiers us to perform Round to Nearest, Ties to Even.
436-
437- This means that after rounding, we should break ties by choosing the option
438- which results in a mantissa of 0 in the least significant digit.
439-
440- Table of representable values in F4E2M1:
441-
442- Note: x is sign bit
443- | Binary | Value ( + / - )
444- | x000 | 0.0
445- | x001 | 0.5
446- | x010 | 1.0
447- | x011 | 1.5
448- | x100 | 2.0
449- | x101 | 3.0
450- | x110 | 4.0
451- | x111 | 6.0
452-
453- Conversion procedure:
454-
455- Step 1: Clamp to max f4 value
456-
457- Step 2: convert exponent, if signed int comparison <= 0, set 0
458-
459- Step 3: if mantissa[1:] greater than 1000000, add 1
460-
461- Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
462- Create bias adjusted exponent, E_1 <- E_0 - 126
463- If E_0 <= 0111 1110
464- M_1 <- 0, E_1 <- 00
465- end
466- if E_1 == 00 (special case for almost subnormal)
467- if we must round up (M_0 >= 10000000000000000000000)
468- M_1 <- 0
469- E_1 <- 01
470- else
471- M_1 <- 1
472- end
473- Else if E_1 > 00
474- roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
475- if roundToEven
476- M_1 <- 0
477- else
478- M_1 <- 1
479- If M_0 >= 11000000000000000000000
480- increment E_1
481- If E_1 > 11 (saturate if beyond range)
482- M_1 <- 1, E_1 <- 11
483- end
484- */
380+
381+ // / Conversion from F32 to F4E2M1 according to the OCP Spec:
382+ // / www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
383+ // /
384+ // / The spec requiers us to perform Round to Nearest, Ties to Even.
385+ // /
386+ // / This means that after rounding, we should break ties by choosing the option
387+ // / which results in a mantissa of 0 in the least significant digit.
388+ // /
389+ // / Table of representable values in F4E2M1:
390+ // /
391+ // / Note: x is sign bit
392+ // / | Binary | Value ( + / - )
393+ // / | x000 | 0.0
394+ // / | x001 | 0.5
395+ // / | x010 | 1.0
396+ // / | x011 | 1.5
397+ // / | x100 | 2.0
398+ // / | x101 | 3.0
399+ // / | x110 | 4.0
400+ // / | x111 | 6.0
401+ // /
402+ // / Conversion procedure:
403+ // / Step 1: Clamp to representable bounds.
404+ // / Step 2: Convert exponent by adjusting bias.
405+ // / Step 3: Set mantissa to first bit.
406+ // / Step 4: Special consideration for subnormal and zero exponent.
407+ // / Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
485408struct F4E2M1TruncFOpConverter : public OpRewritePattern <arith::TruncFOp> {
486409 using OpRewritePattern::OpRewritePattern;
487410 LogicalResult matchAndRewrite (arith::TruncFOp op,
@@ -504,122 +427,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
504427
505428 // Constants
506429 Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
507- Value c0x7e = createConst (op.getLoc (), i8Ty, 0x7e , rewriter);
508- Value c0x0000007e = createConst (op.getLoc (), i32Ty, 0x7e , rewriter);
509-
510- Value c0x00000009 = createConst (op->getLoc (), i32Ty, 9 , rewriter);
511- Value c0x00000016 = createConst (op->getLoc (), i32Ty, 22 , rewriter);
512- Value c0x00000017 = createConst (op->getLoc (), i32Ty, 23 , rewriter);
513- Value c0x0000001f = createConst (op->getLoc (), i32Ty, 31 , rewriter);
514- Value c0x00200000 = createConst (op.getLoc (), i32Ty, 0x200000 , rewriter);
515- Value c0x00400000 = createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
516- Value c0x00600000 = createConst (op.getLoc (), i32Ty, 0x600000 , rewriter);
517- Value c0x003fffff = createConst (op->getLoc (), i32Ty, 0x3fffff , rewriter);
518- Value c0x007fffff = createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
519- Value cF32MantissaWidth = c0x00000017; // 23
430+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter); // 23
520431 Value cF4MantissaWidth = c0x1; // 1
521- Value cF32SignExpWidth = c0x00000009; // 9
522- Value cF32FirstBitMask = c0x00400000;
523- Value cF32Last22BitMask = c0x003fffff;
524- Value cF32MantissaMask = c0x007fffff;
525-
432+ Value cF32FirstBitMask = createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
433+ Value c0x00000016 = createConst (op->getLoc (), i32Ty, 22 , rewriter);
434+ Value c0x00 = createConst (op.getLoc (), i8Ty, 0x00 , rewriter);
435+ Value c0xff = createConst (op.getLoc (), i8Ty, 0xff , rewriter);
436+ Value cF32MantissaMask = createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
437+ Value c0x00000000 = createConst (op.getLoc (), i32Ty, 0 , rewriter);
438+ Value cF32Last22BitMask = createConst (op->getLoc (), i32Ty, 0x3fffff , rewriter);;
439+
526440 // Step 1: Clamp to bounds.
527441 Value cHigherBound = createFloatConst (op->getLoc (), f32Ty, 6.0 , rewriter);
528442 Value cLowerBound = createFloatConst (op->getLoc (), f32Ty, -6.0 , rewriter);
529- Value clampHigh = b.create <arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
530- Value clampLow = b.create <arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
531- Value operandClamped = b.create <arith::SelectOp>(clampHigh, cHigherBound, operand);
532- operandClamped = b.create <arith::SelectOp>(clampLow, cLowerBound, operandClamped);
443+ Value operandClamped = b.create <arith::MinimumFOp>(clampLow, operand);
444+ operandClamped = b.create <arith::MaximumFOp>(clampHigh, operandClamped);
533445 Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operandClamped);
534446
535447 // Step 2: Convert exponent by adjusting bias.
536448 Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
537- Value biasAdjustment = c0x0000007e; // 126
449+ Value biasAdjustment = createConst (op. getLoc (), i32Ty, 0x7e , rewriter);
538450 Value biasAdjustedSignExp = b.create <arith::SubIOp>(f32SignExp, biasAdjustment);
539451 Value f4SignExp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
540452 f4SignExp = b.create <arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
541453
542- // Step 0: Special consideration for conversion to 0.5.
543- Value cSubnormalLowerBound = createFloatConst (op->getLoc (), f32Ty, 0.25 , rewriter);
544- Value cSubnormalHigherBound = createFloatConst (op->getLoc (), f32Ty, 0.75 , rewriter);
545- Value cLowerBound = createConst (op->getLoc (), f32Ty, -6.0 , rewriter);
546- Value isSubnormal =
547- b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
548-
549454 // Step 3: Set mantissa to first bit.
550- Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32FirstBitMask);
551- Value man1Bit = b.create <arith::ShRUIOp>(man23Bits , c0x00000016);
455+ Value man1Bit = b.create <arith::AndIOp>(f32Bits, cF32FirstBitMask);
456+ man1Bit = b.create <arith::ShRUIOp>(man1Bit , c0x00000016);
552457 Value f4Man = b.create <arith::TruncIOp>(i4Ty, man1Bit);
553458 Value f4Bits = b.create <arith::AddIOp>(f4SignExp, f4Man);
459+
460+ // Step 4: Special consideration for conversion to 0.5.
461+ Value f8Exp = b.create <arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
462+ Value isSubnormal =
463+ b.create <arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
464+ Value isNegOneExp =
465+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
466+ Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32MantissaMask);
467+ Value isNonZeroMan =
468+ b.create <arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
469+ Value roundToHalf = b.create <arith::AndIOp>(isNegOneExp, isNonZeroMan);
470+ Value subnormalF4Bits = createConst (op->getLoc (), i4Ty, 0xf , rewriter);
471+ Value halfF4Bits = createConst (op->getLoc (), i4Ty, 0x0 , rewriter);
472+ Value isZeroExp =
473+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
474+
475+ Value subResult = b.create <arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
476+ subResult = b.create <arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
477+ f4Bits = b.create <arith::SelectOp>(isZeroExp, f4Bits, subResult);
554478
555- // Step 4 : Round up if necessary.
556- Value cRound = c0x00200000 ; // 010 0000...
479+ // Step 5 : Round up if necessary.
480+ Value cRound = createConst (op. getLoc (), i32Ty, 0x200000 , rewriter) ; // 010 0000...
557481 Value man22Bits = b.create <arith::AndIOp>(f32Bits, cF32Last22BitMask);
558482 Value shouldRound =
559483 b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
484+ shouldRound =
485+ b.create <arith::OrIOp>(shouldRound, isSubnormal);
560486 Value roundedF4Bits = b.create <arith::AddIOp>(f4Bits, c0x1);
561487 f4Bits = b.create <arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
562488
563- // Regular case
564-
565-
566-
567- Value f4Exp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedExp);
568- Value f4ExpRounded = b.create <arith::AddIOp>(f4Exp, c0x1);
569- // If we round up or down to even, set mantissa to 0
570- Value shouldRoundUp =
571- b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
572- Value shouldRoundDown = b.create <arith::CmpIOp>(arith::CmpIPredicate::ule,
573- man23Bits, cRoundDown);
574- // dont need to worry about saturation this way
575- f4Exp = b.create <arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
576- Value f4BitsMan0 = b.create <arith::ShLIOp>(f4Exp, cF4MantissaWidth);
577- Value f4Bits = b.create <arith::AddIOp>(f4BitsMan0, c0x1);
578- f4Bits = b.create <arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
579- f4Bits = b.create <arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
580-
581- // Bordering subnormal
582- Value cSubnormalRoundUp =
583- createConst (op.getLoc (), i32Ty, 0x4fffff , rewriter);
584- Value f4Edge = createConst (op.getLoc (), i4Ty, 0x1 , rewriter);
585- Value f4EdgeRounded = createConst (op.getLoc (), i4Ty, 0x2 , rewriter);
586- Value isEdgeRounded = b.create <arith::CmpIOp>(arith::CmpIPredicate::uge,
587- man23Bits, cSubnormalRoundUp);
588- f4EdgeRounded =
589- b.create <arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
590- Value isEdge = b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
591- cSubnormalExp);
592-
593- // Subnormal
594- Value f4Zero = createConst (op.getLoc (), i4Ty, 0x0 , rewriter);
595- Value isZero = b.create <arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
596- cSubnormalExp);
597-
598- // create constants to extract mantissa / exponent
599- Value cF4MantissaWidth = createConst (op->getLoc (), i32Ty, 1 , rewriter);
600- Value cF4SignAndExpWidth = createConst (op->getLoc (), i32Ty, 3 , rewriter);
601- // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
602-
603- // create constants for NaNs
604- Value cF4NaN = createConst (op.getLoc (), i4Ty, 0xf , rewriter);
605- Value cF32NaN = createConst (op.getLoc (), i32Ty, 0xffffffff , rewriter);
606- Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
607-
608- Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
609- Value f32Exponent = b.create <arith::ShLIOp>(exti, cF4MantissaWidth);
610- Value f32Mantissa = b.create <arith::ShRUIOp>(exti, cF4SignAndExpWidth);
611- Value f32Bits = b.create <arith::ShLIOp>(exti, cF32MantissaWidth);
612-
613- Value isNan =
614- b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
615- // select for NaNs
616- f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
617- Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
618- if (resultETy.getIntOrFloatBitWidth () < 32 ) {
619- result = b.create <arith::TruncFOp>(resultTy, result);
620- } else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
621- result = b.create <arith::ExtFOp>(resultTy, result);
622- }
489+ Value result = b.create <arith::BitcastOp>(resultTy, f4Bits);
623490 rewriter.replaceOp (op, result);
624491 return success ();
625492 }
0 commit comments