1313#include " mlir/IR/ImplicitLocOpBuilder.h"
1414#include " mlir/IR/TypeUtilities.h"
1515#include " mlir/Transforms/DialectConversion.h"
16+ #include " llvm/ADT/SmallVectorExtras.h"
1617
1718namespace mlir {
1819namespace arith {
@@ -240,9 +241,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
240241 Type operandETy = getElementTypeOrSelf (operandTy);
241242 Type resultETy = getElementTypeOrSelf (resultTy);
242243
243- if (!operandETy.isBF16 () || !resultETy.isF32 ()) {
244+ if (!operandETy.isBF16 () || !resultETy.isF32 ())
244245 return rewriter.notifyMatchFailure (op, " not a ext of bf16 to f32." );
245- }
246246
247247 Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
248248 Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
@@ -270,9 +270,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
270270 Type operandETy = getElementTypeOrSelf (operandTy);
271271 Type resultETy = getElementTypeOrSelf (resultTy);
272272
273- if (!operandETy.isF32 () || !resultETy.isBF16 ()) {
273+ if (!operandETy.isF32 () || !resultETy.isBF16 ())
274274 return rewriter.notifyMatchFailure (op, " not a trunc of f32 to bf16." );
275- }
276275
277276 if (op.getRoundingmodeAttr ()) {
278277 return rewriter.notifyMatchFailure (
@@ -336,6 +335,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
336335
337336struct F4E2M1ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
338337 using OpRewritePattern::OpRewritePattern;
338+ F4E2M1ExtFOpConverter (MLIRContext *context, PatternBenefit benefit = 1 )
339+ : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
339340 LogicalResult matchAndRewrite (arith::ExtFOp op,
340341 PatternRewriter &rewriter) const final {
341342 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
@@ -402,6 +403,71 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
402403 }
403404};
404405
406+ struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
407+ using OpRewritePattern::OpRewritePattern;
408+ ScalarF4E2M1ExtFOpConverter (MLIRContext *context, PatternBenefit benefit = 2 )
409+ : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
410+ LogicalResult matchAndRewrite (arith::ExtFOp op,
411+ PatternRewriter &rewriter) const final {
412+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
413+ Value operand = op.getOperand ();
414+ Type operandTy = operand.getType ();
415+ Type resultTy = op.getType ();
416+ Type operandETy = getElementTypeOrSelf (operandTy);
417+ Type resultETy = getElementTypeOrSelf (resultTy);
418+
419+ if (isa<ShapedType>(operandTy))
420+ return failure ();
421+
422+ if (!isa<Float4E2M1FNType>(operandETy))
423+ return rewriter.notifyMatchFailure (op, " not a ext of F4E2M1FN" );
424+
425+ SmallVector<int > values = {
426+ 0x00000000 , // 0.0
427+ 0x3f000000 , // 0.5
428+ 0x3f800000 , // 1.0
429+ 0x3fc00000 , // 1.5
430+ 0x40000000 , // 2.0
431+ 0x40400000 , // 3.0
432+ 0x40800000 , // 4.0
433+ 0x40c00000 // 6.0
434+ };
435+ // auto type = RankedTensorType::get({8}, b.getI32Type());
436+ VectorType type = VectorType::get ({8 }, b.getI32Type ());
437+ SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector (
438+ values, [&](int v) -> Attribute { return b.getI32IntegerAttr (v); });
439+ Value lookupTable = b.create <arith::ConstantOp>(
440+ DenseIntElementsAttr::get (type, lookupTableAttr));
441+
442+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
443+ Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
444+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
445+ Type i64Ty = cloneToShapedType (operandTy, b.getI64Type ());
446+
447+ Value i4Bits = b.create <arith::BitcastOp>(i4Ty, operand);
448+
449+ Value expManBitmask = createConst (op.getLoc (), i4Ty, 0x7 , rewriter);
450+ Value indexI4 = b.create <arith::AndIOp>(i4Bits, expManBitmask);
451+ Value indexI64 = b.create <arith::ExtUIOp>(i64Ty, indexI4);
452+ Value index = b.create <arith::IndexCastOp>(b.getIndexType (), indexI64);
453+
454+ Value c0x0000001c = createConst (op->getLoc (), i32Ty, 28 , rewriter);
455+ Value signBitmask = createConst (op.getLoc (), i4Ty, 0x8 , rewriter);
456+ Value signBitI4 = b.create <arith::AndIOp>(i4Bits, signBitmask);
457+ Value signBitI32 = b.create <arith::ExtUIOp>(i32Ty, signBitI4);
458+ signBitI32 = b.create <arith::ShLIOp>(signBitI32, c0x0000001c);
459+
460+ Value unsignedBits = b.create <vector::ExtractOp>(lookupTable, index);
461+ Value f32Bits = b.create <arith::OrIOp>(signBitI32, unsignedBits);
462+ Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
463+ if (!isa<Float32Type>(resultETy))
464+ result = b.create <arith::TruncFOp>(resultETy, operand);
465+
466+ rewriter.replaceOp (op, result);
467+ return success ();
468+ }
469+ };
470+
405471struct F8E8M0ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
406472 using OpRewritePattern::OpRewritePattern;
407473 LogicalResult matchAndRewrite (arith::ExtFOp op,
@@ -413,9 +479,8 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
413479 Type operandETy = getElementTypeOrSelf (operandTy);
414480 Type resultETy = getElementTypeOrSelf (resultTy);
415481
416- if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
482+ if (!llvm::isa<Float8E8M0FNUType>(operandETy))
417483 return rewriter.notifyMatchFailure (op, " not a ext of F8E8M0FNU" );
418- }
419484
420485 Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
421486 Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
@@ -485,12 +550,10 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
485550 Type operandETy = getElementTypeOrSelf (operandTy);
486551 Type resultETy = getElementTypeOrSelf (resultTy);
487552
488- if (!isa<Float32Type>(operandETy)) {
553+ if (!isa<Float32Type>(operandETy))
489554 operand = b.create <arith::ExtFOp>(b.getF32Type (), operand);
490- }
491- if (!isa<Float4E2M1FNType>(resultETy)) {
555+ if (!isa<Float4E2M1FNType>(resultETy))
492556 return rewriter.notifyMatchFailure (op, " not a trunc of F4E2M1FN" );
493- }
494557
495558 Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
496559 Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
@@ -509,8 +572,8 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
509572 createFloatConst (op->getLoc (), f32Ty, APFloat (6 .0f ), rewriter);
510573 Value cLowerBound =
511574 createFloatConst (op->getLoc (), f32Ty, APFloat (-6 .0f ), rewriter);
512- Value operandClamped = b.create <arith::MinimumFOp >(cHigherBound, operand);
513- operandClamped = b.create <arith::MaximumFOp >(cLowerBound, operandClamped);
575+ Value operandClamped = b.create <arith::MinNumFOp >(cHigherBound, operand);
576+ operandClamped = b.create <arith::MaxNumFOp >(cLowerBound, operandClamped);
514577 Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operandClamped);
515578
516579 // Step 1: Set sign bit.
@@ -594,14 +657,12 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
594657 Type operandETy = getElementTypeOrSelf (operandTy);
595658 Type resultTy = op.getType ();
596659 Type resultETy = getElementTypeOrSelf (resultTy);
597- if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
660+ if (!llvm::isa<Float8E8M0FNUType>(resultETy))
598661 return rewriter.notifyMatchFailure (op, " not a truncf to f8E8M0FNU" );
599- }
600662
601- if (op.getRoundingmodeAttr ()) {
663+ if (op.getRoundingmodeAttr ())
602664 return rewriter.notifyMatchFailure (
603665 op, " only applicable to default rounding mode." );
604- }
605666
606667 Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
607668 Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
@@ -711,6 +772,8 @@ struct ArithExpandOpsPass
711772 arith::populateArithExpandOpsPatterns (patterns);
712773
713774 target.addLegalDialect <arith::ArithDialect>();
775+ target.addLegalDialect <vector::VectorDialect>();
776+
714777 // clang-format off
715778 target.addIllegalOp <
716779 arith::CeilDivSIOp,
@@ -728,30 +791,24 @@ struct ArithExpandOpsPass
728791 arith::ScalingTruncFOp
729792 >();
730793
731- if (includeBf16) {
794+ if (includeBf16)
732795 arith::populateExpandBFloat16Patterns (patterns);
733- }
734- if (includeF8E8M0) {
796+ if (includeF8E8M0)
735797 arith::populateExpandF8E8M0Patterns (patterns);
736- }
737- if (includeF4E2M1) {
798+ if (includeF4E2M1)
738799 arith::populateExpandF4E2M1Patterns (patterns);
739- }
740800
741801 target.addDynamicallyLegalOp <arith::ExtFOp>(
742802 [=](arith::ExtFOp op) {
743803 Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
744804 Type outETy = getElementTypeOrSelf (op.getType ());
745805 bool legalTypes = true ;
746- if (includeBf16) {
806+ if (includeBf16)
747807 legalTypes &= !(inETy.isBF16 () && outETy.isF32 ());
748- }
749- if (includeF8E8M0) {
808+ if (includeF8E8M0)
750809 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
751- }
752- if (includeF4E2M1) {
810+ if (includeF4E2M1)
753811 legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
754- }
755812 return legalTypes;
756813 });
757814
@@ -760,15 +817,12 @@ struct ArithExpandOpsPass
760817 Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
761818 Type outETy = getElementTypeOrSelf (op.getType ());
762819 bool legalTypes = true ;
763- if (includeBf16) {
820+ if (includeBf16)
764821 legalTypes &= !(inETy.isF32 () && outETy.isBF16 ());
765- }
766- if (includeF8E8M0) {
822+ if (includeF8E8M0)
767823 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
768- }
769- if (includeF4E2M1) {
824+ if (includeF4E2M1)
770825 legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
771- }
772826 return legalTypes;
773827 });
774828
@@ -794,8 +848,8 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
794848}
795849
796850void mlir::arith::populateExpandF4E2M1Patterns (RewritePatternSet &patterns) {
797- patterns.add <F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
798- patterns.getContext ());
851+ patterns.add <F4E2M1ExtFOpConverter, ScalarF4E2M1ExtFOpConverter,
852+ F4E2M1TruncFOpConverter>( patterns.getContext ());
799853}
800854
801855void mlir::arith::populateExpandF8E8M0Patterns (RewritePatternSet &patterns) {
0 commit comments