44#include " intel/include/Analysis/DPAS.h"
55#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
66#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
7- #include " intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
87
98#include " triton/Dialect/Triton/IR/Utility.h"
109#include " triton/Dialect/TritonGPU/IR/Dialect.h"
1413#define PVC_2D_LOAD_MAXIMUM_BYTES_OF_COLS 64
1514
1615using namespace mlir ;
17- using namespace mlir ::triton;
18- using namespace mlir ::triton::gpu;
19- using DPASAnalysis = intel::DPASAnalysis;
16+ namespace tt = mlir::triton;
17+ namespace ttg = mlir::triton::gpu;
2018
2119namespace mlir ::triton::gpu::intel {
2220#define GEN_PASS_DEF_TRITONINTELGPUACCELERATEMATMUL
@@ -55,7 +53,7 @@ IntelDPASCapability getDPASCapability(unsigned minSGSize) {
5553 }
5654}
5755
58- SmallVector<unsigned > getWarpsPerTile (DotOp dotOp,
56+ SmallVector<unsigned > getWarpsPerTile (tt:: DotOp dotOp,
5957 struct IntelDPASCapability dpasCap,
6058 const ArrayRef<int64_t > shape,
6159 unsigned numWarps) {
@@ -66,7 +64,7 @@ SmallVector<unsigned> getWarpsPerTile(DotOp dotOp,
6664 SetVector<Operation *> slices = getSlice (dotOp, {filter});
6765 // TODO: revisit this in flash attention.
6866 for (Operation *op : slices)
69- if (isa<DotOp>(op) && (op != dotOp))
67+ if (isa<tt:: DotOp>(op) && (op != dotOp))
7068 return {numWarps, 1 };
7169
7270 size_t rank = shape.size ();
@@ -108,41 +106,41 @@ SmallVector<unsigned> getWarpsPerTile(DotOp dotOp,
108106 return ret;
109107}
110108
111- class BlockedToDPAS : public RewritePattern {
112- const DPASAnalysis &dpasAnalysis;
109+ class BlockedToDPAS : public OpRewritePattern <tt::DotOp> {
110+ const ttg::intel:: DPASAnalysis &dpasAnalysis;
113111
114112public:
115- BlockedToDPAS (MLIRContext *context, const DPASAnalysis &dpasAnalysis)
116- : RewritePattern(DotOp::getOperationName(), 2 , context),
117- dpasAnalysis (dpasAnalysis) {}
113+ BlockedToDPAS (MLIRContext *context,
114+ const ttg::intel::DPASAnalysis &dpasAnalysis)
115+ : OpRewritePattern<tt::DotOp>(context), dpasAnalysis(dpasAnalysis) {}
118116
119- LogicalResult matchAndRewrite (Operation *op ,
117+ LogicalResult matchAndRewrite (tt::DotOp dotOp ,
120118 PatternRewriter &rewriter) const override {
121- DotOp dotOp = cast<DotOp>(op) ;
122- RankedTensorType oldRetType =
123- cast<RankedTensorType>( dotOp.getResult (). getType () );
119+ using TensorValue = TypedValue<RankedTensorType> ;
120+
121+ RankedTensorType oldRetType = dotOp.getType ();
124122 if (!oldRetType.getEncoding () ||
125- isa<intel::DpasEncodingAttr>(oldRetType.getEncoding ()))
123+ isa<ttg:: intel::DpasEncodingAttr>(oldRetType.getEncoding ()))
126124 return failure ();
127125
128- auto funcOp = op->getParentOfType <FunctionOpInterface>();
129- if (dpasAnalysis.canUseDPAS (funcOp) != DPASAnalysis::Result::True)
126+ auto funcOp = dotOp->getParentOfType <FunctionOpInterface>();
127+ if (dpasAnalysis.canUseDPAS (funcOp) !=
128+ ttg::intel::DPASAnalysis::Result::True)
130129 return failure ();
131130
132131 // Create DPAS encoding for the given number of warps
133132 ArrayRef<int64_t > retShape = oldRetType.getShape ();
134- size_t rank = retShape.size ();
135133 ModuleOp mod = funcOp->getParentOfType <ModuleOp>();
136- unsigned numWarps = TritonGPUDialect::getNumWarps (mod);
134+ unsigned numWarps = ttg:: TritonGPUDialect::getNumWarps (mod);
137135
138- Value a = dotOp.getA ();
139- Value b = dotOp.getB ();
140- RankedTensorType oldAType = cast<RankedTensorType>(a.getType ());
141- RankedTensorType oldBType = cast<RankedTensorType>(b.getType ());
136+ TensorValue a = dotOp.getA ();
137+ TensorValue b = dotOp.getB ();
138+ auto oldAType = cast<RankedTensorType>(a.getType ());
139+ auto oldBType = cast<RankedTensorType>(b.getType ());
142140
143141 unsigned minSGSize =
144142 mod->getAttrOfType <IntegerAttr>(
145- intel::TritonIntelGPUDialect::getMinSGSizeAttrName ())
143+ ttg:: intel::TritonIntelGPUDialect::getMinSGSizeAttrName ())
146144 .getInt ();
147145 IntelDPASCapability dpasCap = getDPASCapability (minSGSize);
148146 unsigned dpasElemBitWidths =
@@ -156,10 +154,11 @@ class BlockedToDPAS : public RewritePattern {
156154 unsigned opsPerChan = dpasCap.opsChanBitWidths / dpasElemBitWidths;
157155 SmallVector<unsigned > warpsPerTile =
158156 getWarpsPerTile (dotOp, dpasCap, retShape, numWarps);
157+ size_t rank = retShape.size ();
159158 SmallVector<unsigned > repCluster (rank, 1 );
160159
161- unsigned threadsPerWarp = TritonGPUDialect::getThreadsPerWarp (mod);
162- auto dpasEnc = intel::DpasEncodingAttr::get (
160+ unsigned threadsPerWarp = ttg:: TritonGPUDialect::getThreadsPerWarp (mod);
161+ auto dpasEnc = ttg:: intel::DpasEncodingAttr::get (
163162 oldRetType.getContext (), dpasCap.repeatCount , dpasCap.systolicDepth ,
164163 dpasCap.executionSize , opsPerChan, warpsPerTile, repCluster,
165164 threadsPerWarp);
@@ -184,7 +183,7 @@ class BlockedToDPAS : public RewritePattern {
184183 repCluster[rank - 2 ] = repClusterDimM;
185184 repCluster[rank - 1 ] = repClusterDimN;
186185
187- dpasEnc = intel::DpasEncodingAttr::get (
186+ dpasEnc = ttg:: intel::DpasEncodingAttr::get (
188187 oldRetType.getContext (), dpasCap.repeatCount , dpasCap.systolicDepth ,
189188 dpasCap.executionSize , opsPerChan, warpsPerTile, repCluster,
190189 threadsPerWarp);
@@ -194,28 +193,28 @@ class BlockedToDPAS : public RewritePattern {
194193 RankedTensorType::get (retShape, oldRetType.getElementType (), dpasEnc);
195194
196195 // convert accumulator
197- Value oldAcc = dotOp.getC ();
198- ConvertLayoutOp newAcc =
199- rewriter. create <ConvertLayoutOp>(oldAcc. getLoc (), newRetType, oldAcc);
196+ TensorValue oldAcc = dotOp.getC ();
197+ auto newAcc = rewriter. create <ttg::ConvertLayoutOp>(oldAcc. getLoc (),
198+ newRetType, oldAcc);
200199
201- DotOperandEncodingAttr newAEncoding = DotOperandEncodingAttr::get (
200+ auto newAEncoding = ttg:: DotOperandEncodingAttr::get (
202201 oldAType.getContext (), 0 , newRetType.getEncoding (), opsPerChan);
203- DotOperandEncodingAttr newBEncoding = DotOperandEncodingAttr::get (
202+ auto newBEncoding = ttg:: DotOperandEncodingAttr::get (
204203 oldBType.getContext (), 1 , newRetType.getEncoding (), opsPerChan);
205204
206- RankedTensorType newAType = RankedTensorType::get (
205+ auto newAType = RankedTensorType::get (
207206 oldAType.getShape (), oldAType.getElementType (), newAEncoding);
208- RankedTensorType newBType = RankedTensorType::get (
207+ auto newBType = RankedTensorType::get (
209208 oldBType.getShape (), oldBType.getElementType (), newBEncoding);
210209
211- a = rewriter.create <ConvertLayoutOp>(a.getLoc (), newAType, a);
212- b = rewriter.create <ConvertLayoutOp>(b.getLoc (), newBType, b);
213- DotOp newDot = rewriter.create <DotOp>(dotOp.getLoc (), newRetType, a, b,
214- newAcc, dotOp.getInputPrecision (),
215- dotOp.getMaxNumImpreciseAcc ());
210+ a = rewriter.create <ttg:: ConvertLayoutOp>(a.getLoc (), newAType, a);
211+ b = rewriter.create <ttg:: ConvertLayoutOp>(b.getLoc (), newBType, b);
212+ auto newDot = rewriter.create <tt:: DotOp>(dotOp.getLoc (), newRetType, a, b,
213+ newAcc, dotOp.getInputPrecision (),
214+ dotOp.getMaxNumImpreciseAcc ());
216215
217- rewriter.replaceOpWithNewOp <ConvertLayoutOp>(op , oldRetType,
218- newDot.getResult ());
216+ rewriter.replaceOpWithNewOp <ttg:: ConvertLayoutOp>(dotOp , oldRetType,
217+ newDot.getResult ());
219218 return success ();
220219 }
221220};
@@ -230,7 +229,7 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
230229
231230 return llvm::TypeSwitch<Type, Value>(elemType)
232231 .Case <FloatType>([&](auto ) {
233- return builder.create <FpToFpOp>(loc, tensorPromotedType, operand);
232+ return builder.create <tt:: FpToFpOp>(loc, tensorPromotedType, operand);
234233 })
235234 .Case <IntegerType>([&](auto ) {
236235 unsigned tgtBitWidth = elemType.getIntOrFloatBitWidth (),
@@ -248,12 +247,12 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
248247// promote operands of dot op if the existing combination is not natively
249248// supported.
250249static void decomposeMixedModeDotOp (ModuleOp mod) {
251- mod.walk ([](DotOp dotOp) -> void {
250+ mod.walk ([](tt:: DotOp dotOp) -> void {
252251 auto D = dotOp.getD ();
253252 OpBuilder builder (dotOp);
254253 Type AElType = dotOp.getA ().getType ().getElementType ();
255254 auto dpasLayout =
256- dyn_cast<intel::DpasEncodingAttr>(D.getType ().getEncoding ());
255+ dyn_cast<ttg:: intel::DpasEncodingAttr>(D.getType ().getEncoding ());
257256
258257 Type promoteType;
259258 if (dpasLayout) {
@@ -289,15 +288,13 @@ class TritonIntelGPUAccelerateMatmulPass
289288 void runOnOperation () override {
290289 MLIRContext *context = &getContext ();
291290 ModuleOp m = getOperation ();
292- DPASAnalysis &dpasAnalysis = getAnalysis<DPASAnalysis>();
291+ auto &dpasAnalysis = getAnalysis<ttg::intel:: DPASAnalysis>();
293292
294293 RewritePatternSet patterns (context);
295294 patterns.add <BlockedToDPAS>(context, dpasAnalysis);
296295 if (applyPatternsAndFoldGreedily (m, std::move (patterns)).failed ())
297296 signalPassFailure ();
298297
299- // now that we pick the scalar type decompose dot that are not natively
300- // supported.
301298 decomposeMixedModeDotOp (m);
302299 }
303300};
0 commit comments