Skip to content

Commit cf9ebd1

Browse files
authored
[BACKEND] Make sure we lower load to async_cp only when supported (#7176)
We ran into cases where we accidently created unsupported async_cp
1 parent 7ce287d commit cf9ebd1

File tree

5 files changed

+97
-61
lines changed

5 files changed

+97
-61
lines changed

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ static const char *kLoopStageAttrName = "loop.stage";
2020
static const char *kLoopClusterAttrName = "loop.cluster";
2121
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
2222
class CoarseSchedule;
23-
23+
class ModuleAxisInfoAnalysis;
2424
//===----------------------------------------------------------------------===//
2525
// Hoisting Utilities
2626
//===----------------------------------------------------------------------===//
@@ -87,6 +87,9 @@ std::pair<Operation *, int64_t> getDefiningOpAndDistance(scf::ForOp forOp,
8787
int getCopyVecBytes(RankedTensorType registerTy,
8888
gpu::SharedEncodingTrait sharedEnc);
8989

90+
bool canBeConvertedToAsyncLoad(
91+
triton::LoadOp loadOp, triton::ModuleAxisInfoAnalysis &axisInfoAnalysis);
92+
9093
// Serialize the latencies of the operations in the loops into the latency
9194
// attribute.
9295
void serializeLatencies(ModuleOp module, DenseMap<Operation *, int> &opLatency);

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -106,34 +106,10 @@ class AssignLoadLatencies {
106106
return !incompatible;
107107
}
108108

109-
bool isSmallLoad(tt::LoadOp loadOp,
110-
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
111-
assert(!isLoadFromTensorPtr(loadOp) &&
112-
"Block ptr should have been lowered before this pass.");
113-
auto ptr = loadOp.getPtr();
114-
unsigned vec = axisInfoAnalysis.getContiguity(ptr);
115-
if (auto mask = loadOp.getMask())
116-
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
117-
118-
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
119-
if (!tensorTy)
120-
return true;
121-
auto ty = cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
122-
unsigned width = vec * ty.getIntOrFloatBitWidth();
123-
124-
// We do not pipeline all loads for the following reasons:
125-
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16.
126-
// 2. It's likely that pipling small loads won't offer much performance
127-
// improvement and may even hurt performance by increasing register
128-
// pressure.
129-
LDBG("Load " << *loadOp << " has width " << width);
130-
return width < 32;
131-
}
132-
133109
bool isPipeliningBeneficial(Operation *op, Operation *finalUser,
134110
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
135111
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
136-
if (isSmallLoad(loadOp, axisInfoAnalysis)) {
112+
if (!canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) {
137113
LDBG("Load " << *loadOp << " is too small for pipelining");
138114
return false;
139115
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "mlir/Dialect/UB/IR/UBOps.h"
22
#include "mlir/IR/Dominance.h"
33
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4+
#include "triton/Analysis/AxisInfo.h"
45
#include "triton/Analysis/Utility.h"
56
#include "triton/Dialect/Triton/IR/Dialect.h"
67
#include "triton/Dialect/Triton/IR/Types.h"
@@ -441,7 +442,8 @@ bool loadRequiresAdditionalBuffer(Operation *loadOp) {
441442
return false;
442443
}
443444

444-
scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
445+
scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
446+
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
445447
llvm::MapVector<Operation *, AsyncLoad> asyncLoads;
446448
llvm::MapVector<int, LoadGroupInfo> loadGroups;
447449
// Only visit the top level ops, we do not support pipelining conditional
@@ -457,9 +459,13 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
457459
SharedEncodingTrait sharedEncoding = getSharedEncoding(&op);
458460
// Do not create async loads for small loads (cp.async requires at least 4
459461
// bytes)
462+
bool canUseAsyncCp =
463+
isa<tt::LoadOp>(op) &&
464+
canBeConvertedToAsyncLoad(cast<tt::LoadOp>(op), axisInfoAnalysis);
460465
int copyVecBytes = getCopyVecBytes(
461466
cast<RankedTensorType>(op.getResultTypes()[0]), sharedEncoding);
462-
if (copyVecBytes >= 4 || isTMALoad(&op)) {
467+
canUseAsyncCp &= copyVecBytes >= 4;
468+
if (canUseAsyncCp || isTMALoad(&op)) {
463469
if (loadRequiresAdditionalBuffer(&op)) {
464470
// Allocate additional buffer required by the wgmma pipelining.
465471
stageDiff += 1;
@@ -1008,26 +1014,28 @@ scf::ForOp lowerMMAs(scf::ForOp forOp, CoarseSchedule &schedule) {
10081014
// LOWER LOOP
10091015
/////////////////////////////
10101016

1011-
void lowerLoop(scf::ForOp forOp) {
1017+
void lowerLoop(scf::ForOp forOp,
1018+
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
10121019
CoarseSchedule schedule;
10131020
if (failed(schedule.deSerialize(forOp))) {
10141021
return;
10151022
}
10161023
scf::ForOp newForOp = lowerMMAs(forOp, schedule);
1017-
newForOp = lowerLoads(newForOp, schedule);
1024+
newForOp = lowerLoads(newForOp, schedule, axisInfoAnalysis);
10181025
newForOp = lowerTMADescriptors(newForOp, schedule);
10191026
schedule.serialize(newForOp);
10201027
}
10211028

10221029
} // namespace
10231030

10241031
void lowerLoops(ModuleOp moduleOp) {
1032+
triton::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
10251033
SmallVector<scf::ForOp> loops;
10261034
moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
10271035
if (loops.empty())
10281036
return;
10291037
for (auto forOp : loops) {
1030-
lowerLoop(forOp);
1038+
lowerLoop(forOp, axisInfoAnalysis);
10311039
}
10321040
}
10331041

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/IR/TypeUtilities.h"
77
#include "mlir/Interfaces/SideEffectInterfaces.h"
88
#include "mlir/Support/LLVM.h"
9+
#include "triton/Analysis/AxisInfo.h"
910
#include "triton/Dialect/Triton/IR/Utility.h"
1011
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1112
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -14,8 +15,13 @@
1415
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1516
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
1617
#include "llvm/Support/Casting.h"
18+
#include "llvm/Support/Debug.h"
1719
#include <queue>
1820

21+
#define DEBUG_TYPE "triton-loop-pipeline"
22+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
23+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
24+
1925
using namespace mlir;
2026
namespace tt = mlir::triton;
2127
namespace ttg = mlir::triton::gpu;
@@ -313,6 +319,30 @@ int mlir::triton::getCopyVecBytes(RankedTensorType registerTy,
313319
return vecElems * registerTy.getElementTypeBitWidth() / 8;
314320
}
315321

322+
bool mlir::triton::canBeConvertedToAsyncLoad(
323+
tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
324+
assert(!isLoadFromTensorPtr(loadOp) &&
325+
"Block ptr should have been lowered before this pass.");
326+
auto ptr = loadOp.getPtr();
327+
unsigned vec = axisInfoAnalysis.getContiguity(ptr);
328+
if (auto mask = loadOp.getMask())
329+
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
330+
331+
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
332+
if (!tensorTy)
333+
return false;
334+
auto ty = cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
335+
unsigned width = vec * ty.getIntOrFloatBitWidth();
336+
337+
// We do not pipeline all loads for the following reasons:
338+
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16.
339+
// 2. It's likely that pipling small loads won't offer much performance
340+
// improvement and may even hurt performance by increasing register
341+
// pressure.
342+
LDBG("Load " << *loadOp << " has width " << width);
343+
return width >= 32;
344+
}
345+
316346
void mlir::triton::serializeLatencies(ModuleOp module,
317347
DenseMap<Operation *, int> &opLatency) {
318348
auto helper = TritonDialect::getLoaded(module)->getLatencyAttrHelper();

0 commit comments

Comments
 (0)