Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1ed7462
Make it elementwise op
umangyadav May 28, 2025
91bb889
Add flushing logic
umangyadav May 28, 2025
8eebbea
Fix build issues
umangyadav May 28, 2025
acc6658
clamping on exponent
umangyadav May 29, 2025
6797446
propagate rounding mode and fast math attrs
umangyadav May 29, 2025
3ad83bd
Add some more notes
umangyadav May 29, 2025
9f755c2
Merge branch 'main' into scaling_cvt
umangyadav May 29, 2025
5e49a72
add scaling_extf tests
umangyadav May 29, 2025
682573e
Fix some issues
umangyadav May 29, 2025
de4497b
add test for scaling_truncf
umangyadav May 29, 2025
e239157
add some more tests
umangyadav May 29, 2025
646465c
Fix Formatting
umangyadav May 29, 2025
20b0928
Merge branch 'main' into scaling_cvt
umangyadav May 29, 2025
80c080f
Remove TODO
umangyadav May 29, 2025
b5df100
Update mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
umangyadav May 29, 2025
b6589ae
Update mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
umangyadav May 29, 2025
12c52a6
Update mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
umangyadav May 29, 2025
b3cadf2
Update mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
umangyadav May 29, 2025
5558b03
Merge remote-tracking branch 'upstream/main' into scaling_cvt
umangyadav May 29, 2025
fc90780
Allow implicit truncf to f8E8M0FN type to extract exponent bits
umangyadav May 29, 2025
8f91e28
USe floating point to normalize scales
umangyadav May 30, 2025
dc7b67f
Rewrite description
umangyadav May 30, 2025
109ddc5
change error message
umangyadav May 30, 2025
f3d9865
some nits
umangyadav May 30, 2025
95a7558
Merge remote-tracking branch 'upstream/main' into scaling_cvt
umangyadav May 30, 2025
3ccb208
Formatting
umangyadav May 30, 2025
d154341
Change comment
umangyadav May 30, 2025
d8a76fa
Update mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
umangyadav May 31, 2025
a0aa490
Update mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
umangyadav May 31, 2025
ff66dad
address some review comments
umangyadav May 31, 2025
10a1bc3
Merge branch 'main' into scaling_cvt
umangyadav May 31, 2025
3c7980d
Merge remote-tracking branch 'upstream/main' into scaling_cvt
umangyadav Jun 2, 2025
f7c1b79
Fix docs
umangyadav Jun 2, 2025
229f6b8
Merge remote-tracking branch 'upstream/main' into scaling_cvt
umangyadav Jun 6, 2025
45e7dba
Simplify arith.scaling_truncf to just do division and trunction. Deno…
umangyadav Jun 6, 2025
80061d6
address review comments and add tests
umangyadav Jun 6, 2025
a38ac5e
Formatting
umangyadav Jun 6, 2025
8151fc7
Merge branch 'main' into scaling_cvt
umangyadav Jun 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,44 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
attr-dict `:` type($in) `to` type($out) }];
}

//===----------------------------------------------------------------------===//
// Scaling ExtFOp
//===----------------------------------------------------------------------===//
def Arith_ScalingExtFOp
: Arith_Op<
"scaling_extf", [Pure, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in, FloatLike:$scale,
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
Results<(outs FloatLike:$out)> {
let summary =
"cast from floating-point to larger floating-point using provided scales";
let description = [{
Implements micro-scaling floating point ExtF op. It expects both scales and input operand to be of same shape.
Scale operand is expected to be of type f8E8M0. But that can be relaxed in future.
Scale is usually calculated per block.
Let's say originally input is shape <dim1 x dim2 x dim3 .. x dimN> then, given blockSize it can be reshaped to <dim1 x dim2 x ... (dimN/blockSize) x blockSize>.
Scales will be calculated on the block axis. Therefore scale will be of shape <dim1 x dim2 x dim3 ... (dimN/blockSize) x 1>.
Before calling into `arith.scaling_extf`, scales must be broadcasted appropariately to make it as same shape as input making `arith.scaling_extf` an elemenwise op.
In above example. scales should be broadcasted to shape of <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand from the description, it doesn't need to be broadcasted, you could use a non-broadcasted tensor of shape <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>?

If that's the case, I don't think it's useful to explain all of these details, broadcasting is just a use-case. If I understood it correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the description needs to be updated - this arith op is set up to do things elementwise because arith ops in general are elementwise and the broadcast scale thing is a special case that gets pattern-matched in a future ArithToAMDGPU

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to rewrite documentation. Please check again and let me know if it is more clear now.

```
resultTy = get_type(result)
scaleTy = get_type(scale)
inputTy = get_type(input)
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
scale.bcast = broadcast_to_same_shape_as(result)
scale.extf = arith.extf(sale.bcast) : f8E8M0 to resultTy
input.extf = arith.extf(input) : inputTy to resultTy
result = arith.mulf(scale.extf, input.extf)
```
It propagates NaN values. Therefore if either scale or input operand element value is a NaN then output element value will also be a NaN.
}];
let hasVerifier = 1;
let assemblyFormat =
[{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
}

//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1280,6 +1318,49 @@ def Arith_TruncFOp :
attr-dict `:` type($in) `to` type($out) }];
}

