Skip to content

Commit 45e7dba

Browse files
committed
Simplify arith.scaling_truncf to just do division and trunction. Denorm flushign on input should be carried out using specified fastMath flag. Scales are assumed to be normalized and clamped.
1 parent 229f6b8 commit 45e7dba

File tree

3 files changed

+42
-132
lines changed

3 files changed

+42
-132
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,10 @@ def Arith_ScalingTruncFOp
13541354
both scales and the input operand to be of the same shape and, therefore,
13551355
makes the operation elementwise. Scales are usually calculated per block
13561356
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1357+
Users are required to normalize and clamp the scales as necessary before calling
1358+
passing them to this operation. OCP MXFP spec also does the flushing of denorms
1359+
on the input operand, which should be handled during lowering by passing appropriate
1360+
fastMath flag to this operation.
13571361

13581362
If scales are calculated per block where blockSize != 1, scales may require
13591363
broadcasting to make this operation elementwise. For example, let's say the
@@ -1369,23 +1373,17 @@ def Arith_ScalingTruncFOp
13691373
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
13701374
that there could be multiple quantization axes. Internally,
13711375
`arith.scaling_truncf` would perform the following:
1372-
1376+
13731377
```
1374-
scaleETy = get_type(scale)
1375-
inputETy = get_type(input)
1376-
resultETy = get_type(result)
1377-
// prepare Scale values with normalization and clamping
1378-
scale.exponent = arith.truncf(scale) : scaleETy to f8E8M0
1379-
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputETy
1380-
// emax is calculated as exponent of the largest normal value in quantized type.
1381-
scale.normalize = arith.divf(scale.extf, emax)
1382-
scale.clamped = clamp(scale.normalize) // clamp underflows
1383-
input.flused = flush_denorms(input)
1384-
result = arith.divf(input.flushed, scale.clamped)
1385-
result.cast = arith.truncf(result, resultETy)
1378+
scaleTy = get_type(scale)
1379+
inputTy = get_type(input)
1380+
resultTy = get_type(result)
1381+
assert(scaleTy.shape() == inputTy.shape() == resultTy.shape())
1382+
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1383+
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
1384+
result = arith.divf(input, scale.extf)
1385+
result.cast = arith.truncf(result, resultTy)
13861386
```
1387-
Flushing of denorms in input and scale normalization with emax is added as per
1388-
the OCP MXFP spec.
13891387
}];
13901388
let hasVerifier = 1;
13911389
let assemblyFormat =

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 24 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
#include "mlir/IR/ImplicitLocOpBuilder.h"
1414
#include "mlir/IR/TypeUtilities.h"
1515
#include "mlir/Transforms/DialectConversion.h"
16-
#include "llvm/ADT/APFloat.h"
17-
#include <cstdint>
1816

