@@ -34,9 +34,10 @@ namespace mlir::triton::gpu::intel {
3434namespace {
3535
3636// FIXME: Remove once IGC can split large 2D block loads.
37- static void setAttrOnBOperand (tt::DotOp dotOp , StringRef attrName,
37+ static void setAttrOnBOperand (Operation *op , StringRef attrName,
3838 Attribute attr) {
39- Operation *defOp = dotOp.getB ().getDefiningOp ();
39+ assert (isa<tt::DotOp>(op) && " Unexpected operation type" );
40+ Operation *defOp = cast<tt::DotOp>(op).getB ().getDefiningOp ();
4041 while (auto convOp = dyn_cast_or_null<ttg::ConvertLayoutOp>(defOp))
4142 defOp = convOp.getSrc ().getDefiningOp ();
4243 if (auto transOp = dyn_cast_or_null<tt::TransOp>(defOp))
@@ -46,7 +47,8 @@ static void setAttrOnBOperand(tt::DotOp dotOp, StringRef attrName,
4647}
4748
4849SmallVector<unsigned >
49- getWarpsPerTile (tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
50+ getWarpsPerTile (Operation *dotOp,
51+ ttgi::DpasEncodingAttr::DPASCapability dpasCap,
5052 const ArrayRef<int64_t > shape, unsigned numWarps) {
5153 auto filter = [&dotOp](Operation *op) {
5254 return op->getParentRegion () == dotOp->getParentRegion ();
@@ -60,8 +62,8 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
6062 MLIRContext *ctx = forOp->getContext ();
6163 StringRef attrName =
6264 ttgi::TritonIntelGPUDialect::getOneMatrixPerLoadAttrName ();
63- setAttrOnBOperand (dotOp , attrName, UnitAttr::get (ctx));
64- setAttrOnBOperand (cast<tt::DotOp>(op) , attrName, UnitAttr::get (ctx));
65+ setAttrOnBOperand (op , attrName, UnitAttr::get (ctx));
66+ setAttrOnBOperand (op , attrName, UnitAttr::get (ctx));
6567 }
6668 SmallVector<unsigned > ret (shape.size (), 1 );
6769 ret[0 ] = numWarps;
@@ -108,42 +110,44 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
108110 return ret;
109111}
110112
111- class BlockedToDPAS : public OpRewritePattern <tt::DotOp> {
112- const ttg::intel::DPASAnalysis &dpasAnalysis;
113+ template <class OpTy ,
114+ typename = std::enable_if_t <llvm::is_one_of<OpTy, tt::DotOp>::value>>
115+ class BlockedToDPAS : public OpRewritePattern <OpTy> {
116+ const ttgi::DPASAnalysis &dpasAnalysis;
113117 using TensorValue = TypedValue<RankedTensorType>;
114118
115119public:
116- BlockedToDPAS (MLIRContext *context,
117- const ttg::intel::DPASAnalysis &dpasAnalysis )
118- : OpRewritePattern<tt::DotOp >(context), dpasAnalysis(dpasAnalysis) {}
120+ BlockedToDPAS (MLIRContext *context, const ttgi::DPASAnalysis &dpasAnalysis,
121+ int benefit )
122+ : OpRewritePattern<OpTy >(context, benefit ), dpasAnalysis(dpasAnalysis) {}
119123
120- LogicalResult matchAndRewrite (tt::DotOp dotOp ,
124+ LogicalResult matchAndRewrite (OpTy op ,
121125 PatternRewriter &rewriter) const override {
122- RankedTensorType oldRetType = dotOp .getType ();
126+ RankedTensorType oldRetType = op .getType ();
123127 if (!oldRetType.getEncoding () ||
124128 isa<ttgi::DpasEncodingAttr>(oldRetType.getEncoding ()))
125129 return failure ();
126130
127- auto funcOp = dotOp->getParentOfType <FunctionOpInterface>();
128- if (dpasAnalysis.canUseDPAS (funcOp) !=
129- ttg::intel::DPASAnalysis::Result::True)
131+ auto funcOp = op->template getParentOfType <FunctionOpInterface>();
132+ if (dpasAnalysis.canUseDPAS (funcOp) != ttgi::DPASAnalysis::Result::True)
130133 return failure ();
131134
132135 // Create DPAS encoding for the given number of warps
133136 ArrayRef<int64_t > retShape = oldRetType.getShape ();
134137 unsigned numWarps = ttg::lookupNumWarps (funcOp);
135138
136- TensorValue a = dotOp .getA ();
137- TensorValue b = dotOp .getB ();
139+ TensorValue a = op .getA ();
140+ TensorValue b = op .getB ();
138141 auto oldAType = cast<RankedTensorType>(a.getType ());
139142 auto oldBType = cast<RankedTensorType>(b.getType ());
140143
141- ModuleOp mod = funcOp->getParentOfType <ModuleOp>();
142- auto dpasCap = ttgi::DpasEncodingAttr::getDPASCapability (mod);
144+ ModuleOp mod = funcOp->template getParentOfType <ModuleOp>();
145+ ttgi::DpasEncodingAttr::DPASCapability dpasCap =
146+ ttgi::DpasEncodingAttr::getDPASCapability (mod);
143147 Type elemType = oldAType.getElementType ();
144148 unsigned opsPerChan = ttgi::DpasEncodingAttr::getOpsPerChannel (elemType);
145149 SmallVector<unsigned > warpsPerTile =
146- getWarpsPerTile (dotOp , dpasCap, retShape, numWarps);
150+ getWarpsPerTile (op , dpasCap, retShape, numWarps);
147151 size_t rank = retShape.size ();
148152 SmallVector<unsigned > repCluster (rank, 1 );
149153
@@ -156,6 +160,7 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
156160 : dpasCap.systolicDepth * 2 ; // A is packed to i16 or i32.
157161 unsigned minM = mlir::ceil<unsigned >(threadsPerWarp, numElemsPerRowForA);
158162 repeatCount = std::max (repeatCount, minM);
163+
159164 auto dpasEnc = ttgi::DpasEncodingAttr::get (
160165 oldRetType.getContext (), repeatCount, dpasCap.systolicDepth ,
161166 dpasCap.executionSize , opsPerChan, warpsPerTile, repCluster,
@@ -194,11 +199,11 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
194199 threadsPerWarp);
195200 }
196201
197- RankedTensorType newRetType =
202+ auto newRetType =
198203 RankedTensorType::get (retShape, oldRetType.getElementType (), dpasEnc);
199204
200205 // convert accumulator
201- TensorValue oldAcc = dotOp .getC ();
206+ TensorValue oldAcc = op .getC ();
202207 auto newAcc = ttg::ConvertLayoutOp::create (rewriter, oldAcc.getLoc (),
203208 newRetType, oldAcc);
204209 // opA are packed to i16 for scalar type < 16 bits. opB are packed to i32.
@@ -215,15 +220,17 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
215220
216221 a = ttg::ConvertLayoutOp::create (rewriter, a.getLoc (), newAType, a);
217222 b = ttg::ConvertLayoutOp::create (rewriter, b.getLoc (), newBType, b);
218- auto newDot = tt::DotOp::create (rewriter, dotOp.getLoc (), newRetType, a, b,
219- newAcc, dotOp.getInputPrecision (),
220- dotOp.getMaxNumImpreciseAcc ());
221223
222- rewriter.replaceOpWithNewOp <ttg::ConvertLayoutOp>(dotOp, oldRetType,
224+ auto newDot =
225+ tt::DotOp::create (rewriter, op.getLoc (), newRetType, a, b, newAcc,
226+ op.getInputPrecision (), op.getMaxNumImpreciseAcc ());
227+
228+ rewriter.replaceOpWithNewOp <ttg::ConvertLayoutOp>(op, oldRetType,
223229 newDot.getResult ());
224230 return success ();
225231 }
226232};
233+
227234} // namespace
228235
229236static Value promoteOperand (OpBuilder &builder, Location loc, Value operand,
@@ -258,13 +265,13 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {
258265 OpBuilder builder (dotOp);
259266 Type AElType = dotOp.getA ().getType ().getElementType ();
260267 auto dpasLayout =
261- dyn_cast<ttg::intel ::DpasEncodingAttr>(D.getType ().getEncoding ());
268+ dyn_cast<ttgi ::DpasEncodingAttr>(D.getType ().getEncoding ());
262269
263270 Type promoteType;
264271 if (dpasLayout) {
265272 bool isNativeFP8 = isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
266- // fp8 is not natively supported by the the DPAS instruction, promote it
267- // to fp16.
273+ // fp8 is not always natively supported by the the DPAS instruction,
274+ // promote it to fp16 when necessary .
268275 if (!isNativeFP8)
269276 return ;
270277 promoteType = builder.getF16Type ();
@@ -428,23 +435,24 @@ static void transposeDots(ModuleOp m) {
428435}
429436
430437class TritonIntelGPUAccelerateMatmulPass
431- : public triton::gpu::intel ::impl::TritonIntelGPUAccelerateMatmulBase<
438+ : public ttgi ::impl::TritonIntelGPUAccelerateMatmulBase<
432439 TritonIntelGPUAccelerateMatmulPass> {
433440public:
434- using triton::gpu::intel ::impl::TritonIntelGPUAccelerateMatmulBase<
441+ using ttgi ::impl::TritonIntelGPUAccelerateMatmulBase<
435442 TritonIntelGPUAccelerateMatmulPass>::TritonIntelGPUAccelerateMatmulBase;
436443
437444 void runOnOperation () override {
438445 MLIRContext *context = &getContext ();
439446 ModuleOp m = getOperation ();
440- auto &dpasAnalysis = getAnalysis<ttg::intel ::DPASAnalysis>();
447+ auto &dpasAnalysis = getAnalysis<ttgi ::DPASAnalysis>();
441448
442449 // Transpose dotOp operations that have a scale on the RHS.
443450 transposeDots (m);
444451
445452 RewritePatternSet patterns (context);
446453 constexpr int benefitDefault = 1 ;
447- patterns.add <BlockedToDPAS>(context, dpasAnalysis);
454+ patterns.add <BlockedToDPAS<tt::DotOp>>(context, dpasAnalysis,
455+ benefitDefault + 1 );
448456 ttgi::populateDecomposeScaledBlockedPatterns (patterns, benefitDefault);
449457 if (applyPatternsGreedily (m, std::move (patterns)).failed ())
450458 signalPassFailure ();
0 commit comments