Skip to content

Commit d250c37

Browse files
manman-renmeta-codesync[bot]
authored andcommitted
[autoWS] port support from ws-3.5 (#533)
Summary: including up to [autoWS] Reorder epilog ops to favor cooperative warp scheduling. (#461) also picking [autoWS] fix lit failures (#622) Fixed TritonGPU/loop-schedule.mlir (#631) [autoWS] Generalize passes to handle causal (#419) [autoWS] fix pytest failures (#639) Track lit failures in T243551722 T243551750 Guard OptimizePartitionWarps with tlx. With Meta's autoWS we run the pass outside of autoWS. Pull Request resolved: #533 Test Plan: pytest python/tutorials/fused-attention-ws.py -rs pytest python/tutorials/fused-attention-ws-device-tma.py -rs Reviewed By: minjang Differential Revision: D85833462 Pulled By: manman-ren fbshipit-source-id: 1149aa94e5470536536a522927414462dd950e87
1 parent bad736e commit d250c37

File tree

63 files changed

+9923
-812
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+9923
-812
lines changed

include/triton/Analysis/Allocation.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ namespace mlir {
1313

1414
namespace triton {
1515
class AllocationAnalysis;
16+
class MemoryPlanner;
17+
class MemoryPlannerTmem;
1618

1719
/// Callback to allow backends to specify target-specific scratch sizes for
1820
/// some operations.
@@ -154,6 +156,15 @@ class Allocation {
154156
size_t alignment;
155157
size_t offset;
156158

159+
// For MemoryPlannerTmem
160+
bool isOwnerOfSpace;
161+
size_t rowOffset;
162+
size_t colOffset;
163+
size_t rowSize;
164+
size_t colSize;
165+
size_t reuseOffset; // when isOwnerOfSpace is true
166+
BufferT *reuseOwner; // when isOwnerOfSpace is false
167+
157168
bool operator==(const BufferT &other) const { return id == other.id; }
158169
bool operator<(const BufferT &other) const { return id < other.id; }
159170

@@ -208,6 +219,8 @@ class Allocation {
208219
size_t bufferIdCounter = 0;
209220

210221
friend class triton::AllocationAnalysis;
222+
friend class triton::MemoryPlanner;
223+
friend class triton::MemoryPlannerTmem;
211224
};
212225

213226
/// Static analysis that computes the allocation of shared memory buffers

include/triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ namespace mlir::triton::gpu {
1212
void attachAllocationSizeAndOffsetAttr(ModuleOp mod,
1313
ModuleAllocation &allocation);
1414

15+
/// Add shared memory access annotations to all operations that use shared
16+
/// memory Only adds annotations when MLIR_ENABLE_DUMP=1 is set.
17+
void addSharedMemoryAnnotations(ModuleOp mod);
18+
1519
} // namespace mlir::triton::gpu
1620

1721
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ALLOCATE_UTILITY_H_

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
4545
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
4646
constexpr static char AttrTargetName[] = "ttg.target";
4747
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
48+
constexpr static char AttrMinRegAutoWSName[] = "ttg.min_reg_auto_ws";
49+
constexpr static char AttrMaxRegAutoWSName[] = "ttg.max_reg_auto_ws";
4850

4951
// Find the contextual number of warps on which this operation is executed.
5052
int lookupNumWarps(Operation *op);

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def TritonGPUScheduleLoops : Pass<"tritongpu-schedule-loops", "mlir::ModuleOp">
4747
The `tritongpu-schedule-loops` pass performs scheduling for loop pipelining
4848
for loops with latency ops.
4949
}];
50+
51+
let options = [
52+
Option<"numStages", "num-stages", "int32_t", /*default*/"3",
53+
"number of pipeline stages">
54+
];
5055
}
5156

5257
def TritonGPUHoistTMEMAlloc : Pass<"tritongpu-hoist-tmem-alloc", "mlir::ModuleOp"> {

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ void lowerLoops(ModuleOp moduleOp);
2020

2121
bool hasGpuBarriers(scf::ForOp forOp);
2222
bool isSafeToPipeline(scf::ForOp forOp);
23+
// Do any preprocessing on the loop information for a given module.
24+
void doLoopSchedulePreprocessing(ModuleOp moduleOp, Builder &builder);
25+
// TODO: Remove me and move to pass structure.
26+
void scheduleLoops(ModuleOp moduleOp, int defaultNumStages);
2327
llvm::MapVector<Operation *, std::pair<int, Operation *>>
2428
loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
2529
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis,
@@ -155,7 +159,7 @@ class CoarseSchedule {
155159
auto begin() const { return opToStageAndCluster.begin(); }
156160

157161
// Set <stage, cluster> based on CoarseSchedule.
158-
void serialize(scf::ForOp &forOp) const;
162+
void serialize(scf::ForOp &forOp, bool keepExistingMaxStage = true) const;
159163
// Create a CoarseSchedule based on forOp's <stage, cluster>.
160164
LogicalResult deSerialize(scf::ForOp &forOp);
161165

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,11 @@ LogicalResult getConvertBackwardSlice(
183183
nullptr);
184184

185185
// Populate pattern to remove dead cycles in ForOp.
186-
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);
186+
// opsCanBeTriviallyDead specifies the operations of which the side effect can
187+
// be ignored.
188+
void populateForOpDeadArgumentElimination(
189+
RewritePatternSet &patterns,
190+
const DenseSet<Operation *> &opsCanBeTriviallyDead = {});
187191

188192
// Convert an \param index to a multi-dim coordinate given \param shape and
189193
// \param order.
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
22
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
33

4+
#include "triton/Analysis/Allocation.h"
45
#include "triton/Dialect/Triton/IR/Dialect.h"
56
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
67

78
namespace mlir::triton::nvidia_gpu {
89

910
LogicalResult verifyBarrierType(Operation *op,
1011
mlir::triton::gpu::MemDescType barrierType);
12+
int allocateTMemWithInterval(
13+
DenseMap<Operation *, Interval<int>> &allocToIntervals,
14+
SmallVector<Operation *> &allocOrder);
1115

12-
}
16+
} // namespace mlir::triton::nvidia_gpu
1317

1418
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3939
"TRITON_LLVM_DEBUG_ONLY",
4040
"TRITON_ENABLE_ASAN",
4141
"TRITON_OVERRIDE_ARCH",
42+
"TRITON_USE_OAI_WS",
4243
"USE_IR_LOC",
4344
"NVPTX_ENABLE_DUMP",
4445
"ALLOW_LHS_TMEM_LAYOUT_CONVERSION",

lib/Analysis/Allocation.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,16 @@ class AllocationAnalysis {
302302
continue;
303303
}
304304

305-
// Any scratch memory's live range is the current operation's live
306-
// range.
307-
bufferRange.insert(
308-
{buffer, Interval(operationId.at(op), operationId.at(op) + 1)});
305+
if (op && isa<mlir::triton::gpu::WarpSpecializeOp>(op)) {
306+
bufferRange.insert(
307+
{buffer, Interval((size_t)0, (size_t)operationId.size())});
308+
} else {
309+
310+
// Any scratch memory's live range is the current operation's live
311+
// range.
312+
bufferRange.insert(
313+
{buffer, Interval(operationId.at(op), operationId.at(op) + 1)});
314+
}
309315
LLVM_DEBUG({
310316
llvm::dbgs() << "-- buffer " << buffer->id << "; value: ";
311317
op->dump();
@@ -341,15 +347,23 @@ class AllocationAnalysis {
341347
// Analyze liveness of explicit buffers
342348
Liveness liveness(operation);
343349
auto getValueLivenessRange = [&](Value value) {
350+
Operation *defOp = value.getDefiningOp();
344351
auto liveOperations = liveness.resolveLiveness(value);
345352
auto minId = std::numeric_limits<size_t>::max();
346353
auto maxId = std::numeric_limits<size_t>::min();
347354
llvm::for_each(liveOperations, [&](Operation *liveOp) {
348-
if (operationId[liveOp] < minId) {
349-
minId = operationId[liveOp];
350-
}
351-
if ((operationId[liveOp] + 1) > maxId) {
352-
maxId = operationId[liveOp] + 1;
355+
if (liveOp && isa<mlir::triton::gpu::WarpSpecializeOp>(liveOp)) {
356+
minId = 0;
357+
if ((operationId[liveOp] + 1) > maxId) {
358+
maxId = operationId[liveOp] + 1;
359+
}
360+
} else {
361+
if (operationId[liveOp] < minId) {
362+
minId = operationId[liveOp];
363+
}
364+
if ((operationId[liveOp] + 1) > maxId) {
365+
maxId = operationId[liveOp] + 1;
366+
}
353367
}
354368
});
355369
return Interval(minId, maxId);

lib/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,91 @@
11
#include "triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h"
2+
#include "triton/Analysis/Allocation.h"
3+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
4+
#include "triton/Tools/Sys/GetEnv.hpp"
5+
#include <cstdlib>
6+
#include <string>
27

38
namespace mlir::triton::gpu {
49

10+
// Helper function to compute allocation size from MemDescType
11+
inline size_t computeAllocationSize(MemDescType memdescTy) {
12+
auto elemTy = memdescTy.getElementType();
13+
auto shape = memdescTy.getShape();
14+
size_t elemSize = elemTy.getIntOrFloatBitWidth() / 8;
15+
size_t totalElements = 1;
16+
for (auto dim : shape) {
17+
totalElements *= dim;
18+
}
19+
return totalElements * elemSize;
20+
}
21+
22+
// Helper function to add allocation information as IR annotations
23+
void addAllocationAnnotations(Operation *op) {
24+
MLIRContext *ctx = op->getContext();
25+
IntegerAttr offsetAttr;
26+
MemDescType memdescTy;
27+
28+
// Try to get allocation.offset from the operation itself
29+
if (auto attr = op->getAttrOfType<IntegerAttr>("allocation.offset")) {
30+
offsetAttr = attr;
31+
// Find MemDescType from result or operands
32+
for (auto result : op->getResults()) {
33+
if (auto ty = dyn_cast<MemDescType>(result.getType())) {
34+
memdescTy = ty;
35+
break;
36+
}
37+
}
38+
if (!memdescTy) {
39+
for (auto operand : op->getOperands()) {
40+
if (auto ty = dyn_cast<MemDescType>(operand.getType())) {
41+
memdescTy = ty;
42+
break;
43+
}
44+
}
45+
}
46+
} else {
47+
// Try to find it through operands
48+
for (auto operand : op->getOperands()) {
49+
if (auto definingOp = operand.getDefiningOp()) {
50+
if (auto allocOp = dyn_cast<triton::gpu::LocalAllocOp>(definingOp)) {
51+
if (auto attr =
52+
allocOp->getAttrOfType<IntegerAttr>("allocation.offset")) {
53+
offsetAttr = attr;
54+
memdescTy = cast<MemDescType>(allocOp.getType());
55+
break;
56+
}
57+
}
58+
}
59+
}
60+
}
61+
62+
if (!offsetAttr || !memdescTy) {
63+
return;
64+
}
65+
66+
auto offset = offsetAttr.getInt();
67+
size_t totalSize = computeAllocationSize(memdescTy);
68+
op->setAttr("shared_memory.offset",
69+
IntegerAttr::get(IntegerType::get(ctx, 64), offset));
70+
op->setAttr("shared_memory.size_bytes",
71+
IntegerAttr::get(IntegerType::get(ctx, 64), totalSize));
72+
}
73+
74+
// Function to add shared memory access annotations to all operations that use
75+
// shared memory
76+
void addSharedMemoryAnnotations(ModuleOp mod) {
77+
if (!triton::tools::getBoolEnv("MLIR_ENABLE_DUMP")) {
78+
return;
79+
}
80+
81+
mod.walk([&](Operation *op) {
82+
if (isa<triton::gpu::LocalStoreOp, triton::gpu::LocalLoadOp,
83+
triton::gpu::MemDescSubsliceOp, triton::gpu::MemDescIndexOp>(op)) {
84+
addAllocationAnnotations(op);
85+
}
86+
});
87+
}
88+
589
void attachAllocationSizeAndOffsetAttr(ModuleOp mod,
690
ModuleAllocation &allocation) {
791
MLIRContext *ctx = mod.getContext();

0 commit comments

Comments
 (0)