1917
namespace mlir {
2018
namespace arith {
@@ -25,16 +23,6 @@ namespace arith {
2523

2624
using namespace mlir;
2725

28-
static Value createFloatConst(Location loc, Type type, float value,
29-
PatternRewriter &rewriter) {
30-
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
31-
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
32-
return rewriter.create<arith::ConstantOp>(
33-
loc, DenseElementsAttr::get(shapedTy, attr));
34-
}
35-
return rewriter.create<arith::ConstantOp>(loc, attr);
36-
}
37-
3826
/// Create an integer or index constant.
3927
static Value createConst(Location loc, Type type, int value,
4028
PatternRewriter &rewriter) {
@@ -368,7 +356,8 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
368356
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
369357
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
370358
if (resultETy.getIntOrFloatBitWidth() < 32) {
371-
result = b.create<arith::TruncFOp>(resultTy, result);
359+
result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
360+
op.getFastmathAttr());
372361
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
373362
result = b.create<arith::ExtFOp>(resultTy, result);
374363
}
@@ -406,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
406395
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
407396

408397
if (operandETy.getIntOrFloatBitWidth() < 32) {
409-
operand = b.create<arith::ExtFOp>(f32Ty, operand);
398+
operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
410399
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
411-
operand = b.create<arith::TruncFOp>(f32Ty, operand);
400+
operand = b.create<arith::TruncFOp>(
401+
f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
412402
}
413403
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
414404
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
@@ -431,7 +421,8 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
431421
// allow implicit exponent extraction from 16/32 bits floats
432422
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
433423
scaleETy = b.getF8E8M0Type();
434-
scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand);
424+
scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand, nullptr,
425+
op.getFastmathAttr());
435426
}
436427
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
437428
return rewriter.notifyMatchFailure(
@@ -441,14 +432,22 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
441432
Type resultTy = op.getType();
442433
// extf on scale will essentially create floating point number
443434
// of type resulTy that is 2^scale and will also propagate NaNs
444-
Value scaleExt = b.create<arith::ExtFOp>(resultTy, scaleOperand);
445-
Value inputExt = b.create<arith::ExtFOp>(resultTy, inputOperand);
446-
Value result = b.create<arith::MulFOp>(inputExt, scaleExt);
435+
Value scaleExt =
436+
b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
437+
Value inputExt =
438+
b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
439+
Value result =
440+
b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
447441
rewriter.replaceOp(op, result);
448442
return success();
449443
}
450444
};
451445

446+
/*
447+
Expands arith.ScalingTruncFOp(in, scale) into
448+
scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
449+
result = arith.truncf(in / (2^scale))
450+
*/
452451
struct ScalingTruncFOpConverter
453452
: public OpRewritePattern<arith::ScalingTruncFOp> {
454453
using OpRewritePattern::OpRewritePattern;
@@ -470,68 +469,14 @@ struct ScalingTruncFOpConverter
470469
op, "scaling_truncf is using scales type which can not be converted "
471470
"to f8E8M0FNU");
472471
}
473-
474472
Type resultTy = op.getType();
475-
Type resultETy = getElementTypeOrSelf(op.getOut());
476-
477473
Type inputTy = inputOperand.getType();
478-
Type inputETy = getElementTypeOrSelf(inputOperand);
479-
480-
Type i8Ty = cloneToShapedType(resultTy, b.getI8Type());
481-
Type i32Ty = cloneToShapedType(resultTy, b.getI32Type());
482-
Type f32Ty = cloneToShapedType(resultTy, b.getF32Type());
483-
484-
if (inputETy.getIntOrFloatBitWidth() < 32) {
485-
inputOperand = b.create<arith::ExtFOp>(f32Ty, inputOperand);
486-
} else if (inputETy.getIntOrFloatBitWidth() > 32) {
487-
inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
488-
}
489-
inputTy = inputOperand.getType();
490-
inputETy = getElementTypeOrSelf(inputOperand);
491-
492-
// normalize scale by exponent of the max normal value (emax) in result type
493-
// as per the OCP MXFP spec
494-
// https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L277
495-
// here this normalization is carried in f32. Therefore instead of
496-
// subtraction it does the DivFOp
497-
const llvm::fltSemantics &resultFltSemantics =
498-
llvm::cast<FloatType>(resultETy).getFloatSemantics();
499-
int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
500-
Value cEmax = createConst(op->getLoc(), i32Ty, maxExponent, rewriter);
501-
Value c1 = createConst(op->getLoc(), i32Ty, 1, rewriter);
502-
Value cPow2 = b.create<arith::ShLIOp>(c1, cEmax);
503-
Value cPow2F32 = b.create<arith::SIToFPOp>(f32Ty, cPow2);
504-
Value scaleF32 = b.create<arith::ExtFOp>(f32Ty, scaleOperand);
505-
// note that spec also does the clamping but it should only be done for
506-
// underflows because dividing by 2^emax will only make it smaller.
507-
// https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L282
508-
Value scaleNormalizedF32 = b.create<arith::DivFOp>(scaleF32, cPow2F32);
509-
// If it has underflown then scale will be a denorm FP32 number after
510-
// division. Clamp underflows to 2^-127 as per the spec implementation
511-
Value scaleNormalizedExponentF8 =
512-
b.create<arith::TruncFOp>(scaleTy, scaleNormalizedF32);
513-
Value scaleNormalizedExponentU8 =
514-
b.create<arith::BitcastOp>(i8Ty, scaleNormalizedExponentF8);
515-
Value cI8Zero = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
516-
Value scaleClampCond = b.create<arith::CmpIOp>(
517-
arith::CmpIPredicate::eq, cI8Zero, scaleNormalizedExponentU8);
518-
// 5.8e-39 is 2^-127, it is a denorm value in f32
519-
float clampValue = 5.87747e-39;
520-
Value scaleClampValue =
521-
createFloatConst(op.getLoc(), f32Ty, clampValue, rewriter);
522-
Value clampedScale = b.create<arith::SelectOp>(
523-
scaleClampCond, scaleClampValue, scaleNormalizedF32);
524-
// flush denorms by checking if exponent part of input operand is zero
525-
// or not.
526-
Value inputExponent = b.create<arith::TruncFOp>(scaleTy, inputOperand);
527-
Value inputExponentU8 = b.create<arith::BitcastOp>(i8Ty, inputExponent);
528-
Value inputFlushCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
529-
cI8Zero, inputExponentU8);
530-
Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter);
531-
Value flushedInput =
532-
b.create<arith::SelectOp>(inputFlushCond, inputTyZero, inputOperand);
533-
Value result = b.create<arith::DivFOp>(flushedInput, clampedScale);
534-
// propagate rounding mode and fast math attributes
474+
// this will create a floating point number of type
475+
// inputTy that is 2^scale and will also propagate NaNs
476+
scaleOperand =
477+
b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
478+
Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
479+
op.getFastmathAttr());
535480
Value resultCast = b.create<arith::TruncFOp>(
536481
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
537482
rewriter.replaceOp(op, resultCast);

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -316,24 +316,8 @@ func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2
316316
}
317317

