14
14
// ===----------------------------------------------------------------------===//
15
15
16
16
#include < imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h>
17
+ #include < imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h>
17
18
18
19
#include " ArithOpConversion.h"
19
20
#include " SCFOpConversion.h"
20
21
#include " XeTileOpConversion.h"
21
22
#include " imex/Utils/XeArch.h"
22
23
#include " mlir/IR/BuiltinAttributes.h"
24
+ #include < mlir/Dialect/Math/IR/Math.h>
23
25
24
26
namespace imex {
25
27
@@ -625,6 +627,79 @@ struct SgUpdateTileOffsetOpPattern
625
627
}
626
628
};
627
629
630
+ bool isLegalElementWiseOp (mlir::Operation *op) {
631
+ auto res = op->getResult (0 );
632
+ auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType ());
633
+ if (resType.getRank () != 2 )
634
+ return false ;
635
+ return true ;
636
+ }
637
+
638
+ template <typename Op, int numOperands>
639
+ Op createOp (XeGPUOneToNPatterRewriter &rewriter, mlir::Location loc,
640
+ llvm::SmallVector<llvm::SmallVector<mlir::Value>> operands, int i) {
641
+ static_assert (numOperands >= 1 && numOperands <= 3 ,
642
+ " Unsupported number of operands" );
643
+
644
+ if constexpr (numOperands == 1 ) {
645
+ return rewriter.create <Op>(loc, operands[0 ][i]);
646
+ } else if constexpr (numOperands == 2 ) {
647
+ return rewriter.create <Op>(loc, operands[0 ][i], operands[1 ][i]);
648
+ } else if constexpr (numOperands == 3 ) {
649
+ return rewriter.create <Op>(loc, operands[0 ][i], operands[1 ][i],
650
+ operands[2 ][i]);
651
+ }
652
+ }
653
+
654
+ template <typename Op, int numOperands>
655
+ struct ElementWiseOpPattern : public SgXeTileToXeGPUConversion <Op> {
656
+
657
+ using SgXeTileToXeGPUConversion<Op>::SgXeTileToXeGPUConversion;
658
+ using RangeT = llvm::ArrayRef<mlir::ValueRange>;
659
+ using OpAdaptor = typename Op::template GenericAdaptor<RangeT>;
660
+
661
+ mlir::LogicalResult
662
+ matchAndRewrite (Op op, OpAdaptor adaptor,
663
+ XeGPUOneToNPatterRewriter &rewriter) const override {
664
+
665
+ auto res = op.getResult ();
666
+ auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType ());
667
+ if (!resType || resType.getRank () != 4 ) {
668
+ op.emitOpError () << " type is not 4D vector" ;
669
+ return mlir::failure ();
670
+ }
671
+
672
+ auto shape = resType.getShape ();
673
+ auto newTy =
674
+ mlir::VectorType::get ({shape[2 ], shape[3 ]}, resType.getElementType ());
675
+
676
+ // Get all the slices of Operands
677
+ auto operands = adaptor.getOperands ();
678
+
679
+ llvm::SmallVector<llvm::SmallVector<mlir::Value>> operand;
680
+ if (numOperands == 1 )
681
+ operand.push_back (operands[0 ]);
682
+ else if (numOperands == 2 ) {
683
+ operand.push_back (operands[0 ]);
684
+ operand.push_back (operands[1 ]);
685
+ } else {
686
+ operand.push_back (operands[0 ]);
687
+ operand.push_back (operands[1 ]);
688
+ operand.push_back (operands[2 ]);
689
+ }
690
+
691
+ llvm::SmallVector<mlir::Value> newOps;
692
+ for (int i = 0 ; i < shape[0 ] * shape[1 ]; i++) {
693
+ auto newOp = createOp<Op, numOperands>(rewriter, op.getLoc (), operand, i);
694
+ newOp->getResult (0 ).setType (newTy);
695
+ newOps.push_back (newOp);
696
+ }
697
+
698
+ rewriter.replaceOp (op, newOps);
699
+ return mlir::success ();
700
+ }
701
+ };
702
+
628
703
void populateXeTileOpConversionPatterns (imex::XeGPUTypeConverter &converter,
629
704
mlir::RewritePatternSet &patterns,
630
705
TileUsageAnalysis &analysis) {
@@ -633,6 +708,26 @@ void populateXeTileOpConversionPatterns(imex::XeGPUTypeConverter &converter,
633
708
SgStoreTileOpPattern, SgTileMMAOpPattern,
634
709
SgUpdateTileOffsetOpPattern>(patterns.getContext (), converter,
635
710
analysis);
711
+ patterns.insert <ElementWiseOpPattern<mlir::arith::NegFOp, 1 >,
712
+ ElementWiseOpPattern<mlir::math::ExpOp, 1 >,
713
+ ElementWiseOpPattern<mlir::math::SinOp, 1 >,
714
+ ElementWiseOpPattern<mlir::math::CosOp, 1 >,
715
+ ElementWiseOpPattern<mlir::math::SqrtOp, 1 >,
716
+ ElementWiseOpPattern<mlir::math::TanhOp, 1 >,
717
+ ElementWiseOpPattern<mlir::math::LogOp, 1 >,
718
+ ElementWiseOpPattern<mlir::math::RsqrtOp, 1 >,
719
+ ElementWiseOpPattern<mlir::math::ErfOp, 1 >,
720
+ ElementWiseOpPattern<mlir::arith::AddFOp, 2 >,
721
+ ElementWiseOpPattern<mlir::arith::RemFOp, 2 >,
722
+ ElementWiseOpPattern<mlir::arith::DivFOp, 2 >,
723
+ ElementWiseOpPattern<mlir::arith::MulFOp, 2 >,
724
+ ElementWiseOpPattern<mlir::arith::MaximumFOp, 2 >,
725
+ ElementWiseOpPattern<mlir::arith::MinimumFOp, 2 >,
726
+ ElementWiseOpPattern<mlir::arith::SubFOp, 2 >,
727
+ ElementWiseOpPattern<mlir::arith::XOrIOp, 2 >,
728
+ ElementWiseOpPattern<mlir::math::PowFOp, 2 >,
729
+ ElementWiseOpPattern<mlir::arith::SelectOp, 3 >>(
730
+ patterns.getContext (), converter, analysis);
636
731
}
637
732
638
733
} // namespace imex
0 commit comments