Skip to content

Commit 062a982

Browse files
committed
Add arith expansion of f8E8M0 type for extf/trunc ops
1 parent 83de1ef commit 062a982

File tree

6 files changed

+265
-11
lines changed

6 files changed

+265
-11
lines changed

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
5959
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
6060
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
6161

62+
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
63+
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
64+
6265
/// Add patterns to expand Arith ops.
6366
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
6467

mlir/include/mlir/Dialect/Arith/Transforms/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def ArithExpandOpsPass : Pass<"arith-expand"> {
1717
let options = [
1818
Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
1919
"Enable the BF16 expansion patterns">,
20+
Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
21+
"Enable the F8E8M0 expansion patterns">,
2022
];
2123
}
2224

mlir/include/mlir/IR/Types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class Type {
109109
// Convenience predicates. This is only for floating point types,
110110
// derived types should use isa/dyn_cast.
111111
bool isIndex() const;
112+
bool isF8E8M0FNU() const;
112113
bool isBF16() const;
113114
bool isF16() const;
114115
bool isTF32() const;

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 129 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
291291
// Constant used to make the rounding bias.
292292
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
293293
// Constant used to generate a quiet NaN.
294-
Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
294+
Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
295295
// Small constants used to address bits.
296296
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
297297
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
@@ -313,18 +313,120 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
313313
// Now that the rounding-bias has been added, truncating the low bits
314314
// yields the correctly rounded result.
315315
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
316-
Value normalCaseResult_i16 =
316+
Value normalCaseResultI16 =
317317
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
318318
// Select either the above-computed result, or a quiet NaN constant
319319
// if the input was NaN.
320320
Value select =
321-
b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
321+
b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
322322
Value result = b.create<arith::BitcastOp>(resultTy, select);
323323
rewriter.replaceOp(op, result);
324324
return success();
325325
}
326326
};
327327

328+
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
329+
using OpRewritePattern::OpRewritePattern;
330+
LogicalResult matchAndRewrite(arith::ExtFOp op,
331+
PatternRewriter &rewriter) const final {
332+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
333+
auto operand = op.getOperand();
334+
Type operandTy = operand.getType();
335+
Type resultTy = op.getType();
336+
Type operandETy = getElementTypeOrSelf(operandTy);
337+
Type resultETy = getElementTypeOrSelf(resultTy);
338+
339+
if (!operandETy.isF8E8M0FNU()) {
340+
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
341+
}
342+
343+
if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
344+
return rewriter.notifyMatchFailure(
345+
op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
346+
}
347+
348+
Type i8Ty = b.getI8Type();
349+
Type i32Ty = b.getI32Type();
350+
Type f32Ty = b.getF32Type();
351+
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
352+
i8Ty = shapedTy.clone(i8Ty);
353+
i32Ty = shapedTy.clone(i32Ty);
354+
f32Ty = shapedTy.clone(f32Ty);
355+
}
356+
357+
Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
358+
// create constants for NaNs
359+
Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
360+
Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
361+
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
362+
363+
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
364+
Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
365+
366+
Value isNan =
367+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
368+
// select for NaNs
369+
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
370+
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
371+
if (resultETy.isBF16()) {
372+
result = b.create<arith::TruncFOp>(resultTy, result);
373+
} else if (resultETy.isF16()) {
374+
result = b.create<arith::TruncFOp>(resultTy, result);
375+
}
376+
rewriter.replaceOp(op, result);
377+
return success();
378+
}
379+
};
380+
381+
/*
382+
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
383+
Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
384+
they all map to NaN in F8E8M0 Type.
385+
*/
386+
struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
387+
using OpRewritePattern::OpRewritePattern;
388+
LogicalResult matchAndRewrite(arith::TruncFOp op,
389+
PatternRewriter &rewriter) const final {
390+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
391+
auto operand = op.getOperand();
392+
Type operandTy = operand.getType();
393+
Type operandETy = getElementTypeOrSelf(operandTy);
394+
Type resultTy = op.getType();
395+
Type resultETy = getElementTypeOrSelf(resultTy);
396+
if (!resultETy.isF8E8M0FNU()) {
397+
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
398+
}
399+
if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
400+
return rewriter.notifyMatchFailure(
401+
op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
402+
}
403+
404+
if (op.getRoundingmodeAttr()) {
405+
return rewriter.notifyMatchFailure(
406+
op, "only applicable to default rounding mode.");
407+
}
408+
409+
Type i8Ty = b.getI8Type();
410+
Type i32Ty = b.getI32Type();
411+
Type f32Ty = b.getF32Type();
412+
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
413+
i8Ty = shapedTy.clone(i8Ty);
414+
i32Ty = shapedTy.clone(i32Ty);
415+
f32Ty = shapedTy.clone(f32Ty);
416+
}
417+
if (!operandETy.isF32()) {
418+
operand = b.create<arith::ExtFOp>(f32Ty, operand);
419+
}
420+
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
421+
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
422+
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
423+
Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
424+
Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
425+
rewriter.replaceOp(op, result);
426+
return success();
427+
}
428+
};
429+
328430
struct ArithExpandOpsPass
329431
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
330432
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -351,23 +453,36 @@ struct ArithExpandOpsPass
351453
arith::MinNumFOp
352454
>();
353455

