Skip to content

Commit 1f12637

Browse files
authored
[Hopper][WS] Update pipeline to get GEMM/FA working (#7136)
Builds and runs for GEMM with matmul_kernel_persistent_tma_ws_cooperative (i.e 2 consumer groups doing computation + epilogue, one producer group doing loads). Performance will be tuned in a followup diff. We set num_stages to 0 after WarpSpec to disable SWP. Also SWP is updated to bail out if there is no loop with num_stages >= 1.
1 parent 53e3e6a commit 1f12637

File tree

13 files changed

+540
-76
lines changed

13 files changed

+540
-76
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ namespace gpu {
3333
#define GEN_PASS_DEF_TRITONGPUPIPELINE
3434
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
3535

36-
static void pipelineWgmma(ModuleOp moduleOp) {
36+
static void pipelineWgmma(ModuleOp moduleOp, unsigned numStages) {
3737
SmallVector<scf::ForOp> loops;
3838
moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
3939

4040
for (scf::ForOp forOp : loops) {
41-
mlir::triton::asyncLaunchDots(forOp);
41+
if (getNumStagesOrDefault(forOp, numStages) >= 1)
42+
mlir::triton::asyncLaunchDots(forOp);
4243
}
4344
}
4445

@@ -223,7 +224,6 @@ struct PipelinePass : public impl::TritonGPUPipelineBase<PipelinePass> {
223224

224225
void runOnOperation() override {
225226
ModuleOp moduleOp = getOperation();
226-
227227
// Transform the loop by introducing async operations to prepare it for
228228
// pipeline expansion.
229229
lowerLoops(moduleOp);
@@ -244,7 +244,7 @@ struct PipelinePass : public impl::TritonGPUPipelineBase<PipelinePass> {
244244
// Cleanup the IR from the pipeline attributes.
245245
removeAttributes(moduleOp);
246246

247-
pipelineWgmma(moduleOp);
247+
pipelineWgmma(moduleOp, numStages);
248248

249249
// schedule the waits
250250
mlir::triton::updateWaits(getOperation());

python/tutorials/09-persistent-matmul.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,12 @@ def supports_tma():
4747
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
4848

4949

50+
def is_hopper():
51+
return torch.cuda.get_device_capability()[0] == 9
52+
53+
5054
def supports_ws():
51-
return is_cuda() and torch.cuda.get_device_capability()[0] >= 10
55+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
5256

5357

5458
def _matmul_launch_metadata(grid, kernel, args):
@@ -465,21 +469,31 @@ def grid(META):
465469
return c
466470

467471

468-
@triton.autotune(
469-
configs=matmul_tma_persistent_get_configs(),
470-
key=["M", "N", "K", "WARP_SPECIALIZE"],
471-
)
472+
def prune_invalid_configs(configs, named_args, **kwargs):
473+
FLATTEN = kwargs["FLATTEN"]
474+
# Filter out configs where EPILOGUE_SUBTILE is true and HOPPER is true
475+
return [conf for conf in configs if not (conf.kwargs.get("EPILOGUE_SUBTILE", True) and FLATTEN is False)]
476+
477+
478+
@triton.autotune(configs=matmul_tma_persistent_get_configs(), key=["M", "N", "K", "WARP_SPECIALIZE", "FLATTEN"],
479+
prune_configs_by={'early_config_prune': prune_invalid_configs})
472480
@triton.jit(launch_metadata=_matmul_launch_metadata)
473-
def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
474-
M, N, K, #
475-
BLOCK_SIZE_M: tl.constexpr, #
476-
BLOCK_SIZE_N: tl.constexpr, #
477-
BLOCK_SIZE_K: tl.constexpr, #
478-
GROUP_SIZE_M: tl.constexpr, #
479-
EPILOGUE_SUBTILE: tl.constexpr, #
480-
NUM_SMS: tl.constexpr, #
481-
WARP_SPECIALIZE: tl.constexpr, #
482-
):
481+
def matmul_kernel_descriptor_persistent(
482+
a_ptr,
483+
b_ptr,
484+
c_ptr, #
485+
M,
486+
N,
487+
K, #
488+
BLOCK_SIZE_M: tl.constexpr, #
489+
BLOCK_SIZE_N: tl.constexpr, #
490+
BLOCK_SIZE_K: tl.constexpr, #
491+
GROUP_SIZE_M: tl.constexpr, #
492+
EPILOGUE_SUBTILE: tl.constexpr, #
493+
NUM_SMS: tl.constexpr, #
494+
WARP_SPECIALIZE: tl.constexpr, #
495+
FLATTEN: tl.constexpr,
496+
):
483497
# Matmul using TMA and device-side descriptor creation
484498
dtype = c_ptr.dtype.element_ty
485499
start_pid = tl.program_id(axis=0)
@@ -512,7 +526,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
512526
tile_id_c = start_pid - NUM_SMS
513527
num_pid_in_group = GROUP_SIZE_M * num_pid_n
514528

515-
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE):
529+
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE):
516530
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
517531
offs_am = pid_m * BLOCK_SIZE_M
518532
offs_bn = pid_n * BLOCK_SIZE_N
@@ -560,12 +574,19 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
560574

561575
triton.set_allocator(alloc_fn)
562576

577+
# Hopper warpspec doesn't work with flatten
578+
flatten = False if (warp_specialize and is_hopper()) else True
563579
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
564580
matmul_kernel_descriptor_persistent[grid](
565-
a, b, c, #
566-
M, N, K, #
581+
a,
582+
b,
583+
c, #
584+
M,
585+
N,
586+
K, #
567587
NUM_SMS=NUM_SMS, #
568588
WARP_SPECIALIZE=warp_specialize, #
589+
FLATTEN=flatten,
569590
)
570591
return c
571592

@@ -632,7 +653,8 @@ def bench(K, dtype, reps=10000, warmup_reps=10000):
632653
warp_specialize = [False, True] if HAS_WARP_SPECIALIZE else [False]
633654
for ws in warp_specialize:
634655
ws_str = "_ws" if ws else ""
635-
if HAS_HOST_TENSOR_DESC:
656+
# disable on-host warpspec on Hopper
657+
if HAS_HOST_TENSOR_DESC and not (is_hopper() and ws):
636658
bench_fn(f"tma_persistent{ws_str}", reps, warmup_reps, lambda a, b: matmul_tma_persistent(a, b, ws), a, b)
637659
bench_fn(f"tma{ws_str}", reps, warmup_reps, lambda a, b: matmul_tma(a, b, ws), a, b)
638660
if HAS_TENSOR_DESC:
@@ -671,7 +693,9 @@ def validate(M, N, K, dtype):
671693

672694
for (kernel, label, enabled), warp_specialize in itertools.product(kernels, warp_specialize):
673695
label = f"{label} (warp_specialize={warp_specialize})"
674-
enabled = enabled and (not warp_specialize or HAS_TENSOR_DESC)
696+
# skip if hopper and warp_specialize and not on-device
697+
skipped = is_hopper() and warp_specialize and kernel != matmul_descriptor_persistent
698+
enabled = enabled and (not warp_specialize or HAS_TENSOR_DESC) and (not skipped)
675699
run_test(naive_result, lambda a, b: kernel(a, b, warp_specialize), a, b, label, enabled)
676700
print()
677701

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def make_ttgir(mod, metadata, opt, capability):
260260
passes.ttir.add_triton_licm(pm)
261261
passes.common.add_canonicalizer(pm)
262262
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
263+
nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
263264
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
264265
passes.ttgpuir.add_schedule_loops(pm)
265266
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)

third_party/nvidia/hopper/include/Transforms/Passes.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@ def NVGPUWarpSpecialization : Pass<"nvgpu-warp-specialization", "mlir::ModuleOp"
1414

1515
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
1616
let options = [
17-
Option<"numWarpGroups", "num-warp-groups",
17+
Option<"numStages", "num-stages",
1818
"int32_t", /*default*/"0",
19-
"number of warp groups for warp specialization">
19+
"number of buffers for warp specialization">,
20+
Option<"dumpIntermediateSteps", "dump-intermediate-steps",
21+
"bool", /*default*/"false",
22+
"Dump intermediate steps">
2023
];
2124
}
2225

third_party/nvidia/hopper/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_triton_library(NVHopperTransforms
77
WarpSpecialization/WSCodePartition.cpp
88
WarpSpecialization/WSDataPartition.cpp
99
WarpSpecialization/WSLowerMem.cpp
10+
WarpSpecialization/WSLowerToken.cpp
1011
WarpSpecialization/WSSpecialize.cpp
1112
WarpSpecialization/WSTaskIdPropagate.cpp
1213
WarpSpecialization/WSTaskPartition.cpp

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/Transforms/Passes.h"
44
#include "nvidia/hopper/include/Transforms/Passes.h"
55
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
67

78
#define DEBUG_TYPE "nvgpu-warp-specialization"
89
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -11,7 +12,10 @@
1112
namespace mlir {
1213

1314
void doTaskPartition(triton::FuncOp &funcOp, unsigned numWarpGroups);
15+
int doTaskIdPropagate(triton::FuncOp &funcOp);
1416
bool doDataPartition(triton::FuncOp &funcOp, unsigned numConsumerGroups);
17+
void doCodePartition(triton::FuncOp &funcOp, unsigned numBuffers);
18+
void doTokenLowering(triton::FuncOp &funcOp, unsigned numConsumerGroups);
1519

1620
#define GEN_PASS_DEF_NVGPUWARPSPECIALIZATION
1721
#include "nvidia/hopper/include/Transforms/Passes.h.inc"
@@ -23,15 +27,81 @@ class NVGPUWarpSpecializationPass
2327
NVGPUWarpSpecializationPass>::NVGPUWarpSpecializationBase;
2428

2529
void runOnFuncOp(triton::FuncOp funcOp) {
26-
if (numWarpGroups <= 1)
30+
SmallVector<scf::ForOp> loops;
31+
funcOp->walk([&](scf::ForOp forOp) {
32+
if (forOp->hasAttr(mlir::triton::kWarpSpecializeAttrName))
33+
loops.push_back(forOp);
34+
});
35+
if (loops.empty())
2736
return;
2837

29-
// Partition key ops into multiple async tasks.
30-
doTaskPartition(funcOp, numWarpGroups);
38+
int numWarps = mlir::triton::gpu::lookupNumWarps(funcOp);
39+
if (numWarps != 4)
40+
return;
41+
42+
// FIXME: skip warpspec if there is else block. Need to improve
43+
// CodePartitioning to correctly handle channels in else block.
44+
bool hasElse = false;
45+
funcOp->walk([&](scf::IfOp ifOp) {
46+
if (ifOp.elseBlock()) {
47+
for (Operation &op : ifOp.elseBlock()->getOperations()) {
48+
hasElse = true;
49+
}
50+
}
51+
});
52+
if (hasElse)
53+
return;
3154

32-
// Partition ops into parallel sub ops.
33-
if (!doDataPartition(funcOp, numWarpGroups - 1))
55+
OpBuilder builder(funcOp);
56+
auto moduleOp = funcOp->getParentOfType<ModuleOp>();
57+
unsigned numWarpGroups = 3;
58+
// FIXME: skip data partitioning with on-host TMA.
59+
bool success = false;
60+
for (; numWarpGroups >= 2; numWarpGroups--) {
61+
// Partition key ops into multiple async tasks.
62+
doTaskPartition(funcOp, numWarpGroups);
63+
if (dumpIntermediateSteps) {
64+
llvm::dbgs()
65+
<< "// -----// WarpSpec internal IR Dump After: doTaskPartition\n"
66+
<< moduleOp << "\n\n\n";
67+
}
68+
// Propagate taskId.
69+
int retCode = doTaskIdPropagate(funcOp);
70+
if (retCode == -1)
71+
continue;
72+
if (dumpIntermediateSteps) {
73+
llvm::dbgs()
74+
<< "// -----// WarpSpec internal IR Dump After: doTaskIdPropagate\n"
75+
<< moduleOp << "\n\n\n";
76+
}
77+
78+
// Partition ops into parallel sub ops.
79+
if (doDataPartition(funcOp, numWarpGroups - 1)) {
80+
if (dumpIntermediateSteps) {
81+
llvm::dbgs()
82+
<< "// -----// WarpSpec internal IR Dump After: doDataPartition\n"
83+
<< moduleOp << "\n\n\n";
84+
}
85+
success = true;
86+
break;
87+
}
88+
// Clear async_task.
89+
}
90+
if (!success)
3491
signalPassFailure();
92+
93+
doCodePartition(funcOp, numStages);
94+
if (dumpIntermediateSteps) {
95+
llvm::dbgs()
96+
<< "// -----// WarpSpec internal IR Dump After: doCodePartition\n"
97+
<< moduleOp << "\n\n\n";
98+
}
99+
doTokenLowering(funcOp, numWarpGroups - 1);
100+
// Clear num_stages to disable SWP.
101+
funcOp->walk([&](scf::ForOp forOp) {
102+
forOp->setAttr(mlir::triton::kNumStagesAttrName,
103+
builder.getI32IntegerAttr(0));
104+
});
35105
}
36106

37107
void runOnOperation() override {

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/TaskIdPropagation.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ void TaskIdBackwardPropagation::propagateToYield(
7575
}
7676
}
7777

78+
void TaskIdBackwardPropagation::propagateToTerminator(
79+
Operation *op, ArrayRef<const TaskIdLattice *> &lattices) {
80+
for (auto [lattice, terminatorOperand] :
81+
llvm::zip_equal(lattices, op->getOperands())) {
82+
auto terminatorLattice = getLatticeElement(terminatorOperand);
83+
ChangeResult changed = terminatorLattice->meet(lattice->getValue());
84+
propagateIfChanged(terminatorLattice, changed);
85+
}
86+
}
87+
7888
void TaskIdBackwardPropagation::propagateToParent(Operation *op,
7989
const TaskId &taskId) {
8090
auto parentOp = op->getParentOp();
@@ -93,7 +103,7 @@ void TaskIdBackwardPropagation::propagateToParent(Operation *op,
93103
ChangeResult changed = condLattice->meet(taskId);
94104
propagateIfChanged(condLattice, changed);
95105
} else {
96-
if (!isa<triton::FuncOp>(parentOp))
106+
if (!isa<triton::FuncOp, triton::ReduceOp>(parentOp))
97107
llvm_unreachable("Other parent ops are not supported.");
98108
}
99109
parentOp = parentOp->getParentOp();
@@ -115,6 +125,14 @@ LogicalResult TaskIdBackwardPropagation::visitOperation(
115125
}
116126
// Propagate to the parent ops such as control flows
117127
propagateToParent(op, annotated);
128+
129+
if (op->getNumRegions() == 1) {
130+
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
131+
propagateToTerminator(reduceOp.getCombineOp().front().getTerminator(),
132+
results);
133+
}
134+
}
135+
118136
return success();
119137
}
120138
// If it is not annotated by the user, propagate from results to the
@@ -129,6 +147,13 @@ LogicalResult TaskIdBackwardPropagation::visitOperation(
129147
for (const auto resultLattice : results)
130148
propagateToParent(op, resultLattice->getValue());
131149

150+
if (op->getNumRegions() == 1) {
151+
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
152+
propagateToTerminator(reduceOp.getCombineOp().front().getTerminator(),
153+
results);
154+
}
155+
}
156+
132157
return success();
133158
}
134159

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/TaskIdPropagation.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ class TaskIdBackwardPropagation
9191

9292
void propagateToYield(scf::YieldOp yieldOp, SmallVector<TaskId> &lattices);
9393

94+
void propagateToTerminator(Operation *op,
95+
ArrayRef<const TaskIdLattice *> &lattices);
96+
9497
void propagateToParent(Operation *op, const TaskId &taskId);
9598
};
9699

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSCodePartition.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,8 +1178,7 @@ void foldLocalLoads(triton::FuncOp funcOp) {
11781178
kv.getSecond());
11791179
}
11801180

