Skip to content

Commit f11addf

Browse files
Initial implementation of truncf
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 63c10e6 commit f11addf

File tree

1 file changed

+66
-199
lines changed

1 file changed

+66
-199
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 66 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -334,57 +334,6 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
334334
}
335335
};
336336

337-
struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
338-
using OpRewritePattern::OpRewritePattern;
339-
LogicalResult matchAndRewrite(arith::ExtFOp op,
340-
PatternRewriter &rewriter) const final {
341-
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
342-
Value operand = op.getOperand();
343-
Type operandTy = operand.getType();
344-
Type resultTy = op.getType();
345-
Type operandETy = getElementTypeOrSelf(operandTy);
346-
Type resultETy = getElementTypeOrSelf(resultTy);
347-
348-
if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
349-
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
350-
}
351-
352-
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
353-
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
354-
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
355-
356-
Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
357-
358-
// create constants to extract mantissa / exponent
359-
Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
360-
Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
361-
// Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
362-
363-
// create constants for NaNs
364-
Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
365-
Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
366-
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
367-
368-
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
369-
Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
370-
Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
371-
Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
372-
373-
Value isNan =
374-
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
375-
// select for NaNs
376-
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
377-
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
378-
if (resultETy.getIntOrFloatBitWidth() < 32) {
379-
result = b.create<arith::TruncFOp>(resultTy, result);
380-
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
381-
result = b.create<arith::ExtFOp>(resultTy, result);
382-
}
383-
rewriter.replaceOp(op, result);
384-
return success();
385-
}
386-
};
387-
388337
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
389338
using OpRewritePattern::OpRewritePattern;
390339
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -428,60 +377,34 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
428377
return success();
429378
}
430379
};
431-
/*
432-
Conversion from F32 to F4E2M1 according to the OCP Spec:
433-
www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
434-
435-
The spec requiers us to perform Round to Nearest, Ties to Even.
436-
437-
This means that after rounding, we should break ties by choosing the option
438-
which results in a mantissa of 0 in the least significant digit.
439-
440-
Table of representable values in F4E2M1:
441-
442-
Note: x is sign bit
443-
| Binary | Value ( + / - )
444-
| x000 | 0.0
445-
| x001 | 0.5
446-
| x010 | 1.0
447-
| x011 | 1.5
448-
| x100 | 2.0
449-
| x101 | 3.0
450-
| x110 | 4.0
451-
| x111 | 6.0
452-
453-
Conversion procedure:
454-
455-
Step 1: Clamp to max f4 value
456-
457-
Step 2: convert exponent, if signed int comparison <= 0, set 0
458-
459-
Step 3: if mantissa[1:] greater than 1000000, add 1
460-
461-
Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
462-
Create bias adjusted exponent, E_1 <- E_0 - 126
463-
If E_0 <= 0111 1110
464-
M_1 <- 0, E_1 <- 00
465-
end
466-
if E_1 == 00 (special case for almost subnormal)
467-
if we must round up (M_0 >= 10000000000000000000000)
468-
M_1 <- 0
469-
E_1 <- 01
470-
else
471-
M_1 <- 1
472-
end
473-
Else if E_1 > 00
474-
roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
475-
if roundToEven
476-
M_1 <- 0
477-
else
478-
M_1 <- 1
479-
If M_0 >= 11000000000000000000000
480-
increment E_1
481-
If E_1 > 11 (saturate if beyond range)
482-
M_1 <- 1, E_1 <- 11
483-
end
484-
*/
380+
381+
/// Conversion from F32 to F4E2M1 according to the OCP Spec:
382+
/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
383+
///
384+
/// The spec requiers us to perform Round to Nearest, Ties to Even.
385+
///
386+
/// This means that after rounding, we should break ties by choosing the option
387+
/// which results in a mantissa of 0 in the least significant digit.
388+
///
389+
/// Table of representable values in F4E2M1:
390+
///
391+
/// Note: x is sign bit
392+
/// | Binary | Value ( + / - )
393+
/// | x000 | 0.0
394+
/// | x001 | 0.5
395+
/// | x010 | 1.0
396+
/// | x011 | 1.5
397+
/// | x100 | 2.0
398+
/// | x101 | 3.0
399+
/// | x110 | 4.0
400+
/// | x111 | 6.0
401+
///
402+
/// Conversion procedure:
403+
/// Step 1: Clamp to representable bounds.
404+
/// Step 2: Convert exponent by adjusting bias.
405+
/// Step 3: Set mantissa to first bit.
406+
/// Step 4: Special consideration for subnormal and zero exponent.
407+
/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
485408
struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
486409
using OpRewritePattern::OpRewritePattern;
487410
LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -504,122 +427,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
504427