318318
// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
319-
// SCHECK: %[[C2:.+]] = arith.constant 2 : i32
320-
// SCHECK: %[[C1:.+]] = arith.constant 1 : i32
321-
// SCHECK: %[[EMAX:.+]] = arith.shli %[[C1]], %[[C2]] : i32
322-
// SCHECK: %[[EMAXF32:.+]] = arith.sitofp %[[EMAX]] : i32 to f32
323319
// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
324-
// SCHECK: %[[SCALEDIV:.+]] = arith.divf %[[SCALEF32]], %[[EMAXF32]] : f32
325-
// SCHECK: %[[SCALEDIVF8:.+]] = arith.truncf %[[SCALEDIV]] : f32 to f8E8M0FNU
326-
// SCHECK: %[[SCALEDIVI8:.+]] = arith.bitcast %[[SCALEDIVF8]] : f8E8M0FNU to i8
327-
// SCHECK: %[[C0:.+]] = arith.constant 0 : i8
328-
// SCHECK: %[[UFLOWCOND:.+]] = arith.cmpi eq, %[[C0]], %[[SCALEDIVI8]] : i8
329-
// SCHECK: %[[CLAMPVAL:.+]] = arith.constant 5.877470e-39 : f32
330-
// SCHECK: %[[CLAMP:.+]] = arith.select %[[UFLOWCOND]], %[[CLAMPVAL]], %[[SCALEDIV]] : f32
331-
// SCHECK: %[[INPUTEXP:.+]] = arith.truncf %arg0 : f32 to f8E8M0FNU
332-
// SCHECK: %[[INPUTEXPI8:.+]] = arith.bitcast %[[INPUTEXP]] : f8E8M0FNU to i8
333-
// SCHECK: %[[FLUSHCOND:.+]] = arith.cmpi eq, %[[C0]], %[[INPUTEXPI8]] : i8
334-
// SCHECK: %[[CF0:.+]] = arith.constant 0.000000e+00 : f32
335-
// SCHECK: %[[FLUSHINPUT:.+]] = arith.select %[[FLUSHCOND]], %[[CF0]], %arg0 : f32
336-
// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[CLAMP]] : f32
320+
// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
337321
// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
338322
// SCHECK: return %[[RESULT]]
339323

