Skip to content

Commit 607d6b1

Browse files
Adding lookup implementation for arith.extf + formatting fixes
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent b3c4977 commit 607d6b1

File tree

2 files changed

+93
-39
lines changed

2 files changed

+93
-39
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 90 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/IR/ImplicitLocOpBuilder.h"
1414
#include "mlir/IR/TypeUtilities.h"
1515
#include "mlir/Transforms/DialectConversion.h"
16+
#include "llvm/ADT/SmallVectorExtras.h"
1617

1718
namespace mlir {
1819
namespace arith {
@@ -240,9 +241,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
240241
Type operandETy = getElementTypeOrSelf(operandTy);
241242
Type resultETy = getElementTypeOrSelf(resultTy);
242243

243-
if (!operandETy.isBF16() || !resultETy.isF32()) {
244+
if (!operandETy.isBF16() || !resultETy.isF32())
244245
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
245-
}
246246

247247
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
248248
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -270,9 +270,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
270270
Type operandETy = getElementTypeOrSelf(operandTy);
271271
Type resultETy = getElementTypeOrSelf(resultTy);
272272

273-
if (!operandETy.isF32() || !resultETy.isBF16()) {
273+
if (!operandETy.isF32() || !resultETy.isBF16())
274274
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
275-
}
276275

277276
if (op.getRoundingmodeAttr()) {
278277
return rewriter.notifyMatchFailure(
@@ -336,6 +335,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
336335

337336
struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
338337
using OpRewritePattern::OpRewritePattern;
338+
F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
339+
: OpRewritePattern<arith::ExtFOp>(context, benefit) {}
339340
LogicalResult matchAndRewrite(arith::ExtFOp op,
340341
PatternRewriter &rewriter) const final {
341342
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
@@ -402,6 +403,71 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
402403
}
403404
};
404405

406+
struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
407+
using OpRewritePattern::OpRewritePattern;
408+
ScalarF4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 2)
409+
: OpRewritePattern<arith::ExtFOp>(context, benefit) {}
410+
LogicalResult matchAndRewrite(arith::ExtFOp op,
411+
PatternRewriter &rewriter) const final {
412+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
413+
Value operand = op.getOperand();
414+
Type operandTy = operand.getType();
415+
Type resultTy = op.getType();
416+
Type operandETy = getElementTypeOrSelf(operandTy);
417+
Type resultETy = getElementTypeOrSelf(resultTy);
418+
419+
if (isa<ShapedType>(operandTy))
420+
return failure();
421+
422+
if (!isa<Float4E2M1FNType>(operandETy))
423+
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
424+
425+
SmallVector<int> values = {
426+
0x00000000, // 0.0
427+
0x3f000000, // 0.5
428+
0x3f800000, // 1.0
429+
0x3fc00000, // 1.5
430+
0x40000000, // 2.0
431+
0x40400000, // 3.0
432+
0x40800000, // 4.0
433+
0x40c00000 // 6.0
434+
};
435+
// auto type = RankedTensorType::get({8}, b.getI32Type());
436+
VectorType type = VectorType::get({8}, b.getI32Type());
437+
SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector(
438+
values, [&](int v) -> Attribute { return b.getI32IntegerAttr(v); });
439+
Value lookupTable = b.create<arith::ConstantOp>(
440+
DenseIntElementsAttr::get(type, lookupTableAttr));
441+
442+
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
443+
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
444+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
445+
Type i64Ty = cloneToShapedType(operandTy, b.getI64Type());
446+
447+
Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
448+
449+
Value expManBitmask = createConst(op.getLoc(), i4Ty, 0x7, rewriter);
450+
Value indexI4 = b.create<arith::AndIOp>(i4Bits, expManBitmask);
451+
Value indexI64 = b.create<arith::ExtUIOp>(i64Ty, indexI4);
452+
Value index = b.create<arith::IndexCastOp>(b.getIndexType(), indexI64);
453+
454+
Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
455+
Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
456+
Value signBitI4 = b.create<arith::AndIOp>(i4Bits, signBitmask);
457+
Value signBitI32 = b.create<arith::ExtUIOp>(i32Ty, signBitI4);
458+
signBitI32 = b.create<arith::ShLIOp>(signBitI32, c0x0000001c);
459+
460+
Value unsignedBits = b.create<vector::ExtractOp>(lookupTable, index);
461+
Value f32Bits = b.create<arith::OrIOp>(signBitI32, unsignedBits);
462+
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
463+
if (!isa<Float32Type>(resultETy))
464+
result = b.create<arith::TruncFOp>(resultETy, operand);
465+
466+
rewriter.replaceOp(op, result);
467+
return success();
468+
}
469+
};
470+
405471
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
406472
using OpRewritePattern::OpRewritePattern;
407473
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -413,9 +479,8 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
413479
Type operandETy = getElementTypeOrSelf(operandTy);
414480
Type resultETy = getElementTypeOrSelf(resultTy);
415481

416-
if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
482+
if (!llvm::isa<Float8E8M0FNUType>(operandETy))
417483
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
418-
}
419484

