Skip to content

Commit 7b5d13c

Browse files
authored
Add pattern for elementwise op lowering from XeTile to XeGPU (#724)
Squashed commit of the following: commit 67bea40e3ff9a04cd09165c7a8ba5addc8db3e7c Author: Nishant Patel <[email protected]> Date: Thu Apr 18 09:23:30 2024 -0700 PR feedback commit d00dfe00572bb02a7824ae8379ab29a7fa16702d Merge: 5fadde9b 255c28e Author: Nishant Patel <[email protected]> Date: Wed Apr 17 00:05:05 2024 -0700 Merge branch 'main' into nishant_ElementWiseOpLowering commit 5fadde9b498dac7a944af89c1e897c67674bd741 Author: Nishant Patel <[email protected]> Date: Tue Apr 16 22:47:08 2024 -0700 unify the pattern and add test case commit 1352deb5b0ce7dc2855767828a91618981be0b3f Author: Nishant Patel <[email protected]> Date: Thu Apr 11 11:36:35 2024 -0700 Add more tests commit 1f1c2ae40cf50f72007fd1a2267c8d8b0e27b06d Author: Nishant Patel <[email protected]> Date: Wed Apr 3 18:03:59 2024 -0700 Add pattern for elementwise op lowering from XeTile to XeGPU
1 parent b160c49 commit 7b5d13c

File tree

5 files changed

+442
-0
lines changed

5 files changed

+442
-0
lines changed

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
//===----------------------------------------------------------------------===//
1515

1616
#include <imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h>
17+
#include <imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h>
1718

1819
#include "ArithOpConversion.h"
1920
#include "SCFOpConversion.h"
2021
#include "XeTileOpConversion.h"
2122
#include "imex/Utils/XeArch.h"
2223
#include "mlir/IR/BuiltinAttributes.h"
24+
#include <mlir/Dialect/Math/IR/Math.h>
2325

2426
namespace imex {
2527

@@ -625,6 +627,79 @@ struct SgUpdateTileOffsetOpPattern
625627
}
626628
};
627629

