@@ -334,6 +334,70 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
334334 }
335335};
336336
337+ struct F4E2M1ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
338+ using OpRewritePattern::OpRewritePattern;
339+ LogicalResult matchAndRewrite (arith::ExtFOp op,
340+ PatternRewriter &rewriter) const final {
341+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
342+ Value operand = op.getOperand ();
343+ Type operandTy = operand.getType ();
344+ Type resultTy = op.getType ();
345+ Type operandETy = getElementTypeOrSelf (operandTy);
346+ Type resultETy = getElementTypeOrSelf (resultTy);
347+
348+ if (!llvm::isa<Float4E2M1FNType>(operandETy) ||
349+ !llvm::isa<Float32Type>(resultETy)) {
350+ return rewriter.notifyMatchFailure (op, " not a ext of F4E2M1FN to F32" );
351+ }
352+
353+ Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
354+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
355+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
356+
357+ Value bitcast = b.create <arith::BitcastOp>(i4Ty, operand);
358+
359+ Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
360+ Value c0x0000001c = createConst (op->getLoc (), i32Ty, 28 , rewriter);
361+ Value c0x00000014 = createConst (op->getLoc (), i32Ty, 22 , rewriter);
362+ Value cZero =
363+ createFloatConst (op->getLoc (), f32Ty, APFloat (0 .0f ), rewriter);
364+ Value cHalf =
365+ createFloatConst (op->getLoc (), f32Ty, APFloat (0 .5f ), rewriter);
366+
367+ Value mantissaBitmask = c0x1;
368+ Value exponentBitmask = createConst (op.getLoc (), i4Ty, 0x6 , rewriter);
369+ Value signBitmask = createConst (op.getLoc (), i4Ty, 0x8 , rewriter);
370+
371+ Value f4SignBit = b.create <arith::AndIOp>(bitcast, signBitmask);
372+ Value f32Bits = b.create <arith::ExtUIOp>(i32Ty, f4SignBit);
373+ f32Bits = b.create <arith::ShRUIOp>(f32Bits, c0x0000001c);
374+
375+ Value biasAdjustment = createConst (op.getLoc (), i32Ty, 126 , rewriter);
376+ Value f4ExpBits = b.create <arith::AndIOp>(bitcast, exponentBitmask);
377+ f4ExpBits = b.create <arith::ShRUIOp>(f4ExpBits, c0x1);
378+ Value f32ExpBits = b.create <arith::ExtUIOp>(i32Ty, f4ExpBits);
379+ f32ExpBits = b.create <arith::AddIOp>(f32ExpBits, biasAdjustment);
380+ f32ExpBits = b.create <arith::ShLIOp>(f32ExpBits, c0x00000014);
381+ f32Bits = b.create <arith::AddIOp>(f32Bits, f32ExpBits);
382+
383+ Value f4ManBit = b.create <arith::AndIOp>(bitcast, mantissaBitmask);
384+ Value f32ManBit = b.create <arith::ExtUIOp>(i32Ty, f4ManBit);
385+ f32Bits = b.create <arith::AddIOp>(f32Bits, f32ManBit);
386+
387+ // Special consideration for subnormal exp (exp == 0).
388+ Value isSubnormal = b.create <arith::CmpIOp>(arith::CmpIPredicate::eq,
389+ f32ExpBits, biasAdjustment);
390+ Value isManSet =
391+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
392+ Value subnormalVal = b.create <arith::SelectOp>(isManSet, cHalf, cZero);
393+ f32Bits = b.create <arith::SelectOp>(isSubnormal, subnormalVal, f32Bits);
394+
395+ Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
396+ rewriter.replaceOp (op, result);
397+ return success ();
398+ }
399+ };
400+
337401struct F8E8M0ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
338402 using OpRewritePattern::OpRewritePattern;
339403 LogicalResult matchAndRewrite (arith::ExtFOp op,
@@ -389,7 +453,7 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
389453// / Table of representable values in F4E2M1:
390454// /
391455// / Note: x is sign bit
392- // / | Binary | Value ( + / - )
456+ // / | Binary | Value ( + / - )
393457// / | x000 | 0.0
394458// / | x001 | 0.5
395459// / | x010 | 1.0
@@ -399,12 +463,13 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
399463// / | x110 | 4.0
400464// / | x111 | 6.0
401465// /
402- // / Conversion procedure:
466+ // / Conversion procedure:
403467// / Step 1: Clamp to representable bounds.
404468// / Step 2: Convert exponent by adjusting bias.
405469// / Step 3: Set mantissa to first bit.
406470// / Step 4: Special consideration for subnormal and zero exponent.
407- // / Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
471+ // / Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
472+ // / subnormal.
408473struct F4E2M1TruncFOpConverter : public OpRewritePattern <arith::TruncFOp> {
409474 using OpRewritePattern::OpRewritePattern;
410475 LogicalResult matchAndRewrite (arith::TruncFOp op,
@@ -442,48 +507,54 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
442507
443508 // Step 2: Convert exponent by adjusting bias.
444509 Value biasAdjustment = createConst (op.getLoc (), i32Ty, 0x7e , rewriter);
445- Value cF4MantissaWidth = c0x1; // 1
446- Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter); // 23
510+ Value cF4MantissaWidth = c0x1; // 1
511+ Value cF32MantissaWidth =
512+ createConst (op->getLoc (), i32Ty, 23 , rewriter); // 23
447513 Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
448- Value biasAdjustedSignExp = b.create <arith::SubIOp>(f32SignExp, biasAdjustment);
514+ Value biasAdjustedSignExp =
515+ b.create <arith::SubIOp>(f32SignExp, biasAdjustment);
449516 Value f4SignExp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
450517 f4SignExp = b.create <arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
451-
518+
452519 // Step 3: Set mantissa to first bit.
453- Value cF32FirstBitMask = createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
520+ Value cF32FirstBitMask =
521+ createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
454522 Value man1Bit = b.create <arith::AndIOp>(f32Bits, cF32FirstBitMask);
455523 man1Bit = b.create <arith::ShRUIOp>(man1Bit, c0x00000016);
456524 Value f4Man = b.create <arith::TruncIOp>(i4Ty, man1Bit);
457525 Value f4Bits = b.create <arith::AddIOp>(f4SignExp, f4Man);
458-
526+
459527 // Step 4: Special consideration for conversion to 0.5.
460- Value cF32MantissaMask = createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
528+ Value cF32MantissaMask =
529+ createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
461530 Value f8Exp = b.create <arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
462531 Value isSubnormal =
463- b.create <arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
532+ b.create <arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
464533 Value isNegOneExp =
465- b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
534+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
466535 Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32MantissaMask);
467- Value isNonZeroMan =
468- b. create <arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
536+ Value isNonZeroMan = b. create <arith::CmpIOp>(arith::CmpIPredicate::ugt,
537+ man23Bits, c0x00000000);
469538 Value roundToHalf = b.create <arith::AndIOp>(isNegOneExp, isNonZeroMan);
470- Value isZeroExp =
471- b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
472-
539+ Value isZeroExp =
540+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
541+
473542 Value subnormalF4Bits = createConst (op->getLoc (), i4Ty, 0xf , rewriter);
474543 Value halfF4Bits = createConst (op->getLoc (), i4Ty, 0x0 , rewriter);
475- Value subResult = b.create <arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
544+ Value subResult =
545+ b.create <arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
476546 subResult = b.create <arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
477547 f4Bits = b.create <arith::SelectOp>(isZeroExp, f4Bits, subResult);
478-
548+
479549 // Step 5: Round up if necessary.
480- Value cF32Last22BitMask = createConst (op->getLoc (), i32Ty, 0x3fffff , rewriter);
481- Value cRound = createConst (op.getLoc (), i32Ty, 0x200000 , rewriter); // 010 0000...
550+ Value cF32Last22BitMask =
551+ createConst (op->getLoc (), i32Ty, 0x3fffff , rewriter);
552+ Value cRound =
553+ createConst (op.getLoc (), i32Ty, 0x200000 , rewriter); // 010 0000...
482554 Value man22Bits = b.create <arith::AndIOp>(f32Bits, cF32Last22BitMask);
483555 Value shouldRound =
484556 b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
485- shouldRound =
486- b.create <arith::OrIOp>(shouldRound, isSubnormal);
557+ shouldRound = b.create <arith::OrIOp>(shouldRound, isSubnormal);
487558 Value roundedF4Bits = b.create <arith::AddIOp>(f4Bits, c0x1);
488559 f4Bits = b.create <arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
489560
0 commit comments