Skip to content

Commit 98dca47

Browse files
authored
[NFC]: Clean up AccelerateMatmul.cpp (#2679)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 52da140 commit 98dca47

File tree

1 file changed

+45
-48
lines changed

1 file changed

+45
-48
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "intel/include/Analysis/DPAS.h"
55
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
66
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
7-
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
87

98
#include "triton/Dialect/Triton/IR/Utility.h"
109
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -14,9 +13,8 @@
1413
#define PVC_2D_LOAD_MAXIMUM_BYTES_OF_COLS 64
1514

1615
using namespace mlir;
17-
using namespace mlir::triton;
18-
using namespace mlir::triton::gpu;
19-
using DPASAnalysis = intel::DPASAnalysis;
16+
namespace tt = mlir::triton;
17+
namespace ttg = mlir::triton::gpu;
2018

2119
namespace mlir::triton::gpu::intel {
2220
#define GEN_PASS_DEF_TRITONINTELGPUACCELERATEMATMUL
@@ -55,7 +53,7 @@ IntelDPASCapability getDPASCapability(unsigned minSGSize) {
5553
}
5654
}
5755

58-
SmallVector<unsigned> getWarpsPerTile(DotOp dotOp,
56+
SmallVector<unsigned> getWarpsPerTile(tt::DotOp dotOp,
5957
struct IntelDPASCapability dpasCap,
6058
const ArrayRef<int64_t> shape,
6159
unsigned numWarps) {
@@ -66,7 +64,7 @@ SmallVector<unsigned> getWarpsPerTile(DotOp dotOp,
6664
SetVector<Operation *> slices = getSlice(dotOp, {filter});
6765
// TODO: revisit this in flash attention.
6866
for (Operation *op : slices)
69-
if (isa<DotOp>(op) && (op != dotOp))
67+
if (isa<tt::DotOp>(op) && (op != dotOp))
7068
return {numWarps, 1};
7169

7270
size_t rank = shape.size();
@@ -108,41 +106,41 @@ SmallVector<unsigned> getWarpsPerTile(DotOp dotOp,
108106
return ret;
109107
}
110108

111-
class BlockedToDPAS : public RewritePattern {
112-
const DPASAnalysis &dpasAnalysis;
109+
class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
110+
const ttg::intel::DPASAnalysis &dpasAnalysis;
113111

114112
public:
115-
BlockedToDPAS(MLIRContext *context, const DPASAnalysis &dpasAnalysis)
116-
: RewritePattern(DotOp::getOperationName(), 2, context),
117-
dpasAnalysis(dpasAnalysis) {}
113+
BlockedToDPAS(MLIRContext *context,
114+
const ttg::intel::DPASAnalysis &dpasAnalysis)
115+
: OpRewritePattern<tt::DotOp>(context), dpasAnalysis(dpasAnalysis) {}
118116

119-
LogicalResult matchAndRewrite(Operation *op,
117+
LogicalResult matchAndRewrite(tt::DotOp dotOp,
120118
PatternRewriter &rewriter) const override {
121-
DotOp dotOp = cast<DotOp>(op);
122-
RankedTensorType oldRetType =
123-
cast<RankedTensorType>(dotOp.getResult().getType());
119+
using TensorValue = TypedValue<RankedTensorType>;
120+
121+
RankedTensorType oldRetType = dotOp.getType();
124122
if (!oldRetType.getEncoding() ||
125-
isa<intel::DpasEncodingAttr>(oldRetType.getEncoding()))
123+
isa<ttg::intel::DpasEncodingAttr>(oldRetType.getEncoding()))
126124
return failure();
127125

128-
auto funcOp = op->getParentOfType<FunctionOpInterface>();
129-
if (dpasAnalysis.canUseDPAS(funcOp) != DPASAnalysis::Result::True)
126+
auto funcOp = dotOp->getParentOfType<FunctionOpInterface>();
127+
if (dpasAnalysis.canUseDPAS(funcOp) !=
128+
ttg::intel::DPASAnalysis::Result::True)
130129
return failure();
131130

132131
// Create DPAS encoding for the given number of warps
133132
ArrayRef<int64_t> retShape = oldRetType.getShape();
134-
size_t rank = retShape.size();
135133
ModuleOp mod = funcOp->getParentOfType<ModuleOp>();
136-
unsigned numWarps = TritonGPUDialect::getNumWarps(mod);
134+
unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
137135

138-
Value a = dotOp.getA();
139-
Value b = dotOp.getB();
140-
RankedTensorType oldAType = cast<RankedTensorType>(a.getType());
141-
RankedTensorType oldBType = cast<RankedTensorType>(b.getType());
136+
TensorValue a = dotOp.getA();
137+
TensorValue b = dotOp.getB();
138+
auto oldAType = cast<RankedTensorType>(a.getType());
139+
auto oldBType = cast<RankedTensorType>(b.getType());
142140

143141
unsigned minSGSize =
144142
mod->getAttrOfType<IntegerAttr>(
145-
intel::TritonIntelGPUDialect::getMinSGSizeAttrName())
143+
ttg::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())
146144
.getInt();
147145
IntelDPASCapability dpasCap = getDPASCapability(minSGSize);
148146
unsigned dpasElemBitWidths =
@@ -156,10 +154,11 @@ class BlockedToDPAS : public RewritePattern {
156154
unsigned opsPerChan = dpasCap.opsChanBitWidths / dpasElemBitWidths;
157155
SmallVector<unsigned> warpsPerTile =
158156
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps);
157+
size_t rank = retShape.size();
159158
SmallVector<unsigned> repCluster(rank, 1);
160159

161-
unsigned threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
162-
auto dpasEnc = intel::DpasEncodingAttr::get(
160+
unsigned threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod);
161+
auto dpasEnc = ttg::intel::DpasEncodingAttr::get(
163162
oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth,
164163
dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster,
165164
threadsPerWarp);
@@ -184,7 +183,7 @@ class BlockedToDPAS : public RewritePattern {
184183
repCluster[rank - 2] = repClusterDimM;
185184
repCluster[rank - 1] = repClusterDimN;
186185

187-
dpasEnc = intel::DpasEncodingAttr::get(
186+
dpasEnc = ttg::intel::DpasEncodingAttr::get(
188187
oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth,
189188
dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster,
190189
threadsPerWarp);
@@ -194,28 +193,28 @@ class BlockedToDPAS : public RewritePattern {
194193
RankedTensorType::get(retShape, oldRetType.getElementType(), dpasEnc);
195194

196195
// convert accumulator
197-
Value oldAcc = dotOp.getC();
198-
ConvertLayoutOp newAcc =
199-
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);
196+
TensorValue oldAcc = dotOp.getC();
197+
auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(oldAcc.getLoc(),
198+
newRetType, oldAcc);
200199

201-
DotOperandEncodingAttr newAEncoding = DotOperandEncodingAttr::get(
200+
auto newAEncoding = ttg::DotOperandEncodingAttr::get(
202201
oldAType.getContext(), 0, newRetType.getEncoding(), opsPerChan);
203-
DotOperandEncodingAttr newBEncoding = DotOperandEncodingAttr::get(
202+
auto newBEncoding = ttg::DotOperandEncodingAttr::get(
204203
oldBType.getContext(), 1, newRetType.getEncoding(), opsPerChan);
205204

206-
RankedTensorType newAType = RankedTensorType::get(
205+
auto newAType = RankedTensorType::get(
207206
oldAType.getShape(), oldAType.getElementType(), newAEncoding);
208-
RankedTensorType newBType = RankedTensorType::get(
207+
auto newBType = RankedTensorType::get(
209208
oldBType.getShape(), oldBType.getElementType(), newBEncoding);
210209

211-
a = rewriter.create<ConvertLayoutOp>(a.getLoc(), newAType, a);
212-
b = rewriter.create<ConvertLayoutOp>(b.getLoc(), newBType, b);
213-
DotOp newDot = rewriter.create<DotOp>(dotOp.getLoc(), newRetType, a, b,
214-
newAcc, dotOp.getInputPrecision(),
215-
dotOp.getMaxNumImpreciseAcc());
210+
a = rewriter.create<ttg::ConvertLayoutOp>(a.getLoc(), newAType, a);
211+
b = rewriter.create<ttg::ConvertLayoutOp>(b.getLoc(), newBType, b);
212+
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b,
213+
newAcc, dotOp.getInputPrecision(),
214+
dotOp.getMaxNumImpreciseAcc());
216215

217-
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, oldRetType,
218-
newDot.getResult());
216+
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(dotOp, oldRetType,
217+
newDot.getResult());
219218
return success();
220219
}
221220
};
@@ -230,7 +229,7 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
230229

231230
return llvm::TypeSwitch<Type, Value>(elemType)
232231
.Case<FloatType>([&](auto) {
233-
return builder.create<FpToFpOp>(loc, tensorPromotedType, operand);
232+
return builder.create<tt::FpToFpOp>(loc, tensorPromotedType, operand);
234233
})
235234
.Case<IntegerType>([&](auto) {
236235
unsigned tgtBitWidth = elemType.getIntOrFloatBitWidth(),
@@ -248,12 +247,12 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
248247
// promote operands of dot op if the existing combination is not natively
249248
// supported.
250249
static void decomposeMixedModeDotOp(ModuleOp mod) {
251-
mod.walk([](DotOp dotOp) -> void {
250+
mod.walk([](tt::DotOp dotOp) -> void {
252251
auto D = dotOp.getD();
253252
OpBuilder builder(dotOp);
254253
Type AElType = dotOp.getA().getType().getElementType();
255254
auto dpasLayout =
256-
dyn_cast<intel::DpasEncodingAttr>(D.getType().getEncoding());
255+
dyn_cast<ttg::intel::DpasEncodingAttr>(D.getType().getEncoding());
257256

258257
Type promoteType;
259258
if (dpasLayout) {
@@ -289,15 +288,13 @@ class TritonIntelGPUAccelerateMatmulPass
289288
void runOnOperation() override {
290289
MLIRContext *context = &getContext();
291290
ModuleOp m = getOperation();
292-
DPASAnalysis &dpasAnalysis = getAnalysis<DPASAnalysis>();
291+
auto &dpasAnalysis = getAnalysis<ttg::intel::DPASAnalysis>();
293292

294293
RewritePatternSet patterns(context);
295294
patterns.add<BlockedToDPAS>(context, dpasAnalysis);
296295
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
297296
signalPassFailure();
298297

299-
// now that we pick the scalar type decompose dot that are not natively
300-
// supported.
301298
decomposeMixedModeDotOp(m);
302299
}
303300
};

0 commit comments

Comments
 (0)