630+
bool isLegalElementWiseOp(mlir::Operation *op) {
631+
auto res = op->getResult(0);
632+
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
633+
if (resType.getRank() != 2)
634+
return false;
635+
return true;
636+
}
637+
638+
template <typename Op, int numOperands>
639+
Op createOp(XeGPUOneToNPatterRewriter &rewriter, mlir::Location loc,
640+
llvm::SmallVector<llvm::SmallVector<mlir::Value>> operands, int i) {
641+
static_assert(numOperands >= 1 && numOperands <= 3,
642+
"Unsupported number of operands");
643+
644+
if constexpr (numOperands == 1) {
645+
return rewriter.create<Op>(loc, operands[0][i]);
646+
} else if constexpr (numOperands == 2) {
647+
return rewriter.create<Op>(loc, operands[0][i], operands[1][i]);
648+
} else if constexpr (numOperands == 3) {
649+
return rewriter.create<Op>(loc, operands[0][i], operands[1][i],
650+
operands[2][i]);
651+
}
652+
}
653+
654+
template <typename Op, int numOperands>
655+
struct ElementWiseOpPattern : public SgXeTileToXeGPUConversion<Op> {
656+
657+
using SgXeTileToXeGPUConversion<Op>::SgXeTileToXeGPUConversion;
658+
using RangeT = llvm::ArrayRef<mlir::ValueRange>;
659+
using OpAdaptor = typename Op::template GenericAdaptor<RangeT>;
660+
661+
mlir::LogicalResult
662+
matchAndRewrite(Op op, OpAdaptor adaptor,
663+
XeGPUOneToNPatterRewriter &rewriter) const override {
664+
665+
auto res = op.getResult();
666+
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
667+
if (!resType || resType.getRank() != 4) {
668+
op.emitOpError() << "type is not 4D vector";
669+
return mlir::failure();
670+
}
671+
672+
auto shape = resType.getShape();
673+
auto newTy =
674+
mlir::VectorType::get({shape[2], shape[3]}, resType.getElementType());
675+
676+
// Get all the slices of Operands
677+
auto operands = adaptor.getOperands();
678+
679+
llvm::SmallVector<llvm::SmallVector<mlir::Value>> operand;
680+
if (numOperands == 1)
681+
operand.push_back(operands[0]);
682+
else if (numOperands == 2) {
683+
operand.push_back(operands[0]);
684+
operand.push_back(operands[1]);
685+
} else {
686+
operand.push_back(operands[0]);
687+
operand.push_back(operands[1]);
688+
operand.push_back(operands[2]);
689+
}
690+
691+
llvm::SmallVector<mlir::Value> newOps;
692+
for (int i = 0; i < shape[0] * shape[1]; i++) {
693+
auto newOp = createOp<Op, numOperands>(rewriter, op.getLoc(), operand, i);
694+
newOp->getResult(0).setType(newTy);
695+
newOps.push_back(newOp);
696+
}
697+
698+
rewriter.replaceOp(op, newOps);
699+
return mlir::success();
700+
}
701+
};
702+
628703
void populateXeTileOpConversionPatterns(imex::XeGPUTypeConverter &converter,
629704
mlir::RewritePatternSet &patterns,
630705
TileUsageAnalysis &analysis) {
@@ -633,6 +708,26 @@ void populateXeTileOpConversionPatterns(imex::XeGPUTypeConverter &converter,
633708
SgStoreTileOpPattern, SgTileMMAOpPattern,
634709
SgUpdateTileOffsetOpPattern>(patterns.getContext(), converter,
635710
analysis);
711+
patterns.insert<ElementWiseOpPattern<mlir::arith::NegFOp, 1>,
712+
ElementWiseOpPattern<mlir::math::ExpOp, 1>,
713+
ElementWiseOpPattern<mlir::math::SinOp, 1>,
714+
ElementWiseOpPattern<mlir::math::CosOp, 1>,
715+
ElementWiseOpPattern<mlir::math::SqrtOp, 1>,
716+
ElementWiseOpPattern<mlir::math::TanhOp, 1>,
717+
ElementWiseOpPattern<mlir::math::LogOp, 1>,
718+
ElementWiseOpPattern<mlir::math::RsqrtOp, 1>,
719+
ElementWiseOpPattern<mlir::math::ErfOp, 1>,
720+
ElementWiseOpPattern<mlir::arith::AddFOp, 2>,
721+
ElementWiseOpPattern<mlir::arith::RemFOp, 2>,
722+
ElementWiseOpPattern<mlir::arith::DivFOp, 2>,
723+
ElementWiseOpPattern<mlir::arith::MulFOp, 2>,
724+
ElementWiseOpPattern<mlir::arith::MaximumFOp, 2>,
725+
ElementWiseOpPattern<mlir::arith::MinimumFOp, 2>,
726+
ElementWiseOpPattern<mlir::arith::SubFOp, 2>,
727+
ElementWiseOpPattern<mlir::arith::XOrIOp, 2>,
728+
ElementWiseOpPattern<mlir::math::PowFOp, 2>,
729+
ElementWiseOpPattern<mlir::arith::SelectOp, 3>>(
730+
patterns.getContext(), converter, analysis);
636731
}
637732

638733
} // namespace imex

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include "imex/Utils/XeArch.h"
2020
namespace imex {
2121

22+
bool isLegalElementWiseOp(mlir::Operation *op);
23+
2224
void populateXeTileOpConversionPatterns(imex::XeGPUTypeConverter &converter,
2325
mlir::RewritePatternSet &patterns,
2426
TileUsageAnalysis &analysis);

lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <mlir/Dialect/Arith/IR/Arith.h>
1818
#include <mlir/Dialect/Func/IR/FuncOps.h>
1919
#include <mlir/Dialect/GPU/IR/GPUDialect.h>
20+
#include <mlir/Dialect/Math/IR/Math.h>
2021
#include <mlir/Dialect/Vector/IR/VectorOps.h>
2122
#include <mlir/IR/BuiltinOps.h>
2223
#include <mlir/Transforms/Passes.h>
@@ -77,6 +78,54 @@ class XeTileConversionTarget : public mlir::ConversionTarget {
7778
return (uArchInterface &&
7879
mlir::succeeded(uArchInterface->isLegalPrefetch2dOp(op)));
7980
});
81+
82+
// Arith ops
83+
addDynamicallyLegalOp<mlir::arith::AddFOp>(
84+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
85+
addDynamicallyLegalOp<mlir::arith::DivFOp>(
86+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
87+
addDynamicallyLegalOp<mlir::arith::MulFOp>(
88+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
89+
addDynamicallyLegalOp<mlir::arith::CmpFOp>(
90+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
91+
addDynamicallyLegalOp<mlir::arith::CmpIOp>(
92+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
93+
addDynamicallyLegalOp<mlir::arith::XOrIOp>(
94+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
95+
addDynamicallyLegalOp<mlir::arith::SubFOp>(
96+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
97+
addDynamicallyLegalOp<mlir::arith::MaximumFOp>(
98+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
99+
addDynamicallyLegalOp<mlir::arith::RemFOp>(
100+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
101+
addDynamicallyLegalOp<mlir::arith::NegFOp>(
102+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
103+
addDynamicallyLegalOp<mlir::arith::MaximumFOp>(
104+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
105+
addDynamicallyLegalOp<mlir::arith::MinimumFOp>(
106+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
107+
addDynamicallyLegalOp<mlir::arith::SelectOp>(
108+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
109+
110+
// Math Ops
111+
addDynamicallyLegalOp<mlir::math::ExpOp>(
112+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
113+
addDynamicallyLegalOp<mlir::math::PowFOp>(
114+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
115+
addDynamicallyLegalOp<mlir::math::SqrtOp>(
116+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
117+
addDynamicallyLegalOp<mlir::math::LogOp>(
118+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
119+
addDynamicallyLegalOp<mlir::math::ErfOp>(
120+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
121+
addDynamicallyLegalOp<mlir::math::SinOp>(
122+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
123+
addDynamicallyLegalOp<mlir::math::CosOp>(
124+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
125+
addDynamicallyLegalOp<mlir::math::RsqrtOp>(
126+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
127+
addDynamicallyLegalOp<mlir::math::TanhOp>(
128+
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
80129
}
81130

82131
private:

0 commit comments

Comments
 (0)