@@ -34,9 +34,9 @@ static Value createConst(Location loc, Type type, int value,
3434 return rewriter.create <arith::ConstantOp>(loc, attr);
3535}
3636
37- // / Create an float constant.
38- static Value createFloatConst (Location loc, Type type, float value,
39- PatternRewriter &rewriter) {
37+ // / Create a float constant.
38+ static Value createFloatConst (Location loc, Type type, APFloat value,
39+ PatternRewriter &rewriter) {
4040 auto attr = rewriter.getFloatAttr (getElementTypeOrSelf (type), value);
4141 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
4242 return rewriter.create <arith::ConstantOp>(
@@ -416,67 +416,68 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
416416 Type operandETy = getElementTypeOrSelf (operandTy);
417417 Type resultETy = getElementTypeOrSelf (resultTy);
418418
419- if (!llvm:: isa<Float4E2M1FNType>(resultETy)) {
420- return rewriter.notifyMatchFailure (op, " not a trunc of F4E2M1FN" );
419+ if (!isa<Float32Type>(operandETy) || ! isa<Float4E2M1FNType>(resultETy)) {
420+ return rewriter.notifyMatchFailure (op, " not a trunc of F32 to F4E2M1FN" );
421421 }
422422
423423 Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
424424 Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
425425 Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
426426 Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
427427
428- // Constants
429428 Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
430- Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter); // 23
431- Value cF4MantissaWidth = c0x1; // 1
432- Value cF32FirstBitMask = createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
433429 Value c0x00000016 = createConst (op->getLoc (), i32Ty, 22 , rewriter);
434430 Value c0x00 = createConst (op.getLoc (), i8Ty, 0x00 , rewriter);
435431 Value c0xff = createConst (op.getLoc (), i8Ty, 0xff , rewriter);
436- Value cF32MantissaMask = createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
437432 Value c0x00000000 = createConst (op.getLoc (), i32Ty, 0 , rewriter);
438- Value cF32Last22BitMask = createConst (op->getLoc (), i32Ty, 0x3fffff , rewriter);;
439-
433+
440434 // Step 1: Clamp to bounds.
441- Value cHigherBound = createFloatConst (op->getLoc (), f32Ty, 6.0 , rewriter);
442- Value cLowerBound = createFloatConst (op->getLoc (), f32Ty, -6.0 , rewriter);
443- Value operandClamped = b.create <arith::MinimumFOp>(clampLow, operand);
444- operandClamped = b.create <arith::MaximumFOp>(clampHigh, operandClamped);
435+ Value cHigherBound =
436+ createFloatConst (op->getLoc (), f32Ty, APFloat (6 .0f ), rewriter);
437+ Value cLowerBound =
438+ createFloatConst (op->getLoc (), f32Ty, APFloat (-6 .0f ), rewriter);
439+ Value operandClamped = b.create <arith::MinimumFOp>(cLowerBound, operand);
440+ operandClamped = b.create <arith::MaximumFOp>(cHigherBound, operandClamped);
445441 Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operandClamped);
446-
442+
447443 // Step 2: Convert exponent by adjusting bias.
448- Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
449444 Value biasAdjustment = createConst (op.getLoc (), i32Ty, 0x7e , rewriter);
445+ Value cF4MantissaWidth = c0x1; // 1
446+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter); // 23
447+ Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
450448 Value biasAdjustedSignExp = b.create <arith::SubIOp>(f32SignExp, biasAdjustment);
451449 Value f4SignExp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
452450 f4SignExp = b.create <arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
453451
454452 // Step 3: Set mantissa to first bit.
453+ Value cF32FirstBitMask = createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
455454 Value man1Bit = b.create <arith::AndIOp>(f32Bits, cF32FirstBitMask);
456455 man1Bit = b.create <arith::ShRUIOp>(man1Bit, c0x00000016);
457456 Value f4Man = b.create <arith::TruncIOp>(i4Ty, man1Bit);
458457 Value f4Bits = b.create <arith::AddIOp>(f4SignExp, f4Man);
459458
460459 // Step 4: Special consideration for conversion to 0.5.
460+ Value cF32MantissaMask = createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
461461 Value f8Exp = b.create <arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
462462 Value isSubnormal =
463- b.create <arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
463+ b.create <arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
464464 Value isNegOneExp =
465- b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
465+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
466466 Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32MantissaMask);
467467 Value isNonZeroMan =
468- b.create <arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
468+ b.create <arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
469469 Value roundToHalf = b.create <arith::AndIOp>(isNegOneExp, isNonZeroMan);
470- Value subnormalF4Bits = createConst (op->getLoc (), i4Ty, 0xf , rewriter);
471- Value halfF4Bits = createConst (op->getLoc (), i4Ty, 0x0 , rewriter);
472470 Value isZeroExp =
473- b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
471+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
474472
473+ Value subnormalF4Bits = createConst (op->getLoc (), i4Ty, 0xf , rewriter);
474+ Value halfF4Bits = createConst (op->getLoc (), i4Ty, 0x0 , rewriter);
475475 Value subResult = b.create <arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
476476 subResult = b.create <arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
477477 f4Bits = b.create <arith::SelectOp>(isZeroExp, f4Bits, subResult);
478-
478+
479479 // Step 5: Round up if necessary.
480+ Value cF32Last22BitMask = createConst (op->getLoc (), i32Ty, 0x3fffff , rewriter);
480481 Value cRound = createConst (op.getLoc (), i32Ty, 0x200000 , rewriter); // 010 0000...
481482 Value man22Bits = b.create <arith::AndIOp>(f32Bits, cF32Last22BitMask);
482483 Value shouldRound =
0 commit comments