//===----------------------------------------------------------------------===//
// Scaling TruncFOp
//===----------------------------------------------------------------------===//

def Arith_ScalingTruncFOp
: Arith_Op<"scaling_truncf",
[Pure, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in, FloatLike:$scale,
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
Results<(outs FloatLike:$out)> {
let summary =
"cast from floating-point to narrower floating-point with scales";
let description = [{
This operation implements micro-scaling (OCP MXFP) quantization of input using provided scale values.
This quantization usually happens over a block of values. All values in that block share same scale value for quantization purposes.
Therefore original input of shape `<dim1 x dim2 ... dimN>` can be thought of as of shape `<dim1 x dim2 x ... (dimN / blockSize) x blockSize>`,
assuming quantization axis is the last axis.
Original scales values therefore would be of shape `<dim1 x dim2 x ... x dimN-1 x (dimN/blockSize)>`.
`arith.scaling_truncf` operation is an elementwise operation. Therefore, before calling into `arith.scaling_truncf`, if `blockSize != 1` then
scales must be broadcasted appropariately to make it of same shape as the input operand.
Internally arith.scaling_truncf does the following:
```
scaleETy = get_type(scale)
inputETy = get_type(input)
resultETy = get_type(result)
scale.bcast = broadcast_to_same_shape_as(input)
scale.exponent = arith.truncf(scale.bcst) : scaleETy to f8E8M0
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputETy
result = arith.divf(input, scale.extf)
result.cast = arith.truncf(result, resultETy)
```
OCP MXFP spec flushes denorm input value before quantization. NaNs are propagated.

}];
let hasVerifier = 1;
let assemblyFormat =
[{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($in) `,` type($scale) `to` type($out)}];
}

//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);

/// Add patterns to expland scaling ExtF/TruncF ops to equivalent arith ops
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);

/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Builder {
Attribute metadata = Attribute());

// Types.
FloatType getF8E8M0Type();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }

//===----------------------------------------------------------------------===//
// ScalingExtFOp
//===----------------------------------------------------------------------===//

bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
TypeRange outputs) {
return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
}

LogicalResult arith::ScalingExtFOp::verify() {
return verifyExtOp<FloatType>(*this);
}

//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1565,6 +1578,19 @@ LogicalResult arith::TruncFOp::verify() {
return verifyTruncateOp<FloatType>(*this);
}

//===----------------------------------------------------------------------===//
// ScalingTruncFOp
//===----------------------------------------------------------------------===//

bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
TypeRange outputs) {
return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
}

LogicalResult arith::ScalingTruncFOp::verify() {
return verifyTruncateOp<FloatType>(*this);
}

//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
Expand Down
137 changes: 132 additions & 5 deletions mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/Transforms/Passes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PDLPatternMatch.h.inc"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APFloat.h"
#include <cstdint>

