Skip to content

Commit 2b61b09

Browse files
authored
[NFC]: Refactor accelerate matmul (#5607)
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent e930c21 commit 2b61b09

File tree

1 file changed

+41
-33
lines changed

1 file changed

+41
-33
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ namespace mlir::triton::gpu::intel {
3434
namespace {
3535

3636
// FIXME: Remove once IGC can split large 2D block loads.
37-
static void setAttrOnBOperand(tt::DotOp dotOp, StringRef attrName,
37+
static void setAttrOnBOperand(Operation *op, StringRef attrName,
3838
Attribute attr) {
39-
Operation *defOp = dotOp.getB().getDefiningOp();
39+
assert(isa<tt::DotOp>(op) && "Unexpected operation type");
40+
Operation *defOp = cast<tt::DotOp>(op).getB().getDefiningOp();
4041
while (auto convOp = dyn_cast_or_null<ttg::ConvertLayoutOp>(defOp))
4142
defOp = convOp.getSrc().getDefiningOp();
4243
if (auto transOp = dyn_cast_or_null<tt::TransOp>(defOp))
@@ -46,7 +47,8 @@ static void setAttrOnBOperand(tt::DotOp dotOp, StringRef attrName,
4647
}
4748

4849
SmallVector<unsigned>
49-
getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
50+
getWarpsPerTile(Operation *dotOp,
51+
ttgi::DpasEncodingAttr::DPASCapability dpasCap,
5052
const ArrayRef<int64_t> shape, unsigned numWarps) {
5153
auto filter = [&dotOp](Operation *op) {
5254
return op->getParentRegion() == dotOp->getParentRegion();
@@ -60,8 +62,8 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
6062
MLIRContext *ctx = forOp->getContext();
6163
StringRef attrName =
6264
ttgi::TritonIntelGPUDialect::getOneMatrixPerLoadAttrName();
63-
setAttrOnBOperand(dotOp, attrName, UnitAttr::get(ctx));
64-
setAttrOnBOperand(cast<tt::DotOp>(op), attrName, UnitAttr::get(ctx));
65+
setAttrOnBOperand(op, attrName, UnitAttr::get(ctx));
66+
setAttrOnBOperand(op, attrName, UnitAttr::get(ctx));
6567
}
6668
SmallVector<unsigned> ret(shape.size(), 1);
6769
ret[0] = numWarps;
@@ -108,42 +110,44 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
108110
return ret;
109111
}
110112

111-
class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
112-
const ttg::intel::DPASAnalysis &dpasAnalysis;
113+
template <class OpTy,
114+
typename = std::enable_if_t<llvm::is_one_of<OpTy, tt::DotOp>::value>>
115+
class BlockedToDPAS : public OpRewritePattern<OpTy> {
116+
const ttgi::DPASAnalysis &dpasAnalysis;
113117
using TensorValue = TypedValue<RankedTensorType>;
114118

115119
public:
116-
BlockedToDPAS(MLIRContext *context,
117-
const ttg::intel::DPASAnalysis &dpasAnalysis)
118-
: OpRewritePattern<tt::DotOp>(context), dpasAnalysis(dpasAnalysis) {}
120+
BlockedToDPAS(MLIRContext *context, const ttgi::DPASAnalysis &dpasAnalysis,
121+
int benefit)
122+
: OpRewritePattern<OpTy>(context, benefit), dpasAnalysis(dpasAnalysis) {}
119123

120-
LogicalResult matchAndRewrite(tt::DotOp dotOp,
124+
LogicalResult matchAndRewrite(OpTy op,
121125
PatternRewriter &rewriter) const override {
122-
RankedTensorType oldRetType = dotOp.getType();
126+
RankedTensorType oldRetType = op.getType();
123127
if (!oldRetType.getEncoding() ||
124128
isa<ttgi::DpasEncodingAttr>(oldRetType.getEncoding()))
125129
return failure();
126130

127-
auto funcOp = dotOp->getParentOfType<FunctionOpInterface>();
128-
if (dpasAnalysis.canUseDPAS(funcOp) !=
129-
ttg::intel::DPASAnalysis::Result::True)
131+
auto funcOp = op->template getParentOfType<FunctionOpInterface>();
132+
if (dpasAnalysis.canUseDPAS(funcOp) != ttgi::DPASAnalysis::Result::True)
130133
return failure();
131134

132135
// Create DPAS encoding for the given number of warps
133136
ArrayRef<int64_t> retShape = oldRetType.getShape();
134137
unsigned numWarps = ttg::lookupNumWarps(funcOp);
135138

136-
TensorValue a = dotOp.getA();
137-
TensorValue b = dotOp.getB();
139+
TensorValue a = op.getA();
140+
TensorValue b = op.getB();
138141
auto oldAType = cast<RankedTensorType>(a.getType());
139142
auto oldBType = cast<RankedTensorType>(b.getType());
140143

141-
ModuleOp mod = funcOp->getParentOfType<ModuleOp>();
142-
auto dpasCap = ttgi::DpasEncodingAttr::getDPASCapability(mod);
144+
ModuleOp mod = funcOp->template getParentOfType<ModuleOp>();
145+
ttgi::DpasEncodingAttr::DPASCapability dpasCap =
146+
ttgi::DpasEncodingAttr::getDPASCapability(mod);
143147
Type elemType = oldAType.getElementType();
144148
unsigned opsPerChan = ttgi::DpasEncodingAttr::getOpsPerChannel(elemType);
145149
SmallVector<unsigned> warpsPerTile =
146-
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps);
150+
getWarpsPerTile(op, dpasCap, retShape, numWarps);
147151
size_t rank = retShape.size();
148152
SmallVector<unsigned> repCluster(rank, 1);
149153

@@ -156,6 +160,7 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
156160
: dpasCap.systolicDepth * 2; // A is packed to i16 or i32.
157161
unsigned minM = mlir::ceil<unsigned>(threadsPerWarp, numElemsPerRowForA);
158162
repeatCount = std::max(repeatCount, minM);
163+
159164
auto dpasEnc = ttgi::DpasEncodingAttr::get(
160165
oldRetType.getContext(), repeatCount, dpasCap.systolicDepth,
161166
dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster,
@@ -194,11 +199,11 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
194199
threadsPerWarp);
195200
}
196201

197-
RankedTensorType newRetType =
202+
auto newRetType =
198203
RankedTensorType::get(retShape, oldRetType.getElementType(), dpasEnc);
199204

200205
// convert accumulator
201-
TensorValue oldAcc = dotOp.getC();
206+
TensorValue oldAcc = op.getC();
202207
auto newAcc = ttg::ConvertLayoutOp::create(rewriter, oldAcc.getLoc(),
203208
newRetType, oldAcc);
204209
// opA are packed to i16 for scalar type < 16 bits. opB are packed to i32.
@@ -215,15 +220,17 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
215220

216221
a = ttg::ConvertLayoutOp::create(rewriter, a.getLoc(), newAType, a);
217222
b = ttg::ConvertLayoutOp::create(rewriter, b.getLoc(), newBType, b);
218-
auto newDot = tt::DotOp::create(rewriter, dotOp.getLoc(), newRetType, a, b,
219-
newAcc, dotOp.getInputPrecision(),
220-
dotOp.getMaxNumImpreciseAcc());
221223

222-
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(dotOp, oldRetType,
224+
auto newDot =
225+
tt::DotOp::create(rewriter, op.getLoc(), newRetType, a, b, newAcc,
226+
op.getInputPrecision(), op.getMaxNumImpreciseAcc());
227+
228+
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
223229
newDot.getResult());
224230
return success();
225231
}
226232
};
233+
227234
} // namespace
228235

229236
static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
@@ -258,13 +265,13 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {
258265
OpBuilder builder(dotOp);
259266
Type AElType = dotOp.getA().getType().getElementType();
260267
auto dpasLayout =
261-
dyn_cast<ttg::intel::DpasEncodingAttr>(D.getType().getEncoding());
268+
dyn_cast<ttgi::DpasEncodingAttr>(D.getType().getEncoding());
262269

263270
Type promoteType;
264271
if (dpasLayout) {
265272
bool isNativeFP8 = isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
266-
// fp8 is not natively supported by the the DPAS instruction, promote it
267-
// to fp16.
273+
// fp8 is not always natively supported by the the DPAS instruction,
274+
// promote it to fp16 when necessary.
268275
if (!isNativeFP8)
269276
return;
270277
promoteType = builder.getF16Type();
@@ -428,23 +435,24 @@ static void transposeDots(ModuleOp m) {
428435
}
429436

430437
class TritonIntelGPUAccelerateMatmulPass
431-
: public triton::gpu::intel::impl::TritonIntelGPUAccelerateMatmulBase<
438+
: public ttgi::impl::TritonIntelGPUAccelerateMatmulBase<
432439
TritonIntelGPUAccelerateMatmulPass> {
433440
public:
434-
using triton::gpu::intel::impl::TritonIntelGPUAccelerateMatmulBase<
441+
using ttgi::impl::TritonIntelGPUAccelerateMatmulBase<
435442
TritonIntelGPUAccelerateMatmulPass>::TritonIntelGPUAccelerateMatmulBase;
436443

437444
void runOnOperation() override {
438445
MLIRContext *context = &getContext();
439446
ModuleOp m = getOperation();
440-
auto &dpasAnalysis = getAnalysis<ttg::intel::DPASAnalysis>();
447+
auto &dpasAnalysis = getAnalysis<ttgi::DPASAnalysis>();
441448

442449
// Transpose dotOp operations that have a scale on the RHS.
443450
transposeDots(m);
444451

445452
RewritePatternSet patterns(context);
446453
constexpr int benefitDefault = 1;
447-
patterns.add<BlockedToDPAS>(context, dpasAnalysis);
454+
patterns.add<BlockedToDPAS<tt::DotOp>>(context, dpasAnalysis,
455+
benefitDefault + 1);
448456
ttgi::populateDecomposeScaledBlockedPatterns(patterns, benefitDefault);
449457
if (applyPatternsGreedily(m, std::move(patterns)).failed())
450458
signalPassFailure();

0 commit comments

Comments
 (0)