Skip to content

Commit 36b5905

Browse files
initial implementation to fix
1 parent 9fd22cb commit 36b5905

File tree

1 file changed

+206
-0
lines changed

1 file changed

+206
-0
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,57 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
322322
}
323323
};
324324

325+
struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
326+
using OpRewritePattern::OpRewritePattern;
327+
LogicalResult matchAndRewrite(arith::ExtFOp op,
328+
PatternRewriter &rewriter) const final {
329+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
330+
Value operand = op.getOperand();
331+
Type operandTy = operand.getType();
332+
Type resultTy = op.getType();
333+
Type operandETy = getElementTypeOrSelf(operandTy);
334+
Type resultETy = getElementTypeOrSelf(resultTy);
335+
336+
if (!llvm::isa<Float4E2M1FNType>(operandETy)) {
337+
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
338+
}
339+
340+
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
341+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
342+
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
343+
344+
Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
345+
346+
// create constants to extract mantissa / exponent
347+
Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
348+
Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
349+
// Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
350+
351+
// create constants for NaNs
352+
Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
353+
Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
354+
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
355+
356+
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
357+
Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
358+
Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
359+
Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
360+
361+
Value isNan =
362+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
363+
// select for NaNs
364+
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
365+
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
366+
if (resultETy.getIntOrFloatBitWidth() < 32) {
367+
result = b.create<arith::TruncFOp>(resultTy, result);
368+
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
369+
result = b.create<arith::ExtFOp>(resultTy, result);
370+
}
371+
rewriter.replaceOp(op, result);
372+
return success();
373+
}
374+
};
375+
325376
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
326377
using OpRewritePattern::OpRewritePattern;
327378
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -365,6 +416,161 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
365416
return success();
366417
}
367418
};
419+
/*
420+
Conversion from F32 to F4E2M1 according to the OCP Spec:
421+
www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
422+
423+
The spec requiers us to perform Round to Nearest, Ties to Even.
424+
425+
This means that after rounding, we should break ties by choosing the option
426+
which results in a mantissa of 0 in the least significant digit.
427+
428+
Table of representable values in F4E2M1:
429+
430+
Note: x is sign bit
431+
| Binary | Value ( + / - )
432+
| x000 | 0.0
433+
| x001 | 0.5
434+
| x010 | 1.0
435+
| x011 | 1.5
436+
| x100 | 2.0
437+
| x101 | 3.0
438+
| x110 | 4.0
439+
| x111 | 6.0
440+
441+
Conversion procedure:
442+
Let M_0 = f32 mantissa, M_1 = f4 mantissa, Let E_0 = f32 exp, let E_1 = f4 exp
443+
Create bias adjusted exponent, E_1 <- E_0 - 126
444+
If E_0 <= 0111 1110
445+
M_1 <- 0, E_1 <- 00
446+
end
447+
if E_1 == 00 (special case for almost subnormal)
448+
if we must round up (M_0 >= 10000000000000000000000)
449+
M_1 <- 0
450+
E_1 <- 01
451+
else
452+
M_1 <- 1
453+
end
454+
Else if E_1 > 00
455+
roundToEven <- M_0 <= 01000000000000000000000 || M_0 >= 11000000000000000000000
456+
if roundToEven
457+
M_1 <- 0
458+
else
459+
M_1 <- 1
460+
If M_0 >= 11000000000000000000000
461+
increment E_1
462+
If E_1 > 11 (saturate if beyond range)
463+
M_1 <- 1, E_1 <- 11
464+
end
465+
*/
466+
struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
467+
using OpRewritePattern::OpRewritePattern;
468+
LogicalResult matchAndRewrite(arith::TruncFOp op,
469+
PatternRewriter &rewriter) const final {
470+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
471+
Value operand = op.getOperand();
472+
Type operandTy = operand.getType();
473+
Type resultTy = op.getType();
474+
Type operandETy = getElementTypeOrSelf(operandTy);
475+
Type resultETy = getElementTypeOrSelf(resultTy);
476+
477+
if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
478+
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
479+
}
480+
481+
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
482+
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
483+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
484+
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
485+
486+
// Constants
487+
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
488+
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
489+
Value c0x7e = createConst(op.getLoc(), i8Ty, 0x7e, rewriter);
490+
Value c0x00000009 = createConst(op->getLoc(), i32Ty, 9, rewriter);
491+
Value c0x00000017 = createConst(op->getLoc(), i32Ty, 23, rewriter);
492+
Value c0x00200000 = createConst(op.getLoc(), i32Ty, 0x200000, rewriter);
493+
Value c0x00400000 = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
494+
Value c0x00600000 = createConst(op.getLoc(), i32Ty, 0x600000, rewriter);
495+
Value c0x007fffff = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
496+
497+
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
498+
499+
Value cF32MantissaWidth = c0x00000017; // 23
500+
Value cF4MantissaWidth = c0x1; // 1
501+
Value cF32SignExpWidth = c0x00000009; // 9
502+
Value cF32MantissaMask = c0x007fffff;
503+
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
504+
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
505+
Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
506+
507+
Value cSubnormalExp = c0x7e; // 126
508+
509+
// Regular case
510+
Value biasAdjustment = c0x7e; // 126
511+
Value cRoundUp = c0x00600000; // 110 0000...
512+
Value cRoundDown = c0x00200000; // 010 0000...
513+
Value biasAdjustedExp = b.create<arith::SubIOp>(exp8Bits, biasAdjustment);
514+
Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedExp);
515+
Value f4ExpRounded = b.create<arith::AddIOp>(f4Exp, c0x1);
516+
// If we round up or down to even, set mantissa to 0
517+
Value shouldRoundUp =
518+
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man23Bits, cRoundUp);
519+
Value shouldRoundDown = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule,
520+
man23Bits, cRoundDown);
521+
// dont need to worry about saturation this way
522+
f4Exp = b.create<arith::SelectOp>(shouldRoundUp, f4ExpRounded, f4Exp);
523+
Value f4BitsMan0 = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
524+
Value f4Bits = b.create<arith::AddIOp>(f4BitsMan0, c0x1);
525+
f4Bits = b.create<arith::SelectOp>(shouldRoundUp, f4BitsMan0, f4Bits);
526+
f4Bits = b.create<arith::SelectOp>(shouldRoundDown, f4BitsMan0, f4Bits);
527+
528+
// Bordering subnormal
529+
Value cSubnormalRoundUp =
530+
createConst(op.getLoc(), i32Ty, 0x4fffff, rewriter);
531+
Value f4Edge = createConst(op.getLoc(), i4Ty, 0x1, rewriter);
532+
Value f4EdgeRounded = createConst(op.getLoc(), i4Ty, 0x2, rewriter);
533+
Value isEdgeRounded = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
534+
man23Bits, cSubnormalRoundUp);
535+
f4EdgeRounded =
536+
b.create<arith::SelectOp>(isEdgeRounded, f4EdgeRounded, f4Edge);
537+
Value isEdge = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, exp8Bits,
538+
cSubnormalExp);
539+
540+
// Subnormal
541+
Value f4Zero = createConst(op.getLoc(), i4Ty, 0x0, rewriter);
542+
Value isZero = b.create<arith::CmpIOp>(arith::CmpIPredicate::ule, exp8Bits,
543+
cSubnormalExp);
544+
545+
// create constants to extract mantissa / exponent
546+
Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
547+
Value cF4SignAndExpWidth = createConst(op->getLoc(), i32Ty, 3, rewriter);
548+
// Value cF4MantissaWidth = createConst(op->getLoc(), i32Ty, 1, rewriter);
549+
550+
// create constants for NaNs
551+
Value cF4NaN = createConst(op.getLoc(), i4Ty, 0xf, rewriter);
552+
Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
553+
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
554+
555+
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
556+
Value f32Exponent = b.create<arith::ShLIOp>(exti, cF4MantissaWidth);
557+
Value f32Mantissa = b.create<arith::ShRUIOp>(exti, cF4SignAndExpWidth);
558+
Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
559+
560+
Value isNan =
561+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
562+
// select for NaNs
563+
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
564+
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
565+
if (resultETy.getIntOrFloatBitWidth() < 32) {
566+
result = b.create<arith::TruncFOp>(resultTy, result);
567+
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
568+
result = b.create<arith::ExtFOp>(resultTy, result);
569+
}
570+
rewriter.replaceOp(op, result);
571+
return success();
572+
}
573+
};
368574

369575
/*
370576
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type

0 commit comments

Comments
 (0)