namespace mlir {
namespace arith {
Expand All @@ -23,6 +26,16 @@ namespace arith {

using namespace mlir;

static Value createFloatConst(Location loc, Type type, float value,
PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
return rewriter.create<arith::ConstantOp>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update the attr here: attr = DenseElementsAttr::get(shapedTy, attr). It will return the right thing. (Both are fine to me).

loc, DenseElementsAttr::get(shapedTy, attr));
}
return rewriter.create<arith::ConstantOp>(loc, attr);
}

/// Create an integer or index constant.
static Value createConst(Location loc, Type type, int value,
PatternRewriter &rewriter) {
Expand All @@ -31,7 +44,6 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(shapedTy, attr));
}

return rewriter.create<arith::ConstantOp>(loc, attr);
}

Expand Down Expand Up @@ -409,6 +421,112 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};

struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto inputOperand = op.getIn();
auto scaleOperand = op.getScale();
if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
return rewriter.notifyMatchFailure(
op, "scaling extf is not using scale operand of type f8E8M0FNU");
}
Type resultTy = op.getType();
// extf on scale will essentially create f32 number that is 2^scale and will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why f32? can't resultTy be any float type?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check if resultTy >= Float8E8M0FNU and >= inputType

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, scaled truncation from f32 to f32 is a really weird way to spell division,b ut we might want to verify it away

Copy link
Contributor Author

@umangyadav umangyadav May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why f32? can't resultTy be any float type

Changed comment to better reflect what it's doing.

should we check if resultTy >= Float8E8M0FNU and >= inputType

As part of verification, it checks that output dtype is of larger widhth compared to input.
https://github.com/umangyadav/llvm-project/blob/d1543414578abf95a495b4eb6fe9b6201de8e9f6/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L1460

// also propagate NaNs
Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
rewriter.replaceOp(op, result);
return success();
}
};

struct ScalingTruncFOpConverter
: public OpRewritePattern<arith::ScalingTruncFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto inputOperand = op.getIn();
auto scaleOperand = op.getScale();
if (!llvm::isa<Float8E8M0FNUType>(getElementTypeOrSelf(scaleOperand))) {
return rewriter.notifyMatchFailure(
op, "scaling truncf is not using scale operand of type f8E8M0FNU");
}
auto scaleTy = scaleOperand.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type


Type resultTy = op.getType();
Type resultETy = getElementTypeOrSelf(op.getOut());

Type inputTy = inputOperand.getType();
Type inputETy = getElementTypeOrSelf(inputOperand);

Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
Type f8Ty = cloneToShapedType(resultTy, b.getF8E8M0Type());

if (inputETy.getIntOrFloatBitWidth() < 32) {
inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand);
} else if (inputETy.getIntOrFloatBitWidth() > 32) {
inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
}
inputTy = inputOperand.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could update these to f32Type in the if statements above, but it doesn't matter

inputETy = getElementTypeOrSelf(inputOperand);

// normalize scale by exponent of the max normal value in result type as per
// the OCP MXFP spec
// https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277
const llvm::fltSemantics &resultFltSemantics =
llvm::cast<FloatType>(resultETy).getFloatSemantics();
int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
Value cMaxNormalExponent =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip all this if we're in f32 or higher?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrote using f32.

