-
Notifications
You must be signed in to change notification settings - Fork 15.2k
Introduce arith.scaling_extf and arith.scaling_truncf
#141965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
1ed7462
91bb889
8eebbea
acc6658
6797446
3ad83bd
9f755c2
5e49a72
682573e
de4497b
e239157
646465c
20b0928
80c080f
b5df100
b6589ae
12c52a6
b3cadf2
5558b03
fc90780
8f91e28
dc7b67f
109ddc5
f3d9865
95a7558
3ccb208
d154341
d8a76fa
a0aa490
ff66dad
10a1bc3
3c7980d
f7c1b79
229f6b8
45e7dba
80061d6
a38ac5e
8151fc7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1215,6 +1215,44 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast | |
| attr-dict `:` type($in) `to` type($out) }]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Scaling ExtFOp | ||
| //===----------------------------------------------------------------------===// | ||
| def Arith_ScalingExtFOp | ||
| : Arith_Op< | ||
| "scaling_extf", [Pure, SameInputOutputTensorDims, | ||
| DeclareOpInterfaceMethods<ArithFastMathInterface>, | ||
| DeclareOpInterfaceMethods<CastOpInterface>]>, | ||
| Arguments<(ins FloatLike:$in, FloatLike:$scale, | ||
| OptionalAttr<Arith_FastMathAttr>:$fastmath)>, | ||
| Results<(outs FloatLike:$out)> { | ||
| let summary = | ||
| "cast from floating-point to larger floating-point using provided scales"; | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| let description = [{ | ||
umangyadav marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape. | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Scale operand is expected to be of type f8E8M0. But that can be relaxed in future. | ||
| Scale is usually calculated per block. | ||
| Let's say originally input is shape <dim1 x dim2 x dim3 .. x dimN> then, given blockSize it can be reshaped to <dim1 x dim2 x ... (dimN/blockSize) x blockSize>. | ||
| Scales will be calculated on the block axis. Therefore scale will be of shape <dim1 x dim2 x dim3 ... (dimN/blockSize) x 1>. | ||
| Before calling into `arith.scaling_extf`, scales must be broadcasted appropariately to make it as same shape as input making `arith.scaling_extf` an elemenwise op. | ||
| In above example. scales should be broadcasted to shape of <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>. | ||
|
||
| ``` | ||
| resultTy = get_type(result) | ||
| scaleTy = get_type(scale) | ||
| inputTy = get_type(input) | ||
| scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 | ||
umangyadav marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| scale.bcast = broadcast_to_same_shape_as(result) | ||
| scale.extf = arith.extf(sale.bcast) : f8E8M0 to resultTy | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| input.extf = arith.extf(input) : inputTy to resultTy | ||
| result = arith.mulf(scale.extf, input.extf) | ||
| ``` | ||
| It propagates NaN values. Therefore if either scale or input operand element value is a NaN then output element value will also be a NaN. | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| }]; | ||
| let hasVerifier = 1; | ||
| let assemblyFormat = | ||
| [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // TruncIOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -1280,6 +1318,49 @@ def Arith_TruncFOp : | |
| attr-dict `:` type($in) `to` type($out) }]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Scaling TruncFOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| def Arith_ScalingTruncFOp | ||
| : Arith_Op<"scaling_truncf", | ||
| [Pure, SameInputOutputTensorDims, | ||
| DeclareOpInterfaceMethods<ArithRoundingModeInterface>, | ||
| DeclareOpInterfaceMethods<ArithFastMathInterface>, | ||
| DeclareOpInterfaceMethods<CastOpInterface>]>, | ||
| Arguments<(ins FloatLike:$in, FloatLike:$scale, | ||
| OptionalAttr<Arith_RoundingModeAttr>:$roundingmode, | ||
| OptionalAttr<Arith_FastMathAttr>:$fastmath)>, | ||
| Results<(outs FloatLike:$out)> { | ||
| let summary = | ||
| "cast from floating-point to narrower floating-point with scales"; | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| let description = [{ | ||
| This operation implements micro-scaling (OCP MXFP) quantization of input using provided scale values. | ||
| This quantization usually happens over a block of values. All values in that block share same scale value for quantization purposes. | ||
| Therefore original input of shape `<dim1 x dim2 ... dimN>` can be thought of as of shape `<dim1 x dim2 x ... (dimN / blockSize) x blockSize>`, | ||
| assuming quantization axis is the last axis. | ||
| Original scales values therefore would be of shape `<dim1 x dim2 x ... x dimN-1 x (dimN/blockSize)>`. | ||
| `arith.scaling_truncf` operation is an elementwise operation. Therefore, before calling into `arith.scaling_truncf`, if `blockSize != 1` then | ||
| scales must be broadcasted appropariately to make it of same shape as the input operand. | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Internally arith.scaling_truncf does the following: | ||
| ``` | ||
| scaleETy = get_type(scale) | ||
| inputETy = get_type(input) | ||
| resultETy = get_type(result) | ||
| scale.bcast = broadcast_to_same_shape_as(input) | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| scale.exponent = arith.truncf(scale.bcst) : scaleETy to f8E8M0 | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputETy | ||
| result = arith.divf(input, scale.extf) | ||
| result.cast = arith.truncf(result, resultETy) | ||
| ``` | ||
| OCP MXFP spec flushes denorm input value before quantization. NaNs are propagated. | ||
|
|
||
| }]; | ||
| let hasVerifier = 1; | ||
| let assemblyFormat = | ||
| [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // UIToFPOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,13 +6,16 @@ | |
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Arith/Transforms/Passes.h" | ||
|
|
||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Dialect/Arith/Transforms/Passes.h" | ||
| #include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
| #include "mlir/IR/BuiltinTypeInterfaces.h" | ||
| #include "mlir/IR/ImplicitLocOpBuilder.h" | ||
| #include "mlir/IR/PDLPatternMatch.h.inc" | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #include "mlir/IR/TypeUtilities.h" | ||
| #include "mlir/Transforms/DialectConversion.h" | ||
| #include "llvm/ADT/APFloat.h" | ||
| #include <cstdint> | ||
|
|
||
| namespace mlir { | ||
| namespace arith { | ||
|
|
@@ -23,6 +26,16 @@ namespace arith { | |
|
|
||
| using namespace mlir; | ||
|
|
||
| static Value createFloatConst(Location loc, Type type, float value, | ||
| PatternRewriter &rewriter) { | ||
| auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value); | ||
| if (auto shapedTy = dyn_cast<ShapedType>(type)) { | ||
| return rewriter.create<arith::ConstantOp>( | ||
|
||
| loc, DenseElementsAttr::get(shapedTy, attr)); | ||
| } | ||
| return rewriter.create<arith::ConstantOp>(loc, attr); | ||
| } | ||
|
|
||
| /// Create an integer or index constant. | ||
| static Value createConst(Location loc, Type type, int value, | ||
| PatternRewriter &rewriter) { | ||
|
|
@@ -31,7 +44,6 @@ static Value createConst(Location loc, Type type, int value, | |
| return rewriter.create<arith::ConstantOp>( | ||
| loc, DenseElementsAttr::get(shapedTy, attr)); | ||
| } | ||
|
|
||
| return rewriter.create<arith::ConstantOp>(loc, attr); | ||
| } | ||
|
|
||
|
|
@@ -409,6 +421,112 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { | |
| } | ||
| }; | ||
|
|
||
| struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> { | ||
| using OpRewritePattern::OpRewritePattern; | ||
| LogicalResult matchAndRewrite(arith::ScalingExtFOp op, | ||
| PatternRewriter &rewriter) const final { | ||
| ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
| auto inputOperand = op.getIn(); | ||
| auto scaleOperand = op.getScale(); | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) { | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return rewriter.notifyMatchFailure( | ||
| op, "scaling extf is not using scale operand of type f8E8M0FNU"); | ||
| } | ||
| Type resultTy = op.getType(); | ||
| // extf on scale will essentially create f32 number that is 2^scale and will | ||
|
||
| // also propagate NaNs | ||
| Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand); | ||
| Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand); | ||
| Value result = b.create<arith::MulFOp>(inputExt, scaleExt); | ||
| rewriter.replaceOp(op, result); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct ScalingTruncFOpConverter | ||
| : public OpRewritePattern<arith::ScalingTruncFOp> { | ||
| using OpRewritePattern::OpRewritePattern; | ||
| LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, | ||
| PatternRewriter &rewriter) const final { | ||
| ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
| auto inputOperand = op.getIn(); | ||
| auto scaleOperand = op.getScale(); | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) { | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return rewriter.notifyMatchFailure( | ||
| op, "scaling truncf is not using scale operand of type f8E8M0FNU"); | ||
umangyadav marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| auto scaleTy = scaleOperand.getType(); | ||
|
||
|
|
||
| Type resultTy = op.getType(); | ||
| Type resultETy = getElementTypeOrSelf(op.getOut()); | ||
|
|
||
| Type inputTy = inputOperand.getType(); | ||
| Type inputETy = getElementTypeOrSelf(inputOperand); | ||
|
|
||
| Type i8Ty = cloneToShapedType(resultTy, b.getI8Type()); | ||
| Type i32Ty = cloneToShapedType(resultTy, b.getI32Type()); | ||
| Type f32Ty = cloneToShapedType(resultTy, b.getF32Type()); | ||
| Type f8Ty = cloneToShapedType(resultTy, b.getF8E8M0Type()); | ||
|
|
||
| if (inputETy.getIntOrFloatBitWidth() < 32) { | ||
| inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand); | ||
| } else if (inputETy.getIntOrFloatBitWidth() > 32) { | ||
| inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand); | ||
| } | ||
| inputTy = inputOperand.getType(); | ||
|
||
| inputETy = getElementTypeOrSelf(inputOperand); | ||
|
|
||
| // normalize scale by exponent of the max normal value in result type as per | ||
| // the OCP MXFP spec | ||
| // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277 | ||
| const llvm::fltSemantics &resultFltSemantics = | ||
| llvm::cast<FloatType>(resultETy).getFloatSemantics(); | ||
| int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics); | ||
| Value cMaxNormalExponent = | ||
|
||
| createConst(op->getLoc(), i32Ty, maxExponent, rewriter); | ||
| Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter); | ||
| Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter); | ||
| Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand); | ||
| Value scaleI32 = b.create<arith::ExtSIOp>(i32Ty, scaleI8); | ||
|
||
| Value unbiasedScale = b.create<arith::SubIOp>(scaleI32, c127); | ||
| Value normalizedUnbiasedScale = | ||
| b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent); | ||
| // clamp scale exponent as per spec | ||
| // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282 | ||
| // upper clamp limit of 127 will be mapped to biased value of 255 and will | ||
| // be bitcasted to 0xFF in F8E8M0 which will be converted to Float32 NaN | ||
| // using arith.extf | ||
| Value clampUpperCond = b.create<arith::CmpIOp>( | ||
| arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127); | ||
| Value clampLowerCond = b.create<arith::CmpIOp>( | ||
| arith::CmpIPredicate::slt, normalizedUnbiasedScale, cNeg127); | ||
| Value clampedScale = b.create<arith::SelectOp>( | ||
| clampUpperCond, c127, | ||
| b.create<arith::SelectOp>(clampLowerCond, cNeg127, | ||
| normalizedUnbiasedScale)); | ||
| Value biasedScale = b.create<arith::AddIOp>(clampedScale, c127); | ||
| Value biasedScaleI8 = b.create<arith::TruncIOp>(i8Ty, biasedScale); | ||
| Value biasedScaleF8 = b.create<arith::BitcastOp>(f8Ty, biasedScaleI8); | ||
| Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8); | ||
| // flush denorms by checking if exponent part of input operand is zero | ||
| // or not. | ||
| Value inputExponent = b.create<arith::TruncFOp>(scaleTy, inputOperand); | ||
| Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent); | ||
| Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter); | ||
| Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero, | ||
| inputExponentU8); | ||
| Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter); | ||
| Value flushedInput = | ||
|
||
| b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand); | ||
| Value result = b.create<arith::DivFOp>(flushedInput, scaleF32); | ||
| // propagate rounding mode and fast math attributes | ||
| Value resultCast = b.create<arith::TruncFOp>( | ||
| resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we check resultTy <= f32?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Verify() checks that output width is smaller compared to input.
No, other
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, verify checks that output width is smaller than input width. But I understand the output of this function is always f32. Then, I wonder if somebody can do input, scale -> f128, result -> f64. Then, it's true that output width < input width and we are still trying to truncate "result" which is f32 into f64. Not sure if I misunderstood something?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.
No, why do you think so ? Output dtype will be whatever user has specified.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I mean result of the function before truncation. result.dtype = f32, right?
I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide. |
||
| rewriter.replaceOp(op, resultCast); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct ArithExpandOpsPass | ||
| : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> { | ||
| using ArithExpandOpsPassBase::ArithExpandOpsPassBase; | ||
|
|
@@ -432,7 +550,9 @@ struct ArithExpandOpsPass | |
| arith::MaximumFOp, | ||
| arith::MinimumFOp, | ||
| arith::MaxNumFOp, | ||
| arith::MinNumFOp | ||
| arith::MinNumFOp, | ||
| arith::ScalingExtFOp, | ||
| arith::ScalingTruncFOp | ||
| >(); | ||
|
|
||
| if (includeBf16) { | ||
|
|
@@ -492,8 +612,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) { | |
| patterns.getContext()); | ||
| } | ||
|
|
||
| void mlir::arith::populateExpandScalingExtTruncPatterns( | ||
| RewritePatternSet &patterns) { | ||
| patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>( | ||
| patterns.getContext()); | ||
| } | ||
|
|
||
| void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { | ||
| populateCeilFloorDivExpandOpsPatterns(patterns); | ||
| populateExpandScalingExtTruncPatterns(patterns); | ||
| // clang-format off | ||
| patterns.add< | ||
| MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>, | ||
|
|
@@ -503,7 +630,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { | |
| MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>, | ||
| MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>, | ||
| MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>, | ||
| MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> | ||
| MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> | ||
| >(patterns.getContext()); | ||
| // clang-format on | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.