1181-
void doCodePartition(triton::FuncOp &funcOp, unsigned numBuffers,
1182-
unsigned requestedRegisters) {
1181+
void doCodePartition(triton::FuncOp &funcOp, unsigned numBuffers) {
11831182
// Step 1: collect all communications between producers and consumers.
11841183
SmallVector<std::unique_ptr<Channel>> channelsOrigin;
11851184
collectAsyncChannels(channelsOrigin, funcOp, numBuffers);
@@ -1269,7 +1268,7 @@ void doCodePartition(triton::FuncOp &funcOp, unsigned numBuffers,
12691268
funcOp.dump();
12701269
});
12711270

1272-
specializeRegion(funcOp, requestedRegisters);
1271+
specializeRegion(funcOp, 0 /*requestedRegisters*/);
12731272
LLVM_DEBUG({
12741273
LDBG("\n\nwith specializeRegion");
12751274
funcOp.dump();
@@ -1288,7 +1287,7 @@ class NVGPUTestWSCodePartitionPass
12881287
void runOnFuncOp(triton::FuncOp funcOp) {
12891288
// Disable code partitioning when numBuffers is 0.
12901289
if (numBuffers > 0)
1291-
doCodePartition(funcOp, numBuffers, requestedRegisters);
1290+
doCodePartition(funcOp, numBuffers);
12921291
}
12931292
void runOnOperation() override {
12941293
getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); });

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ static void fixTaskId(triton::FuncOp &funcOp) {
4747
auto defTaskIds = getAsyncTaskIds(defOp);
4848
// Backward propagation: ensure def covers op's task IDs.
4949
if (!containsAll(defTaskIds, asyncTaskIds)) {
50+
// Skip control flow ops.
51+
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp>(op))
52+
continue;
5053
// Only propagate backward to arithmetic ops (e.g. constants).
5154
// Const ops with same value but different task ids can be folded.
5255
if (defOp->getDialect()->getNamespace() == "arith") {

0 commit comments

Comments
 (0)