createConst(op->getLoc(), i32Ty, maxExponent, rewriter);
Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter);
Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter);
Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand);
Value scaleI32 = b.create<arith::ExtSIOp>(i32Ty, scaleI8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an extui. But also, there's no need to go i32 here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I first need to calculate unbiased scale value. I can do that while being in i8.

But then i also need to subtract emax (max exponent of largest normal number in resultant quantized dtype).
That subtraction could underflow or overflow and that needs to be checked and clamped later on. Therefore i require i32

Copy link
Contributor Author

@umangyadav umangyadav May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an extui.

Thanks. Good catch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so, my bigger complaint is that you can simplify the generated code substantially if you just switch on what kind of type you're extending to

That is, f32 requires nothing - that's already a +- 127 situation

Types shorter than f32 will need the subtraction.

... Also, I'm doing to re-read the code but I'm not convinced this should be subtracting the max normalized exponent. Are we sure it isn't "clamp to the exponent range of the type"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Ah, we're subtracting the max exponent of the result type

Which can't lead to overflow

This could be substantially simplified if we just use usub_sat (which we'd need a MLIR Arith op for but that's fairly trivial)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... But also, the code you linked is for quantization

I think it's reasonable to assume that someone implementing quantization will already have done the scale-biasing thing and so we don't need to do it here

Unless we have evidence that the hardware implementations perform the subtraction described here? (We'll probably want to go find the AMD behavior)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and if you're doing usub_sat, you don't need to unbias the exponent.

But also, I'd make sure this is something that other implementors of scaling_truncf implement so we don't get conflicting lowerings

Value unbiasedScale = b.create<arith::SubIOp>(scaleI32, c127);
Value normalizedUnbiasedScale =
b.create<arith::SubIOp>(unbiasedScale, cMaxNormalExponent);
// clamp scale exponent as per spec
// https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
// upper clamp limit of 127 will be mapped to biased value of 255 and will
// be bitcasted to 0xFF in F8E8M0 which will be converted to Float32 NaN
// using arith.extf
Value clampUpperCond = b.create<arith::CmpIOp>(
arith::CmpIPredicate::sgt, normalizedUnbiasedScale, c127);
Value clampLowerCond = b.create<arith::CmpIOp>(
arith::CmpIPredicate::slt, normalizedUnbiasedScale, cNeg127);
Value clampedScale = b.create<arith::SelectOp>(
clampUpperCond, c127,
b.create<arith::SelectOp>(clampLowerCond, cNeg127,
normalizedUnbiasedScale));
Value biasedScale = b.create<arith::AddIOp>(clampedScale, c127);
Value biasedScaleI8 = b.create<arith::TruncIOp>(i8Ty, biasedScale);
Value biasedScaleF8 = b.create<arith::BitcastOp>(f8Ty, biasedScaleI8);
Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, biasedScaleF8);
// flush denorms by checking if exponent part of input operand is zero
// or not.
Value inputExponent = b.create<arith::TruncFOp>(scaleTy, inputOperand);
Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
inputExponentU8);
Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter);
Value flushedInput =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all seems overcomplicated?

This could just be extending the scale to f32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrote using f32. It does simplify things a bit. Thanks

b.create<arith::SelectOp>(cmpCond, inputTyZero, inputOperand);
Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
// propagate rounding mode and fast math attributes
Value resultCast = b.create<arith::TruncFOp>(
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check resultTy <= f32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check resultTy <= f32?

Verify() checks that output width is smaller compared to input.

https://github.com/umangyadav/llvm-project/blob/d1543414578abf95a495b4eb6fe9b6201de8e9f6/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L1587

there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter

No, other arith.truncf are mainly for scales dtype conversion which just operates on exponent and not really affected by rounding mode or fast math.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, verify checks that output width is smaller than input width. But I understand the output of this function is always f32. Then, I wonder if somebody can do input, scale -> f128, result -> f64. Then, it's true that output width < input width and we are still trying to truncate "result" which is f32 into f64. Not sure if I misunderstood something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.
Then, Verify() checks is a strict check. Therefore output_bit_width < input_bit_width.
So this would never really be truncating to f32 resultTy in practice.

But I understand the output of this function is always f32

No, why do you think so ? Output dtype will be whatever user has specified.

Copy link
Contributor

@dhernandez0 dhernandez0 Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, why do you think so ? Output dtype will be whatever user has specified.

I mean result of the function before truncation. result.dtype = f32, right?

In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.

I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide.

rewriter.replaceOp(op, resultCast);
return success();
}
};

struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
Expand All @@ -432,7 +550,9 @@ struct ArithExpandOpsPass
arith::MaximumFOp,
arith::MinimumFOp,
arith::MaxNumFOp,
arith::MinNumFOp
arith::MinNumFOp,
arith::ScalingExtFOp,
arith::ScalingTruncFOp
>();

if (includeBf16) {
Expand Down Expand Up @@ -492,8 +612,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}

void mlir::arith::populateExpandScalingExtTruncPatterns(
RewritePatternSet &patterns) {
patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
patterns.getContext());
}

void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
populateExpandScalingExtTruncPatterns(patterns);
// clang-format off
patterns.add<
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
Expand All @@ -503,7 +630,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
>(patterns.getContext());
// clang-format on
}
2 changes: 2 additions & 0 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//

FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }

FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }

FloatType Builder::getF16Type() { return Float16Type::get(context); }
Expand Down
Loading
Loading