33#include " mlir/IR/BuiltinAttributes.h"
44#include " mlir/IR/Dominance.h"
55#include " mlir/IR/Verifier.h"
6- #include " mlir/Pass/Pass.h"
76#include " mlir/Pass/PassManager.h"
87#include " triton/Dialect/Triton/IR/Dialect.h"
98#include " triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -17,9 +16,23 @@ namespace ttg = mlir::triton::gpu;
1716// Utility functions
1817// ===----------------------------------------------------------------------===//
1918
20- // Return true if the given moduleOp contains a pure matmul problem; i.e.,
21- // single dot in the main loop.
22- static bool isPureMatmulProblem (triton::FuncOp funcOp) {
19+ static SmallVector<scf::ForOp> getLeafForOps (triton::FuncOp funcOp) {
20+ SmallVector<scf::ForOp> allOps;
21+ funcOp->walk ([&](scf::ForOp forOp) { allOps.push_back (forOp); });
22+
23+ SmallVector<scf::ForOp> leafOps;
24+ for (scf::ForOp forOp : allOps) {
25+ auto searchResult = forOp.getBody ()->walk (
26+ [](scf::ForOp) { return WalkResult::interrupt (); });
27+ if (!searchResult.wasInterrupted ())
28+ leafOps.push_back (forOp);
29+ }
30+ return leafOps;
31+ }
32+
33+ // Return true if the given funcOp is a pure matmul problem; i.e.,
34+ // a single main loop with a single dot.
35+ static bool isPureMatmulFunc (triton::FuncOp funcOp) {
2336 bool isMatmul = true ;
2437 bool foundLoop = false ;
2538 funcOp.walk ([&](scf::ForOp forOp) -> void {
@@ -31,6 +44,20 @@ static bool isPureMatmulProblem(triton::FuncOp funcOp) {
3144 return foundLoop && isMatmul;
3245}
3346
47+ // Return true if the given ForOp contains a pure matmul problem; i.e.,
48+ // single dot and at least 2 glboal loads in the main loop.
49+ static bool isPureMatmulLoop (scf::ForOp forOp) {
50+ int dotCounter = 0 ;
51+ int loadCounter = 0 ;
52+ forOp.walk ([&](Operation *op) {
53+ if (isa<triton::DotOp>(op))
54+ ++dotCounter;
55+ else if (isa<triton::LoadOp>(op))
56+ ++loadCounter;
57+ });
58+ return dotCounter == 1 && loadCounter >= 2 ;
59+ }
60+
3461// Search through block to find earliest insertion point for move op. This can
3562// be either an atomic op or last usage of source pointer. Search ends when move
3663// op is encountered.
@@ -214,14 +241,41 @@ static void moveUpTranspose(triton::FuncOp funcOp) {
214241}
215242
216243// Schedule global load and local store ops for better GEMM performance.
217- static void scheduleGlobalLoadLocalStore (triton::FuncOp funcOp ) {
244+ static void scheduleGlobalLoadLocalStore (Operation *parentOp ) {
218245 SmallVector<Operation *> moveOps;
219- // Move local_stores early if dependence distance greater than one iteration.
220- // Best perf on GEMM when these precede global loads.
221- funcOp.walk ([&](ttg::LocalStoreOp op) { moveOps.push_back (op); });
222- // Move global loads early to prefetch. This may increase register pressure
223- // but it enables issuing global loads early.
224- funcOp.walk ([&](triton::LoadOp op) { moveOps.push_back (op); });
246+
247+ // Search through the forOp initArgs to find global loads for a GEMM that
248+ // the pipeliner may have peeled into a loop prologue.
249+ if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
250+ SmallVector<Value> vals = forOp.getInitArgs ();
251+ while (!vals.empty ()) {
252+ SmallVector<Value> nextVals; // Next set of values to search via BFS.
253+ for (size_t i = 0 ; i < vals.size (); ++i) {
254+ Operation *defOp = vals[i].getDefiningOp ();
255+ if (isa_and_nonnull<triton::LoadOp>(defOp)) {
256+ moveOps.push_back (defOp);
257+ continue ;
258+ }
259+
260+ // Find uses of the op that are local_store
261+ for (Operation *op : vals[i].getUsers ()) {
262+ if (auto storeOp = dyn_cast<ttg::LocalStoreOp>(op)) {
263+ // Recurse on operands of the local_store (to find a global_load).
264+ nextVals.push_back (storeOp.getSrc ());
265+ }
266+ }
267+ }
268+ vals.swap (nextVals);
269+ }
270+ }
271+
272+ // Move local_store ops inside the loop early if dependence distance greater
273+ // than one iteration (i.e., num_stages > 2). For such case, better perf on
274+ // GEMM when local_store ops precede global loads.
275+ parentOp->walk ([&](ttg::LocalStoreOp op) { moveOps.push_back (op); });
276+ // Move global_load ops inside the loop early to prefetch. This may increase
277+ // register pressure but it enables issuing global loads early.
278+ parentOp->walk ([&](triton::LoadOp op) { moveOps.push_back (op); });
225279
226280 for (auto op : llvm::reverse (moveOps)) {
227281 // Gather use-def chain in block.
@@ -314,38 +368,36 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
314368// are experimenting how to better control instruction scheduling and enable
315369// such optimizations.
316370// ===-------------------------------------------------------------------===//
317- static void sinkSecondLoad (triton::FuncOp funcOp) {
318- funcOp.walk ([&](scf::ForOp forOp) -> void {
319- SetVector<triton::LoadOp> loadOps;
320- triton::DotOp dotOp;
321- for (Operation &op : forOp) {
322- if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
323- loadOps.insert (loadOp);
324- if (auto curOp = dyn_cast<triton::DotOp>(&op))
325- dotOp = curOp;
326- }
327- // Only apply the optimization when there are 2 load's in the loop
328- if (loadOps.size () != 2 )
329- return ;
330- // Only apply the optimization when tile size is large enough
331- // 1. nonKDim >= 128
332- // 2. kDim >= 64
333- auto ldAOp = loadOps[0 ];
334- auto tileAShape = cast<RankedTensorType>(ldAOp.getType ()).getShape ();
335- auto ldBOp = loadOps[1 ];
336- auto tileBShape = cast<RankedTensorType>(ldBOp.getType ()).getShape ();
337- if (!(tileAShape[0 ] >= 128 && tileAShape[1 ] >= 64 && tileBShape[1 ] >= 128 ))
338- return ;
339- // Only apply the optimization when the moving is legal
340- // 1. Make sure the 2nd loadOp is before the dot
341- // 2. Make sure the first user of the 2nd loadOp is after the dot.
342- bool isBeforeDotOp = ldBOp->isBeforeInBlock (dotOp);
343- auto firstUser = *ldBOp.getResult ().getUsers ().begin ();
344- bool firstUserAfterDotOp = dotOp->isBeforeInBlock (firstUser);
345- if (isBeforeDotOp && firstUserAfterDotOp)
346- // move ldBOp right before tt.dot
347- ldBOp->moveBefore (dotOp);
348- });
371+ static void sinkSecondLoad (scf::ForOp forOp) {
372+ SetVector<triton::LoadOp> loadOps;
373+ triton::DotOp dotOp;
374+ for (Operation &op : forOp) {
375+ if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
376+ loadOps.insert (loadOp);
377+ if (auto curOp = dyn_cast<triton::DotOp>(&op))
378+ dotOp = curOp;
379+ }
380+ // Only apply the optimization when there are 2 load's in the loop
381+ if (loadOps.size () != 2 )
382+ return ;
383+ // Only apply the optimization when tile size is large enough
384+ // 1. nonKDim >= 128
385+ // 2. kDim >= 64
386+ auto ldAOp = loadOps[0 ];
387+ auto tileAShape = cast<RankedTensorType>(ldAOp.getType ()).getShape ();
388+ auto ldBOp = loadOps[1 ];
389+ auto tileBShape = cast<RankedTensorType>(ldBOp.getType ()).getShape ();
390+ if (!(tileAShape[0 ] >= 128 && tileAShape[1 ] >= 64 && tileBShape[1 ] >= 128 ))
391+ return ;
392+ // Only apply the optimization when the moving is legal
393+ // 1. Make sure the 2nd loadOp is before the dot
394+ // 2. Make sure the first user of the 2nd loadOp is after the dot.
395+ bool isBeforeDotOp = ldBOp->isBeforeInBlock (dotOp);
396+ auto firstUser = *ldBOp.getResult ().getUsers ().begin ();
397+ bool firstUserAfterDotOp = dotOp->isBeforeInBlock (firstUser);
398+ if (isBeforeDotOp && firstUserAfterDotOp)
399+ // move ldBOp right before tt.dot
400+ ldBOp->moveBefore (dotOp);
349401}
350402
351403// ===----------------------------------------------------------------------===//
@@ -369,9 +421,17 @@ struct TritonAMDGPUReorderInstructionsPass
369421
370422 moveUpTranspose (funcOp);
371423
372- if (isPureMatmulProblem (funcOp)) {
424+ if (isPureMatmulFunc (funcOp)) {
373425 scheduleGlobalLoadLocalStore (funcOp);
374- sinkSecondLoad (funcOp);
426+ funcOp.walk ([&](scf::ForOp forOp) -> void { sinkSecondLoad (forOp); });
427+ } else {
428+ SmallVector<scf::ForOp> leafForOps = getLeafForOps (funcOp);
429+ for (auto forOp : leafForOps) {
430+ if (isPureMatmulLoop (forOp)) {
431+ scheduleGlobalLoadLocalStore (forOp);
432+ sinkSecondLoad (forOp);
433+ }
434+ }
375435 }
376436 }
377437 }
0 commit comments