@@ -322,6 +322,57 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
322322 }
323323};
324324
325+ struct F4E2M1ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
326+ using OpRewritePattern::OpRewritePattern;
327+ LogicalResult matchAndRewrite (arith::ExtFOp op,
328+ PatternRewriter &rewriter) const final {
329+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
330+ Value operand = op.getOperand ();
331+ Type operandTy = operand.getType ();
332+ Type resultTy = op.getType ();
333+ Type operandETy = getElementTypeOrSelf (operandTy);
334+ Type resultETy = getElementTypeOrSelf (resultTy);
335+
336+ if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
337+ return rewriter.notifyMatchFailure (op, " not a ext of F4E2M1FN" );
338+ }
339+
340+ Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
341+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
342+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
343+
344+ Value bitcast = b.create <arith::BitcastOp>(i4Ty, operand);
345+
346+ // create constants to extract mantissa / exponent
347+ Value cF4MantissaWidth = createConst (op->getLoc (), i32Ty, 1 , rewriter);
348+ Value cF4SignAndExpWidth = createConst (op->getLoc (), i32Ty, 3 , rewriter);
349+ // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
350+
351+ // create constants for NaNs
352+ Value cF4NaN = createConst (op.getLoc (), i4Ty, 0xf , rewriter);
353+ Value cF32NaN = createConst (op.getLoc (), i32Ty, 0xffffffff , rewriter);
354+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
355+
356+ Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
357+ Value f32Exponent = b.create <arith::ShLIOp>(exti, cF4MantissaWidth);
358+ Value f32Mantissa = b.create <arith::ShRUIOp>(exti, cF4SignAndExpWidth);
359+ Value f32Bits = b.create <arith::ShLIOp>(exti, cF32MantissaWidth);
360+
361+ Value isNan =
362+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
363+ // select for NaNs
364+ f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
365+ Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
366+ if (resultETy.getIntOrFloatBitWidth () < 32 ) {
367+ result = b.create <arith::TruncFOp>(resultTy, result);
368+ } else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
369+ result = b.create <arith::ExtFOp>(resultTy, result);
370+ }
371+ rewriter.replaceOp (op, result);
372+ return success ();
373+ }
374+ };
375+
325376struct F8E8M0ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
326377 using OpRewritePattern::OpRewritePattern;
327378 LogicalResult matchAndRewrite (arith::ExtFOp op,
@@ -365,6 +416,161 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
365416 return success ();
366417 }
367418};
419+ /*
420+ Conversion from F32 to F4E2M1 according to the OCP Spec:
421+ www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
422+
423+ The spec requiers us to perform Round to Nearest, Ties to Even.
424+
425+ This means that after rounding, we should break ties by choosing the option
426+ which results in a mantissa of 0 in the least significant digit.
427+
428+ Table of representable values in F4E2M1:
429+
430+ Note: x is sign bit
431+ | Binary | Value ( + / - )
432+ | x000 | 0.0
433+ | x001 | 0.5
434+ | x010 | 1.0
435+ | x011 | 1.5
436+ | x100 | 2.0
437+ | x101 | 3.0
438+ | x110 | 4.0
439+ | x111 | 6.0
440+
441+ Conversion procedure:
442+ Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
443+ Create bias adjusted exponent, E_1 <- E_0 - 126
444+ If E_0 <= 0111 1110
445+ M_1 <- 0, E_1 <- 00
446+ end
447+ if E_1 == 00 (special case for almost subnormal)
448+ if we must round up (M_0 >= 10000000000000000000000)
449+ M_1 <- 0
450+ E_1 <- 01
451+ else
452+ M_1 <- 1
453+ end
454+ Else if E_1 > 00
455+ roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
456+ if roundToEven
457+ M_1 <- 0
458+ else
459+ M_1 <- 1
460+ If M_0 >= 11000000000000000000000
461+ increment E_1
462+ If E_1 > 11 (saturate if beyond range)
463+ M_1 <- 1, E_1 <- 11
464+ end
465+ */
466+ struct F4E2M1TruncFOpConverter : public OpRewritePattern <arith::TruncFOp> {
467+ using OpRewritePattern::OpRewritePattern;
468+ LogicalResult matchAndRewrite (arith::TruncFOp op,
469+ PatternRewriter &rewriter) const final {
470+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
471+ Value operand = op.getOperand ();
472+ Type operandTy = operand.getType ();
473+ Type resultTy = op.getType ();
474+ Type operandETy = getElementTypeOrSelf (operandTy);
475+ Type resultETy = getElementTypeOrSelf (resultTy);
476+
477+ if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
478+ return rewriter.notifyMatchFailure (op, " not a trunc of F4E2M1FN" );
479+ }
480+
481+ Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
482+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
483+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
484+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
485+
486+ // Constants
487+ Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
488+ Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
489+ Value c0x7e = createConst (op.getLoc (), i8Ty, 0x7e , rewriter);
490+ Value c0x00000009 = createConst (op->getLoc (), i32Ty, 9 , rewriter);
491+ Value c0x00000017 = createConst (op->getLoc (), i32Ty, 23 , rewriter);
492+ Value c0x00200000 = createConst (op.getLoc (), i32Ty, 0x200000 , rewriter);
493+ Value c0x00400000 = createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
494+ Value c0x00600000 = createConst (op.getLoc (), i32Ty, 0x600000 , rewriter);
495+ Value c0x007fffff = createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
496+
497+ Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
498+
499+ Value cF32MantissaWidth = c0x00000017; // 23
500+ Value cF4MantissaWidth = c0x1; // 1
501+ Value cF32SignExpWidth = c0x00000009; // 9
502+ Value cF32MantissaMask = c0x007fffff;
503+ Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
504+ Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32MantissaMask);
505+ Value exp8Bits = b.create <arith::TruncIOp>(i8Ty, f32SignExp);
506+
507+ Value cSubnormalExp = c0x7e; // 126
508+
509+ // Regular case
510+ Value biasAdjustment = c0x7e; // 126
511+ Value cRoundUp = c0x00600000; // 110 0000...
512+ Value cRoundDown = c0x00200000; // 010 0000...
513+ Value biasAdjustedExp = b.create <arith::SubIOp>(exp8Bits, biasAdjustment);
514+ Value f4Exp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedExp);
515+ Value f4ExpRounded = b.create <arith::AddIOp>(f4Exp, c0x1);
516+ // If we round up or down to even, set mantissa to 0
517+ Value shouldRoundUp =
518+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
519+ Value shouldRoundDown = b.create <arith::CmpIOp>(arith::CmpIPredicate::ule,
520+ man23Bits, cRoundDown);
521+ // dont need to worry about saturation this way
522+ f4Exp = b.create <arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
523+ Value f4BitsMan0 = b.create <arith::ShLIOp>(f4Exp, cF4MantissaWidth);
524+ Value f4Bits = b.create <arith::AddIOp>(f4BitsMan0, c0x1);
525+ f4Bits = b.create <arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
526+ f4Bits = b.create <arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
527+
528+ // Bordering subnormal
529+ Value cSubnormalRoundUp =
530+ createConst (op.getLoc (), i32Ty, 0x4fffff , rewriter);
531+ Value f4Edge = createConst (op.getLoc (), i4Ty, 0x1 , rewriter);
532+ Value f4EdgeRounded = createConst (op.getLoc (), i4Ty, 0x2 , rewriter);
533+ Value isEdgeRounded = b.create <arith::CmpIOp>(arith::CmpIPredicate::uge,
534+ man23Bits, cSubnormalRoundUp);
535+ f4EdgeRounded =
536+ b.create <arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
537+ Value isEdge = b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
538+ cSubnormalExp);
539+
540+ // Subnormal
541+ Value f4Zero = createConst (op.getLoc (), i4Ty, 0x0 , rewriter);
542+ Value isZero = b.create <arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
543+ cSubnormalExp);
544+
545+ // create constants to extract mantissa / exponent
546+ Value cF4MantissaWidth = createConst (op->getLoc (), i32Ty, 1 , rewriter);
547+ Value cF4SignAndExpWidth = createConst (op->getLoc (), i32Ty, 3 , rewriter);
548+ // Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
549+
550+ // create constants for NaNs
551+ Value cF4NaN = createConst (op.getLoc (), i4Ty, 0xf , rewriter);
552+ Value cF32NaN = createConst (op.getLoc (), i32Ty, 0xffffffff , rewriter);
553+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
554+
555+ Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
556+ Value f32Exponent = b.create <arith::ShLIOp>(exti, cF4MantissaWidth);
557+ Value f32Mantissa = b.create <arith::ShRUIOp>(exti, cF4SignAndExpWidth);
558+ Value f32Bits = b.create <arith::ShLIOp>(exti, cF32MantissaWidth);
559+
560+ Value isNan =
561+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
562+ // select for NaNs
563+ f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
564+ Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
565+ if (resultETy.getIntOrFloatBitWidth () < 32 ) {
566+ result = b.create <arith::TruncFOp>(resultTy, result);
567+ } else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
568+ result = b.create <arith::ExtFOp>(resultTy, result);
569+ }
570+ rewriter.replaceOp (op, result);
571+ return success ();
572+ }
573+ };
368574
369575/*
370576TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
0 commit comments