Skip to content

Commit ee6abd9

Browse files
Mogballcsullivan
andauthored
[Pipeliner] Implement more sophisticated partitioning strategy for attention (#6660)
This PR implements a more sophisticated partitioning and scheduling strategy, aimed at generating a better schedule for forward attention. This PR enables composing the pipeliner with warp specialization by tweaking `scheduleLoops` to consider a `ttg.assigned_stage` attribute. Normally, the pipeline scheduler determines stages based on a DAG between latency ops, but this allows some higher level pass to directly inject an assigned stage into `scheduleKeyOps`. This allows, for example, warp specialization to place `load K` into stage 0 and `load V` into stage 2, and `MMA QK` into stage 0 and `MMA PV` into stage 2, in their respective partitions. This PR also reworks partition assignment to generate more than 1 vector partition. Previously, the partitioning strategy was generalized to take the body of the loop and "outline" async loads and async MMAs into their own partition, keeping the rest of the loop body in the default partition. Now, partition assignment also considers high latency synchronous operations, starting with `math.exp2` with a large number of elements. It does a first-order partition assignment, placing local users and dependencies of loads and MMAs into the same partition, and then clusters the remaining operations together. The clusters are then either: 1. Assigned entirely to the source partition 2. Assigned entirely to the sink partition 3. Assigned into a wholly new partition 4. Rematerialized into sink partitions along critical paths The idea is simply to simultaneously reduce the critical path between each latency operation. This strategy successfully automatically derives the same schedule CUTLASS uses for FMHA (correction partition, etc.). There is currently no cost model for deciding between rematerialization or sending intermediates over shared memory/placing them in their own partition, but we can probably reuse the one @apgoucher implemented for `remove-layout-conversions` in triton-lang/triton#6667. This also slightly tweaks the kernel code in `06-fused-attention.py` to reduce register pressure. This achieves close to 700 TFLOPS on DHEAD=64 and around 960-1080 TFLOPS on DHEAD=128. TODO: - [ ] Write lit tests for new scheduler - [ ] Write integration tests for MFHA --------- Co-authored-by: Chris Sullivan <[email protected]>
1 parent 3c8cdbf commit ee6abd9

File tree

15 files changed

+1241
-218
lines changed

15 files changed

+1241
-218
lines changed

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ class WarpSchedule {
6161

6262
// Create a new partition with a stage.
6363
Partition *addPartition(unsigned stage);
64-
// Give each partition a new index and order. The indices must be unique.
65-
void reorderPartitions(ArrayRef<unsigned> order);
6664
// Update the op to partition mapping.
6765
void updatePartitions();
6866

@@ -74,15 +72,25 @@ class WarpSchedule {
7472
Partition *getPartition(unsigned idx);
7573
// Get the partition at the index.
7674
const Partition *getPartition(unsigned idx) const;
75+
// Insert an operation into a partition.
76+
void insert(Partition *partition, Operation *op);
7777
// Return an iterator range over the partitions.
7878
auto getPartitions() { return llvm::make_pointee_range(partitions); }
7979
// Return an iterator range over the partitions.
8080
auto getPartitions() const { return llvm::make_pointee_range(partitions); }
81+
// Get the number of partitions.
82+
unsigned getNumPartitions() const { return partitions.size(); }
8183
// Get the root partition.
8284
Partition *getRootPartition() { return rootPartition.get(); }
8385
// Get the root partition.
8486
const Partition *getRootPartition() const { return rootPartition.get(); }
8587

88+
// Return true if an operation is assigned to a partition.
89+
bool isScheduled(Operation *op) const;
90+
// Schedule an operation to a partition if it is not already scheduled. Return
91+
// true if the operation was scheduled.
92+
bool trySchedule(Partition *partition, Operation *op);
93+
8694
// Deserialize a warp schedule from an `scf.for` op using the attributes
8795
// tagged on operations in its body.
8896
static FailureOr<WarpSchedule> deserialize(scf::ForOp loop);

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ static const char *kWarpSpecializeAttrName = "tt.warp_specialize";
1919
static const char *kLoopStageAttrName = "loop.stage";
2020
static const char *kLoopClusterAttrName = "loop.cluster";
2121
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
22+
static const char *kAssignedStageAttrName = "ttg.assigned_stage";
2223

2324
//===----------------------------------------------------------------------===//
2425
// Hoisting Utilities

lib/Analysis/Membar.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "triton/Analysis/Membar.h"
22
#include "triton/Analysis/Alias.h"
33
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
4+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
45

56
#include "mlir/Dialect/Func/IR/FuncOps.h"
67
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -216,6 +217,13 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
216217
}
217218
}
218219
}
220+
// If this op is may be signalling other threads asynchronously, make sure
221+
// all shared memory transactions are complete beforehand.
222+
if (isa<triton::nvidia_gpu::ArriveBarrierOp>(op)) {
223+
Interval<size_t> allIntervals(0, std::numeric_limits<size_t>::max());
224+
curBlockInfo.syncWriteIntervals[allIntervals].insert(op);
225+
curBlockInfo.syncReadIntervals[allIntervals].insert(op);
226+
}
219227
scratchBufferId = allocation->getBufferId(op);
220228
}
221229

lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ struct TMEMLoadOpPattern : public OpConversionPattern<ttng::TMEMLoadOp> {
4444
RankedTensorType type = getTMEMTensorLayout(
4545
typeConverter, op.getType(), op.getSrc().getType(), lookupNumWarps(op));
4646
rewriter.modifyOpInPlace(op, [&] { op.getResult().setType(type); });
47+
Type resultType = getTypeConverter()->convertType(op.getType());
48+
rewriter.setInsertionPointAfter(op);
49+
auto cvt = rewriter.create<ConvertLayoutOp>(op.getLoc(), resultType,
50+
op.getResult());
51+
rewriter.replaceAllUsesExcept(op.getResult(), cvt, cvt);
4752
return success();
4853
}
4954
};

lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,13 @@ void tt::CoarseSchedule::serialize(scf::ForOp &forOp) {
189189
for (auto [op, stage, cluster] : getOpsInOrder(forOp)) {
190190
setStageCluster(op, stage, *cluster);
191191
}
192+
193+
Builder b(forOp.getContext());
194+
int maxStages = numStages - 1;
195+
if (auto maxStageAttr = tryGetMaxStage(forOp))
196+
maxStages = std::max(maxStages, *maxStageAttr);
192197
forOp->setAttr(mlir::triton::kScheduledMaxStageAttrName,
193-
IntegerAttr::get(IntegerType::get(forOp.getContext(), 32),
194-
numStages - 1));
198+
b.getI32IntegerAttr(maxStages));
195199
}
196200

197201
// Create a CoarseSchedule based on forOp's <stage, cluster>.

lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ bool hasLatenciesAssigned(scf::ForOp forOp,
5959
for (auto &op : forOp.getBody()->without_terminator()) {
6060
if (opLatency.count(&op))
6161
return true;
62+
if (op.getAttr(kAssignedStageAttrName))
63+
return true;
6264
}
6365
return false;
6466
}
@@ -70,12 +72,15 @@ CoarseSchedule scheduleKeyOps(scf::ForOp forOp,
7072
auto terminator = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
7173
// Determine all operations that have a non-zero latency
7274
SmallVector<Operation *> latOps;
75+
SmallVector<Operation *> stagedOps;
7376
for (auto &op : forOp.getBody()->without_terminator()) {
7477
if (opLatency.count(&op))
7578
latOps.push_back(&op);
79+
if (op.getAttr(kAssignedStageAttrName))
80+
stagedOps.push_back(&op);
7681
}
7782
// If no latency ops, nothing to schedule
78-
if (latOps.empty())
83+
if (latOps.empty() && stagedOps.empty())
7984
return CoarseSchedule(0);
8085

8186
DominanceInfo domInfo(forOp);
@@ -123,6 +128,11 @@ CoarseSchedule scheduleKeyOps(scf::ForOp forOp,
123128
opToStage[op] = maxDistance - dist;
124129
}
125130

131+
for (Operation *op : stagedOps) {
132+
auto stageAttr = op->getAttrOfType<IntegerAttr>(kAssignedStageAttrName);
133+
opToStage[op] = stageAttr.getInt();
134+
}
135+
126136
auto stages = llvm::make_second_range(opToStage);
127137
int maxStage = *llvm::max_element(stages);
128138
CoarseSchedule schedule(maxStage + 1);

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ static void removeAttributes(ModuleOp moduleOp) {
7373
op->removeAttr(mlir::triton::kLoopStageAttrName);
7474
op->removeAttr(mlir::triton::kLoopClusterAttrName);
7575
op->removeAttr(mlir::triton::kScheduledMaxStageAttrName);
76+
op->removeAttr(mlir::triton::kAssignedStageAttrName);
7677
});
7778
}
7879

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,30 +1402,35 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
14021402
// Save the operand to replace / delete later (avoid iterator invalidation).
14031403
// TODO: can we use an early_inc iterator?
14041404
for (OpOperand &use : oldUse->getUses()) {
1405+
// Propagate through `ttg.warp_specialize`.
1406+
if (auto wsOp = dyn_cast<ttg::WarpSpecializeOp>(use.getOwner())) {
1407+
for (Region *region : wsOp.getPartitionRegions())
1408+
region->getArgument(use.getOperandNumber()).setType(val.getType());
1409+
}
1410+
14051411
// Non-subview/trans ops will be replaced by `val`.
1406-
if (!isa<triton::gpu::MemDescTransOp, triton::gpu::MemDescSubviewOp>(
1407-
use.getOwner())) {
1412+
if (!isa<ttg::MemDescTransOp, ttg::MemDescSubviewOp>(use.getOwner())) {
14081413
operandsToReplace.push_back(&use);
14091414
continue;
14101415
}
1416+
14111417
Operation *user = use.getOwner();
14121418
// `subview(old_op)` is replaced by a new `subview(val)`.
14131419
OpBuilder::InsertionGuard g(builder);
14141420
builder.setInsertionPoint(user);
14151421
Value newVal;
1416-
if (auto subview = dyn_cast<triton::gpu::MemDescSubviewOp>(user)) {
1417-
triton::gpu::MemDescType oldType = subview.getType();
1418-
bool isMutable =
1419-
cast<triton::gpu::MemDescType>(val.getType()).getMutableMemory();
1420-
Type newDstType = triton::gpu::MemDescType::get(
1422+
if (auto subview = dyn_cast<ttg::MemDescSubviewOp>(user)) {
1423+
ttg::MemDescType oldType = subview.getType();
1424+
bool isMutable = cast<ttg::MemDescType>(val.getType()).getMutableMemory();
1425+
Type newDstType = ttg::MemDescType::get(
14211426
oldType.getShape(), oldType.getElementType(), oldType.getEncoding(),
14221427
oldType.getMemorySpace(), isMutable);
1423-
newVal = builder.create<triton::gpu::MemDescSubviewOp>(
1428+
newVal = builder.create<ttg::MemDescSubviewOp>(
14241429
subview.getLoc(), newDstType, val, subview.getOffsets());
14251430
newVal.getDefiningOp()->setAttrs(user->getAttrs());
1426-
} else if (auto trans = dyn_cast<triton::gpu::MemDescTransOp>(user)) {
1427-
newVal = builder.create<triton::gpu::MemDescTransOp>(trans.getLoc(), val,
1428-
trans.getOrder());
1431+
} else if (auto trans = dyn_cast<ttg::MemDescTransOp>(user)) {
1432+
newVal = builder.create<ttg::MemDescTransOp>(trans.getLoc(), val,
1433+
trans.getOrder());
14291434
newVal.getDefiningOp()->setAttrs(user->getAttrs());
14301435
}
14311436
assert(newVal);

0 commit comments

Comments
 (0)