420485
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
421486
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -485,12 +550,10 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
485550
Type operandETy = getElementTypeOrSelf(operandTy);
486551
Type resultETy = getElementTypeOrSelf(resultTy);
487552

488-
if (!isa<Float32Type>(operandETy)) {
553+
if (!isa<Float32Type>(operandETy))
489554
operand = b.create<arith::ExtFOp>(b.getF32Type(), operand);
490-
}
491-
if (!isa<Float4E2M1FNType>(resultETy)) {
555+
if (!isa<Float4E2M1FNType>(resultETy))
492556
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
493-
}
494557

495558
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
496559
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
@@ -509,8 +572,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
509572
createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
510573
Value cLowerBound =
511574
createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
512-
Value operandClamped = b.create<arith::MinimumFOp>(cHigherBound, operand);
513-
operandClamped = b.create<arith::MaximumFOp>(cLowerBound, operandClamped);
575+
Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
576+
operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
514577
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
515578

516579
// Step 1: Set sign bit.
@@ -594,14 +657,12 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
594657
Type operandETy = getElementTypeOrSelf(operandTy);
595658
Type resultTy = op.getType();
596659
Type resultETy = getElementTypeOrSelf(resultTy);
597-
if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
660+
if (!llvm::isa<Float8E8M0FNUType>(resultETy))
598661
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
599-
}
600662

601-
if (op.getRoundingmodeAttr()) {
663+
if (op.getRoundingmodeAttr())
602664
return rewriter.notifyMatchFailure(
603665
op, "only applicable to default rounding mode.");
604-
}
605666

606667
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
607668
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -711,6 +772,8 @@ struct ArithExpandOpsPass
711772
arith::populateArithExpandOpsPatterns(patterns);
712773

713774
target.addLegalDialect<arith::ArithDialect>();
775+
target.addLegalDialect<vector::VectorDialect>();
776+
714777
// clang-format off
715778
target.addIllegalOp<
716779
arith::CeilDivSIOp,
@@ -728,30 +791,24 @@ struct ArithExpandOpsPass
728791
arith::ScalingTruncFOp
729792
>();
730793

731-
if (includeBf16) {
794+
if (includeBf16)
732795
arith::populateExpandBFloat16Patterns(patterns);
733-
}
734-
if (includeF8E8M0) {
796+
if (includeF8E8M0)
735797
arith::populateExpandF8E8M0Patterns(patterns);
736-
}
737-
if (includeF4E2M1) {
798+
if (includeF4E2M1)
738799
arith::populateExpandF4E2M1Patterns(patterns);
739-
}
740800

741801
target.addDynamicallyLegalOp<arith::ExtFOp>(
742802
[=](arith::ExtFOp op) {
743803
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
744804
Type outETy = getElementTypeOrSelf(op.getType());
745805
bool legalTypes = true;
746-
if (includeBf16) {
806+
if (includeBf16)
747807
legalTypes &= !(inETy.isBF16() && outETy.isF32());
748-
}
749-
if (includeF8E8M0) {
808+
if (includeF8E8M0)
750809
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
751-
}
752-
if (includeF4E2M1) {
810+
if (includeF4E2M1)
753811
legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
754-
}
755812
return legalTypes;
756813
});
757814

@@ -760,15 +817,12 @@ struct ArithExpandOpsPass
760817
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
761818
Type outETy = getElementTypeOrSelf(op.getType());
762819
bool legalTypes = true;
763-
if (includeBf16) {
820+
if (includeBf16)
764821
legalTypes &= !(inETy.isF32() && outETy.isBF16());
765-
}
766-
if (includeF8E8M0) {
822+
if (includeF8E8M0)
767823
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
768-
}
769-
if (includeF4E2M1) {
824+
if (includeF4E2M1)
770825
legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
771-
}
772826
return legalTypes;
773827
});
774828

@@ -794,8 +848,8 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
794848
}
795849

796850
void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
797-
patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
798-
patterns.getContext());
851+
patterns.add<F4E2M1ExtFOpConverter, ScalarF4E2M1ExtFOpConverter,
852+
F4E2M1TruncFOpConverter>(patterns.getContext());
799853
}
800854

801855
void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {

mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
// Check various edge cases for truncf/extf ops involving f32 and f4e2m1 types.
22

3-
// RUN: mlir-opt %s --convert-vector-to-llvm \
4-
// RUN: --convert-func-to-llvm \
3+
// RUN: mlir-opt %s --convert-func-to-llvm \
54
// RUN: --arith-expand="include-f4e2m1=true" \
6-
// RUN: --convert-arith-to-llvm -reconcile-unrealized-casts | \
5+
// RUN: --convert-arith-to-llvm --convert-vector-to-llvm \
6+
// RUN: --reconcile-unrealized-casts | \
77
// RUN: mlir-runner -e entry --entry-point-result=void \
88
// RUN: --shared-libs=%mlir_c_runner_utils | \
99
// RUN: FileCheck %s --match-full-lines

0 commit comments

Comments
 (0)