354-
if (includeBf16) {
456+
if(includeBf16) {
355457
arith::populateExpandBFloat16Patterns(patterns);
458+
}
459+
if(includeF8E8M0) {
460+
arith::populateExpandF8E8M0Patterns(patterns);
461+
}
462+
if (includeBf16 || includeF8E8M0) {
356463
target.addDynamicallyLegalOp<arith::ExtFOp>(
357-
[](arith::ExtFOp op) {
464+
[=](arith::ExtFOp op) {
358465
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
359466
Type outETy = getElementTypeOrSelf(op.getType());
360-
return !(inETy.isBF16() && outETy.isF32());
467+
if(includeBf16 && includeF8E8M0)
468+
return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
469+
if(includeBf16)
470+
return !(inETy.isBF16() && outETy.isF32());
471+
return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
361472
});
362473

363474
target.addDynamicallyLegalOp<arith::TruncFOp>(
364-
[](arith::TruncFOp op) {
475+
[=](arith::TruncFOp op) {
365476
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
366477
Type outETy = getElementTypeOrSelf(op.getType());
367-
return !(inETy.isF32() && outETy.isBF16());
478+
if(includeBf16 && includeF8E8M0)
479+
return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
480+
if(includeBf16)
481+
return !(inETy.isF32() && outETy.isBF16());
482+
return
483+
!(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
368484
});
369485
}
370-
371486
// clang-format on
372487
if (failed(applyPartialConversion(getOperation(), target,
373488
std::move(patterns))))
@@ -389,6 +504,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
389504
patterns.getContext());
390505
}
391506

507+
void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
508+
patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
509+
patterns.getContext());
510+
}
511+
392512
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
393513
populateCeilFloorDivExpandOpsPatterns(patterns);
394514
// clang-format off

mlir/lib/IR/Types.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
3333
//===----------------------------------------------------------------------===//
3434

3535
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
36-
36+
bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
3737
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
3838
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
3939
bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
22

33
// Test ceil divide with signed integer
44
// CHECK-LABEL: func @ceildivi
@@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
248248
// CHECK-LABEL: @truncf_vector_f32
249249
// CHECK-NOT: arith.truncf
250250

251+
// -----
252+
func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
253+
%0 = arith.truncf %arg0 : f32 to f8E8M0FNU
254+
return %0 : f8E8M0FNU
255+
}
256+
// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
257+
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
258+
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
259+
// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
260+
// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
261+
// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
262+
// CHECK: return %[[RESULT]]
263+
264+
// -----
265+
266+
func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
267+
%0 = arith.truncf %arg0 : f16 to f8E8M0FNU
268+
return %0 : f8E8M0FNU
269+
}
270+
// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
271+
// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
272+
// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
273+
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
274+
// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
275+
// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
276+
// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
277+
// CHECK: return %[[RESULT]]
278+
279+
// -----
280+
281+
func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
282+
%0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
283+
return %0 : vector<4xf8E8M0FNU>
284+
}
285+
286+
// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
287+
// CHECK-NOT: arith.truncf
288+
289+
// -----
290+
291+
func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
292+
%0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
293+
return %0 : vector<4xf8E8M0FNU>
294+
}
295+
296+
// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
297+
// CHECK-NOT: arith.truncf
298+
299+
// -----
300+
301+
func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
302+
%0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
303+
return %0 : vector<4xf8E8M0FNU>
304+
}
305+
306+
// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
307+
// CHECK-NOT: arith.truncf
308+
309+
310+
// -----
311+
func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
312+
%0 = arith.extf %arg0 : f8E8M0FNU to f32
313+
return %0 : f32
314+
}
315+
316+
// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
317+
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
318+
// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
319+
// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
320+
// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
321+
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
322+
// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
323+
// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
324+
// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
325+
// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
326+
// CHECK: return %[[RESULT]]
327+
328+
// -----
329+
330+
func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
331+
%0 = arith.extf %arg0 : f8E8M0FNU to f16
332+
return %0 : f16
333+
}
334+
335+
// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
336+
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
337+
// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
338+
// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
339+
// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
340+
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
341+
// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
342+
// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
343+
// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
344+
// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
345+
// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
346+
// CHECK: return %[[F16_RESULT]]
347+
348+
// -----
349+
350+
func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
351+
%0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
352+
return %0 : vector<4xf32>
353+
}
354+
355+
// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
356+
// CHECK-NOT: arith.extf
357+
358+
// -----
359+
360+
func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
361+
%0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
362+
return %0 : vector<4xf16>
363+
}
364+
365+
// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
366+
// CHECK-NOT: arith.extf
367+
368+
// -----
369+
370+
func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
371+
%0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
372+
return %0 : vector<4xbf16>
373+
}
374+
375+
// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
376+
// CHECK-NOT: arith.extf
377+
378+
251379
// -----
252380

253381
func.func @maxsi(%a: i32, %b: i32) -> i32 {

0 commit comments

Comments
 (0)