@@ -345,26 +329,9 @@ func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: v
345329
}
346330

347331
// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
348-
// SCHECK: %[[INPUTF32:.+]] = arith.extf %arg0 : vector<4xf16> to vector<4xf32>
349-
// SCHECK: %[[C2:.+]] = arith.constant dense<4> : vector<4xi32>
350-
// SCHECK: %[[C1:.+]] = arith.constant dense<1> : vector<4xi32>
351-
// SCHECK: %[[EMAX:.+]] = arith.shli %[[C1]], %[[C2]] : vector<4xi32>
352-
// SCHECK: %[[EMAXF32:.+]] = arith.sitofp %[[EMAX]] : vector<4xi32> to vector<4xf32>
353-
// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
354-
// SCHECK: %[[SCALEDIV:.+]] = arith.divf %[[SCALEF32]], %[[EMAXF32]] : vector<4xf32>
355-
// SCHECK: %[[SCALEDIVF8:.+]] = arith.truncf %[[SCALEDIV]] : vector<4xf32> to vector<4xf8E8M0FNU>
356-
// SCHECK: %[[SCALEDIVI8:.+]] = arith.bitcast %[[SCALEDIVF8]] : vector<4xf8E8M0FNU> to vector<4xi8>
357-
// SCHECK: %[[C0:.+]] = arith.constant dense<0> : vector<4xi8>
358-
// SCHECK: %[[UFLOWCOND:.+]] = arith.cmpi eq, %[[C0]], %[[SCALEDIVI8]] : vector<4xi8>
359-
// SCHECK: %[[CLAMPVAL:.+]] = arith.constant dense<5.877470e-39> : vector<4xf32>
360-
// SCHECK: %[[CLAMP:.+]] = arith.select %[[UFLOWCOND]], %[[CLAMPVAL]], %[[SCALEDIV]] : vector<4xi1>, vector<4xf32>
361-
// SCHECK: %[[INPUTEXP:.+]] = arith.truncf %[[INPUTF32]] : vector<4xf32> to vector<4xf8E8M0FNU>
362-
// SCHECK: %[[INPUTEXPI8:.+]] = arith.bitcast %[[INPUTEXP]] : vector<4xf8E8M0FNU> to vector<4xi8>
363-
// SCHECK: %[[FLUSHCOND:.+]] = arith.cmpi eq, %[[C0]], %[[INPUTEXPI8]] : vector<4xi8>
364-
// SCHECK: %[[CF0:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
365-
// SCHECK: %[[FLUSHINPUT:.+]] = arith.select %[[FLUSHCOND]], %[[CF0]], %[[INPUTF32]] : vector<4xi1>, vector<4xf32>
366-
// SCHECK: %[[DIVF:.+]] = arith.divf %[[FLUSHINPUT]], %[[CLAMP]] : vector<4xf32>
367-
// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf32> to vector<4xf6E3M2FN>
332+
// SCHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
333+
// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
334+
// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
368335
// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
369336

370337
// -----
@@ -374,7 +341,7 @@ func.func @scaling_truncf_propagate_rounding_mode(%arg0 : vector<4xf16>, %arg1:
374341
return %0 : vector<4xf6E3M2FN>
375342
}
376343
// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode
377-
// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf32> to vector<4xf6E3M2FN>
344+
// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf16> to vector<4xf6E3M2FN>
378345
// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
379346

380347
// -----

0 commit comments

Comments
 (0)