1313#include " mlir/IR/ImplicitLocOpBuilder.h"
1414#include " mlir/IR/TypeUtilities.h"
1515#include " mlir/Transforms/DialectConversion.h"
16- #include " llvm/ADT/APFloat.h"
17- #include < cstdint>
1816
1917namespace mlir {
2018namespace arith {
@@ -25,16 +23,6 @@ namespace arith {
2523
2624using namespace mlir ;
2725
28- static Value createFloatConst (Location loc, Type type, float value,
29- PatternRewriter &rewriter) {
30- auto attr = rewriter.getFloatAttr (getElementTypeOrSelf (type), value);
31- if (auto shapedTy = dyn_cast<ShapedType>(type)) {
32- return rewriter.create <arith::ConstantOp>(
33- loc, DenseElementsAttr::get (shapedTy, attr));
34- }
35- return rewriter.create <arith::ConstantOp>(loc, attr);
36- }
37-
3826// / Create an integer or index constant.
3927static Value createConst (Location loc, Type type, int value,
4028 PatternRewriter &rewriter) {
@@ -368,7 +356,8 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
368356 f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
369357 Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
370358 if (resultETy.getIntOrFloatBitWidth () < 32 ) {
371- result = b.create <arith::TruncFOp>(resultTy, result);
359+ result = b.create <arith::TruncFOp>(resultTy, result, nullptr ,
360+ op.getFastmathAttr ());
372361 } else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
373362 result = b.create <arith::ExtFOp>(resultTy, result);
374363 }
@@ -406,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
406395 Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
407396
408397 if (operandETy.getIntOrFloatBitWidth () < 32 ) {
409- operand = b.create <arith::ExtFOp>(f32Ty, operand);
398+ operand = b.create <arith::ExtFOp>(f32Ty, operand, op. getFastmathAttr () );
410399 } else if (operandETy.getIntOrFloatBitWidth () > 32 ) {
411- operand = b.create <arith::TruncFOp>(f32Ty, operand);
400+ operand = b.create <arith::TruncFOp>(
401+ f32Ty, operand, op.getRoundingmodeAttr (), op.getFastmathAttr ());
412402 }
413403 Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
414404 Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
@@ -431,7 +421,8 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
431421 // allow implicit exponent extraction from 16/32 bits floats
432422 if (scaleETy.getIntOrFloatBitWidth () >= 16 ) {
433423 scaleETy = b.getF8E8M0Type ();
434- scaleOperand = b.create <arith::TruncFOp>(scaleETy, scaleOperand);
424+ scaleOperand = b.create <arith::TruncFOp>(scaleETy, scaleOperand, nullptr ,
425+ op.getFastmathAttr ());
435426 }
436427 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
437428 return rewriter.notifyMatchFailure (
@@ -441,14 +432,22 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
441432 Type resultTy = op.getType ();
442433 // extf on scale will essentially create floating point number
443434 // of type resulTy that is 2^scale and will also propagate NaNs
444- Value scaleExt = b.create <arith::ExtFOp>(resultTy, scaleOperand);
445- Value inputExt = b.create <arith::ExtFOp>(resultTy, inputOperand);
446- Value result = b.create <arith::MulFOp>(inputExt, scaleExt);
435+ Value scaleExt =
436+ b.create <arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr ());
437+ Value inputExt =
438+ b.create <arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr ());
439+ Value result =
440+ b.create <arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr ());
447441 rewriter.replaceOp (op, result);
448442 return success ();
449443 }
450444};
451445
446+ /*
447+ Expands arith.ScalingTruncFOp(in, scale) into
448+ scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
449+ result = arith.truncf(in / (2^scale))
450+ */
452451struct ScalingTruncFOpConverter
453452 : public OpRewritePattern<arith::ScalingTruncFOp> {
454453 using OpRewritePattern::OpRewritePattern;
@@ -470,68 +469,14 @@ struct ScalingTruncFOpConverter
470469 op, " scaling_truncf is using scales type which can not be converted "
471470 " to f8E8M0FNU" );
472471 }
473-
474472 Type resultTy = op.getType ();
475- Type resultETy = getElementTypeOrSelf (op.getOut ());
476-
477473 Type inputTy = inputOperand.getType ();
478- Type inputETy = getElementTypeOrSelf (inputOperand);
479-
480- Type i8Ty = cloneToShapedType (resultTy, b.getI8Type ());
481- Type i32Ty = cloneToShapedType (resultTy, b.getI32Type ());
482- Type f32Ty = cloneToShapedType (resultTy, b.getF32Type ());
483-
484- if (inputETy.getIntOrFloatBitWidth () < 32 ) {
485- inputOperand = b.create <arith::ExtFOp>(f32Ty, inputOperand);
486- } else if (inputETy.getIntOrFloatBitWidth () > 32 ) {
487- inputOperand = b.create <arith::TruncFOp>(f32Ty, inputOperand);
488- }
489- inputTy = inputOperand.getType ();
490- inputETy = getElementTypeOrSelf (inputOperand);
491-
492- // normalize scale by exponent of the max normal value (emax) in result type
493- // as per the OCP MXFP spec
494- // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277
495- // here this normalization is carried in f32. Therefore instead of
496- // subtraction it does the DivFOp
497- const llvm::fltSemantics &resultFltSemantics =
498- llvm::cast<FloatType>(resultETy).getFloatSemantics ();
499- int maxExponent = APFloat::semanticsMaxExponent (resultFltSemantics);
500- Value cEmax = createConst (op->getLoc (), i32Ty, maxExponent, rewriter);
501- Value c1 = createConst (op->getLoc (), i32Ty, 1 , rewriter);
502- Value cPow2 = b.create <arith::ShLIOp>(c1, cEmax);
503- Value cPow2F32 = b.create <arith::SIToFPOp>(f32Ty, cPow2);
504- Value scaleF32 = b.create <arith::ExtFOp>(f32Ty, scaleOperand);
505- // note that spec also does the clamping but it should only be done for
506- // underflows because dividing by 2^emax will only make it smaller.
507- // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
508- Value scaleNormalizedF32 = b.create <arith::DivFOp>(scaleF32, cPow2F32);
509- // If it has underflown then scale will be a denorm FP32 number after
510- // division. Clamp underflows to 2^-127 as per the spec implementation
511- Value scaleNormalizedExponentF8 =
512- b.create <arith::TruncFOp>(scaleTy, scaleNormalizedF32);
513- Value scaleNormalizedExponentU8 =
514- b.create <arith::BitcastOp>(i8Ty, scaleNormalizedExponentF8);
515- Value cI8Zero = createConst (op.getLoc (), i8Ty, 0x00 , rewriter);
516- Value scaleClampCond = b.create <arith::CmpIOp>(
517- arith::CmpIPredicate::eq, cI8Zero, scaleNormalizedExponentU8);
518- // 5.8e-39 is 2^-127, it is a denorm value in f32
519- float clampValue = 5.87747e-39 ;
520- Value scaleClampValue =
521- createFloatConst (op.getLoc (), f32Ty, clampValue, rewriter);
522- Value clampedScale = b.create <arith::SelectOp>(
523- scaleClampCond, scaleClampValue, scaleNormalizedF32);
524- // flush denorms by checking if exponent part of input operand is zero
525- // or not.
526- Value inputExponent = b.create <arith::TruncFOp>(scaleTy, inputOperand);
527- Value inputExponentU8 = b.create <arith::BitcastOp>(i8Ty, inputExponent);
528- Value inputFlushCond = b.create <arith::CmpIOp>(arith::CmpIPredicate::eq,
529- cI8Zero, inputExponentU8);
530- Value inputTyZero = createFloatConst (op.getLoc (), inputTy, 0 , rewriter);
531- Value flushedInput =
532- b.create <arith::SelectOp>(inputFlushCond, inputTyZero, inputOperand);
533- Value result = b.create <arith::DivFOp>(flushedInput, clampedScale);
534- // propagate rounding mode and fast math attributes
474+ // this will create a floating point number of type
475+ // inputTy that is 2^scale and will also propagate NaNs
476+ scaleOperand =
477+ b.create <arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr ());
478+ Value result = b.create <arith::DivFOp>(inputOperand, scaleOperand,
479+ op.getFastmathAttr ());
535480 Value resultCast = b.create <arith::TruncFOp>(
536481 resultTy, result, op.getRoundingmodeAttr (), op.getFastmathAttr ());
537482 rewriter.replaceOp (op, resultCast);
0 commit comments