-
Notifications
You must be signed in to change notification settings - Fork 15.2k
Introduce arith.scaling_extf and arith.scaling_truncf
#141965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 33 commits
1ed7462
91bb889
8eebbea
acc6658
6797446
3ad83bd
9f755c2
5e49a72
682573e
de4497b
e239157
646465c
20b0928
80c080f
b5df100
b6589ae
12c52a6
b3cadf2
5558b03
fc90780
8f91e28
dc7b67f
109ddc5
f3d9865
95a7558
3ccb208
d154341
d8a76fa
a0aa490
ff66dad
10a1bc3
3c7980d
f7c1b79
229f6b8
45e7dba
80061d6
a38ac5e
8151fc7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1215,6 +1215,59 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast | |
| attr-dict `:` type($in) `to` type($out) }]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Scaling ExtFOp | ||
| //===----------------------------------------------------------------------===// | ||
| def Arith_ScalingExtFOp | ||
| : Arith_Op< | ||
| "scaling_extf", [Pure, SameInputOutputTensorDims, | ||
| DeclareOpInterfaceMethods<ArithFastMathInterface>, | ||
| DeclareOpInterfaceMethods<CastOpInterface>]>, | ||
| Arguments<(ins FloatLike:$in, FloatLike:$scale, | ||
| OptionalAttr<Arith_FastMathAttr>:$fastmath)>, | ||
| Results<(outs FloatLike:$out)> { | ||
| let summary = | ||
| "Upcasts quantized floats using provided scales values following OCP MXFP Spec"; | ||
| let description = [{ | ||
umangyadav marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| This operation upcasts quantized floating-point values using provided scale | ||
| values. It expects both scales and the input operand to be of the same shape, | ||
| making the operation elementwise. Scales are usually calculated per block | ||
| following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. | ||
|
|
||
| If scales are calculated per block where blockSize != 1, then scales may | ||
| require broadcasting to make this operation elementwise. For example, let's | ||
| say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and | ||
| assuming quantization happens on the last axis, the input can be reshaped to | ||
| `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated | ||
| per block on the last axis. Therefore, scales will be of shape | ||
| `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other | ||
| shape as long as it is broadcast compatible with the input, e.g., | ||
| `<1 x 1 x ... (dimN/blockSize) x 1>`. | ||
|
|
||
| In this example, before calling into `arith.scaling_extf`, scales must be | ||
| broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note | ||
| that there could be multiple quantization axes. Internally, | ||
| `arith.scaling_extf` would perform the following: | ||
|
|
||
| ``` | ||
| resultTy = get_type(result) | ||
| scaleTy = get_type(scale) | ||
| inputTy = get_type(input) | ||
| assert(scaleTy.shape() == inputTy.shape() == resultTy.shape()) | ||
| scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 | ||
umangyadav marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy | ||
| input.extf = arith.extf(input) : inputTy to resultTy | ||
| result = arith.mulf(scale.extf, input.extf) | ||
| ``` | ||
| It propagates NaN values. Therefore, if either scale or the input element | ||
| contains NaN, then the output element value will also be a NaN. | ||
| }]; | ||
| let hasVerifier = 1; | ||
| let assemblyFormat = | ||
| [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` | ||
| type($in) `,` type($scale) `to` type($out)}]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // TruncIOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -1280,6 +1333,66 @@ def Arith_TruncFOp : | |
| attr-dict `:` type($in) `to` type($out) }]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Scaling TruncFOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| def Arith_ScalingTruncFOp | ||
| : Arith_Op<"scaling_truncf", | ||
| [Pure, SameInputOutputTensorDims, | ||
| DeclareOpInterfaceMethods<ArithRoundingModeInterface>, | ||
| DeclareOpInterfaceMethods<ArithFastMathInterface>, | ||
| DeclareOpInterfaceMethods<CastOpInterface>]>, | ||
| Arguments<(ins FloatLike:$in, FloatLike:$scale, | ||
| OptionalAttr<Arith_RoundingModeAttr>:$roundingmode, | ||
| OptionalAttr<Arith_FastMathAttr>:$fastmath)>, | ||
| Results<(outs FloatLike:$out)> { | ||
| let summary = | ||
| "Downcasts input floating point values using provided scales values following OCP MXFP Spec"; | ||
| let description = [{ | ||
| This operation quantizes input using the provided scale values. It expects | ||
|
||
| both scales and the input operand to be of the same shape and, therefore, | ||
| makes the operation elementwise. Scales are usually calculated per block | ||
| following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. | ||
|
|
||
| If scales are calculated per block where blockSize != 1, scales may require | ||
| broadcasting to make this operation elementwise. For example, let's say the | ||
| input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and | ||
| assuming quantization happens on the last axis, the input can be reshaped to | ||
| `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated | ||
| per block on the last axis. Therefore, scales will be of shape | ||
| `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other | ||
| shape as long as it is broadcast compatible with the input, e.g., | ||
| `<1 x 1 x ... (dimN/blockSize) x 1>`. | ||
|
|
||
| In this example, before calling into `arith.scaling_truncf`, scales must be | ||
| broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note | ||
| that there could be multiple quantization axes. Internally, | ||
| `arith.scaling_truncf` would perform the following: | ||
|
|
||
| ``` | ||
| scaleETy = get_type(scale) | ||
| inputETy = get_type(input) | ||
| resultETy = get_type(result) | ||
| // prepare Scale values with normalization and clamping | ||
| scale.exponent = arith.truncf(scale) : scaleETy to f8E8M0 | ||
| scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputETy | ||
| // emax is calculated as exponent of the largest normal value in quantized type. | ||
| scale.normalize = arith.divf(scale.extf, emax) | ||
| scale.clamped = clamp(scale.normalize) // clamp underflows | ||
| input.flused = flush_denorms(input) | ||
|
||
| result = arith.divf(input.flushed, scale.clamped) | ||
| result.cast = arith.truncf(result, resultETy) | ||
| ``` | ||
| Flushing of denorms in input and scale normalization with emax is added as per | ||
| the OCP MXFP spec. | ||
| }]; | ||
| let hasVerifier = 1; | ||
| let assemblyFormat = | ||
| [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` | ||
| type($in) `,` type($scale) `to` type($out)}]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // UIToFPOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,13 +6,15 @@ | |
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Arith/Transforms/Passes.h" | ||
|
|
||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Dialect/Arith/Transforms/Passes.h" | ||
| #include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
| #include "mlir/IR/BuiltinTypeInterfaces.h" | ||
| #include "mlir/IR/ImplicitLocOpBuilder.h" | ||
| #include "mlir/IR/TypeUtilities.h" | ||
| #include "mlir/Transforms/DialectConversion.h" | ||
| #include "llvm/ADT/APFloat.h" | ||
| #include <cstdint> | ||
|
|
||
| namespace mlir { | ||
| namespace arith { | ||
|
|
@@ -23,6 +25,16 @@ namespace arith { | |
|
|
||
| using namespace mlir; | ||
|
|
||
| static Value createFloatConst(Location loc, Type type, float value, | ||
| PatternRewriter &rewriter) { | ||
| auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value); | ||
| if (auto shapedTy = dyn_cast<ShapedType>(type)) { | ||
| return rewriter.create<arith::ConstantOp>( | ||
|
||
| loc, DenseElementsAttr::get(shapedTy, attr)); | ||
| } | ||
| return rewriter.create<arith::ConstantOp>(loc, attr); | ||
| } | ||
|
|
||
| /// Create an integer or index constant. | ||
| static Value createConst(Location loc, Type type, int value, | ||
| PatternRewriter &rewriter) { | ||
|
|
@@ -31,7 +43,6 @@ static Value createConst(Location loc, Type type, int value, | |
| return rewriter.create<arith::ConstantOp>( | ||
| loc, DenseElementsAttr::get(shapedTy, attr)); | ||
| } | ||
|
|
||
| return rewriter.create<arith::ConstantOp>(loc, attr); | ||
| } | ||
|
|
||
|
|
@@ -409,6 +420,125 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { | |
| } | ||
| }; | ||
|
|
||
| struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> { | ||
| using OpRewritePattern::OpRewritePattern; | ||
| LogicalResult matchAndRewrite(arith::ScalingExtFOp op, | ||
| PatternRewriter &rewriter) const final { | ||
| ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
| Value inputOperand = op.getIn(); | ||
| Value scaleOperand = op.getScale(); | ||
| Type scaleETy = getElementTypeOrSelf(scaleOperand); | ||
| // allow implicit exponent extraction from 16/32 bits floats | ||
| if (scaleETy.getIntOrFloatBitWidth() >= 16) { | ||
| scaleETy = b.getF8E8M0Type(); | ||
| scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand); | ||
| } | ||
| if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { | ||
| return rewriter.notifyMatchFailure( | ||
| op, "scaling_extf is using scales of type which can not be converted " | ||
| "to f8E8M0FNU"); | ||
| } | ||
| Type resultTy = op.getType(); | ||
| // extf on scale will essentially create floating point number | ||
| // of type resulTy that is 2^scale and will also propagate NaNs | ||
| Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand); | ||
| Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand); | ||
| Value result = b.create<arith::MulFOp>(inputExt, scaleExt); | ||
| rewriter.replaceOp(op, result); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct ScalingTruncFOpConverter | ||
| : public OpRewritePattern<arith::ScalingTruncFOp> { | ||
| using OpRewritePattern::OpRewritePattern; | ||
| LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, | ||
| PatternRewriter &rewriter) const final { | ||
| ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
| Value inputOperand = op.getIn(); | ||
| Value scaleOperand = op.getScale(); | ||
| Type scaleTy = scaleOperand.getType(); | ||
| Type scaleETy = getElementTypeOrSelf(scaleOperand); | ||
| // allow implicit exponent extraction from 16/32 bits floats | ||
| if (scaleETy.getIntOrFloatBitWidth() >= 16) { | ||
| scaleETy = b.getF8E8M0Type(); | ||
| scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand); | ||
| scaleTy = scaleOperand.getType(); | ||
| } | ||
| if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { | ||
| return rewriter.notifyMatchFailure( | ||
| op, "scaling_truncf is using scales type which can not be converted " | ||
| "to f8E8M0FNU"); | ||
| } | ||
|
|
||
| Type resultTy = op.getType(); | ||
| Type resultETy = getElementTypeOrSelf(op.getOut()); | ||
|
|
||
| Type inputTy = inputOperand.getType(); | ||
| Type inputETy = getElementTypeOrSelf(inputOperand); | ||
|
|
||
| Type i8Ty = cloneToShapedType(resultTy, b.getI8Type()); | ||
| Type i32Ty = cloneToShapedType(resultTy, b.getI32Type()); | ||
| Type f32Ty = cloneToShapedType(resultTy, b.getF32Type()); | ||
|
|
||
| if (inputETy.getIntOrFloatBitWidth() < 32) { | ||
| inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand); | ||
| } else if (inputETy.getIntOrFloatBitWidth() > 32) { | ||
| inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand); | ||
| } | ||
| inputTy = inputOperand.getType(); | ||
|
||
| inputETy = getElementTypeOrSelf(inputOperand); | ||
|
|
||
| // normalize scale by exponent of the max normal value (emax) in result type | ||
| // as per the OCP MXFP spec | ||
| // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277 | ||
| // here this normalization is carried in f32. Therefore instead of | ||
| // subtraction it does the DivFOp | ||
| const llvm::fltSemantics &resultFltSemantics = | ||
| llvm::cast<FloatType>(resultETy).getFloatSemantics(); | ||
| int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics); | ||
| Value cEmax = createConst(op->getLoc(), i32Ty, maxExponent, rewriter); | ||
| Value c1 = createConst(op->getLoc(), i32Ty, 1, rewriter); | ||
| Value cPow2 = b.create<arith::ShLIOp>(c1, cEmax); | ||
| Value cPow2F32 = b.create<arith::SIToFPOp>(f32Ty, cPow2); | ||
| Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, scaleOperand); | ||
| // note that spec also does the clamping but it should only be done for | ||
| // underflows because dividing by 2^emax will only make it smaller. | ||
| // https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282 | ||
| Value scaleNormalizedF32 = b.create<arith::DivFOp>(scaleF32, cPow2F32); | ||
| // If it has underflown then scale will be a denorm FP32 number after | ||
| // division. Clamp underflows to 2^-127 as per the spec implementation | ||
| Value scaleNormalizedExponentF8 = | ||
| b.create<arith::TruncFOp>(scaleTy, scaleNormalizedF32); | ||
| Value scaleNormalizedExponentU8 = | ||
| b.create<arith::BitcastOp>(i8Ty, scaleNormalizedExponentF8); | ||
| Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter); | ||
| Value scaleClampCond = b.create<arith::CmpIOp>( | ||
| arith::CmpIPredicate::eq, cI8Zero, scaleNormalizedExponentU8); | ||
| // 5.8e-39 is 2^-127, it is a denorm value in f32 | ||
| float clampValue = 5.87747e-39; | ||
| Value scaleClampValue = | ||
| createFloatConst(op.getLoc(), f32Ty, clampValue, rewriter); | ||
| Value clampedScale = b.create<arith::SelectOp>( | ||
| scaleClampCond, scaleClampValue, scaleNormalizedF32); | ||
| // flush denorms by checking if exponent part of input operand is zero | ||
| // or not. | ||
| Value inputExponent = b.create<arith::TruncFOp>(scaleTy, inputOperand); | ||
| Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent); | ||
| Value inputFlushCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, | ||
| cI8Zero, inputExponentU8); | ||
| Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter); | ||
| Value flushedInput = | ||
|
||
| b.create<arith::SelectOp>(inputFlushCond, inputTyZero, inputOperand); | ||
| Value result = b.create<arith::DivFOp>(flushedInput, clampedScale); | ||
| // propagate rounding mode and fast math attributes | ||
| Value resultCast = b.create<arith::TruncFOp>( | ||
| resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we check resultTy <= f32?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Verify() checks that output width is smaller compared to input.
No, other
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, verify checks that output width is smaller than input width. But I understand the output of this function is always f32. Then, I wonder if somebody can do input, scale -> f128, result -> f64. Then, it's true that output width < input width and we are still trying to truncate "result" which is f32 into f64. Not sure if I misunderstood something?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.
No, why do you think so ? Output dtype will be whatever user has specified.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I mean result of the function before truncation. result.dtype = f32, right?
I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide. |
||
| rewriter.replaceOp(op, resultCast); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct ArithExpandOpsPass | ||
| : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> { | ||
| using ArithExpandOpsPassBase::ArithExpandOpsPassBase; | ||
|
|
@@ -432,7 +562,9 @@ struct ArithExpandOpsPass | |
| arith::MaximumFOp, | ||
| arith::MinimumFOp, | ||
| arith::MaxNumFOp, | ||
| arith::MinNumFOp | ||
| arith::MinNumFOp, | ||
| arith::ScalingExtFOp, | ||
| arith::ScalingTruncFOp | ||
| >(); | ||
|
|
||
| if (includeBf16) { | ||
|
|
@@ -492,8 +624,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) { | |
| patterns.getContext()); | ||
| } | ||
|
|
||
| void mlir::arith::populateExpandScalingExtTruncPatterns( | ||
| RewritePatternSet &patterns) { | ||
| patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>( | ||
| patterns.getContext()); | ||
| } | ||
|
|
||
| void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { | ||
| populateCeilFloorDivExpandOpsPatterns(patterns); | ||
| populateExpandScalingExtTruncPatterns(patterns); | ||
| // clang-format off | ||
| patterns.add< | ||
| MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>, | ||
|
|
@@ -503,7 +642,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { | |
| MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>, | ||
| MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>, | ||
| MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>, | ||
| MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> | ||
| MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> | ||
| >(patterns.getContext()); | ||
| // clang-format on | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.