505428
// Constants
506429
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
507-
Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
508-
Value c0x0000007e = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
509-
510-
Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
511-
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
512-
Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
513-
Value c0x0000001f = createConst(op->getLoc(), i32Ty, 31, rewriter);
514-
Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
515-
Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
516-
Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
517-
Value c0x003fffff = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
518-
Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
519-
Value cF32MantissaWidth = c0x00000017; // 23
430+
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
520431
Value cF4MantissaWidth = c0x1; // 1
521-
Value cF32SignExpWidth = c0x00000009; // 9
522-
Value cF32FirstBitMask = c0x00400000;
523-
Value cF32Last22BitMask = c0x003fffff;
524-
Value cF32MantissaMask = c0x007fffff;
525-
432+
Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
433+
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
434+
Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
435+
Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
436+
Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
437+
Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
438+
Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
439+
526440
// Step 1: Clamp to bounds.
527441
Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
528442
Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
529-
Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
530-
Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
531-
Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
532-
operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
443+
Value operandClamped = b.create<arith::MinimumFOp>(clampLow, operand);
444+
operandClamped = b.create<arith::MaximumFOp>(clampHigh, operandClamped);
533445
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
534446

535447
// Step 2: Convert exponent by adjusting bias.
536448
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
537-
Value biasAdjustment = c0x0000007e; // 126
449+
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
538450
Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
539451
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
540452
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
541453

542-
// Step 0: Special consideration for conversion to 0.5.
543-
Value cSubnormalLowerBound = createFloatConst(op->getLoc(), f32Ty, 0.25, rewriter);
544-
Value cSubnormalHigherBound = createFloatConst(op->getLoc(), f32Ty, 0.75, rewriter);
545-
Value cLowerBound = createConst(op->getLoc(), f32Ty, -6.0, rewriter);
546-
Value isSubnormal =
547-
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
548-
549454
// Step 3: Set mantissa to first bit.
550-
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
551-
Value man1Bit = b.create<arith::ShRUIOp>(man23Bits, c0x00000016);
455+
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
456+
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
552457
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
553458
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
459+
460+
// Step 4: Special consideration for conversion to 0.5.
461+
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
462+
Value isSubnormal =
463+
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
464+
Value isNegOneExp =
465+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
466+
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
467+
Value isNonZeroMan =
468+
b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
469+
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
470+
Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
471+
Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
472+
Value isZeroExp =
473+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
474+
475+
Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
476+
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
477+
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
554478

555-
// Step 4: Round up if necessary.
556-
Value cRound = c0x00200000; // 010 0000...
479+
// Step 5: Round up if necessary.
480+
Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
557481
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
558482
Value shouldRound =
559483
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
484+
shouldRound =
485+
b.create<arith::OrIOp>(shouldRound, isSubnormal);
560486
Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
561487
f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
562488

563-
// Regular case
564-
565-
566-
567-
Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
568-
Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
569-
// If we round up or down to even, set mantissa to 0
570-
Value shouldRoundUp =
571-
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
572-
Value shouldRoundDown = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule,
573-
man23Bits, cRoundDown);
574-
// dont need to worry about saturation this way
575-
f4Exp = b.create<arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
576-
Value f4BitsMan0 = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
577-
Value f4Bits = b.create<arith::AddIOp>(f4BitsMan0, c0x1);
578-
f4Bits = b.create<arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
579-
f4Bits = b.create<arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
580-
581-
// Bordering subnormal
582-
Value cSubnormalRoundUp =
583-
createConst(op.getLoc(), i32Ty, 0x4fffff, rewriter);
584-
Value f4Edge = createConst(op.getLoc(), i4Ty, 0x1, rewriter);
585-
Value f4EdgeRounded = createConst(op.getLoc(), i4Ty, 0x2, rewriter);
586-
Value isEdgeRounded = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
587-
man23Bits, cSubnormalRoundUp);
588-
f4EdgeRounded =
589-
b.create<arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
590-
Value isEdge = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
591-
cSubnormalExp);
592-
593-
// Subnormal
594-
Value f4Zero = createConst(op.getLoc(), i4Ty, 0x0, rewriter);
595-
Value isZero = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
596-
cSubnormalExp);
597-
598-
// create constants to extract mantissa / exponent
599-
Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
600-
Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
601-
// Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
602-
603-
// create constants for NaNs
604-
Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
605-
Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
606-
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
607-
608-
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
609-
Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
610-
Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
611-
Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
612-
613-
Value isNan =
614-
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
615-
// select for NaNs
616-
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
617-
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
618-
if (resultETy.getIntOrFloatBitWidth() < 32) {
619-
result = b.create<arith::TruncFOp>(resultTy, result);
620-
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
621-
result = b.create<arith::ExtFOp>(resultTy, result);
622-
}
489+
Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
623490
rewriter.replaceOp(op, result);
624491
return success();
625492
}

0 commit comments

Comments
 (0)