Skip to content

Commit 44e830e

Browse files
masahiMogball
andauthored
[WS] Use aref for TMA load pipelining and lowering (triton-lang#7826)
A follow-up to triton-lang#7581 which actually starts to use aref for TMA load. It replaces one half of `LoadMMASpecialization`, but the code there is not removed to keep the pass self-contained and its lit tests functional. I verified that all tests pass if I remove the TMA code in `LoadMMASpecialization`. `LowerAref` is updated to add * Lowering for NVWS desc load ops * Aref combining optimization, to coalesce barrier operations on MMA operands into one * Aref multi-buffering, only enabled for arefs whose producer is TMA --------- Co-authored-by: Jeff Niu <[email protected]>
1 parent 98147ef commit 44e830e

File tree

7 files changed

+680
-177
lines changed

7 files changed

+680
-177
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,16 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
257257
arriveBarrier.getPredMutable().assign(mask);
258258
return op;
259259
}
260+
if (auto commit = dyn_cast<ttng::TCGen5CommitOp>(op)) {
261+
rewriter.setInsertionPoint(commit);
262+
Value mask = pred;
263+
Value currentPred = commit.getPred();
264+
if (currentPred) {
265+
mask = getPredMask(rewriter, currentPred.getType(), currentPred, pred);
266+
}
267+
commit.getPredMutable().assign(mask);
268+
return op;
269+
}
260270
if (auto storeOp = dyn_cast<tt::StoreOp>(op)) {
261271
rewriter.setInsertionPoint(storeOp);
262272
Value mask = getPredMask(rewriter, storeOp.getPtr().getType(),

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ struct AutomaticWarpSpecialization
3535
void AutomaticWarpSpecialization::runOnOperation() {
3636
OpPassManager pm;
3737
pm.addPass(createTritonGPUPartitionScheduling());
38+
pm.addPass(createNVWSInsertAref());
3839
pm.addPass(createTritonGPULoadMMASpecialization({numStages}));
3940
pm.addPass(createTritonGPURewritePartitionDependencies());
4041
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.
4142
// FIXME: Re-enable integer range analysis once it is fixed.
4243
// pm.addPass(arith::createIntRangeOptimizationsPass());
4344
pm.addPass(createSCCPPass());
4445
pm.addPass(createCSEPass());
45-
pm.addPass(createNVWSAssignStagePhase());
46-
pm.addPass(createNVWSLowerAref());
46+
pm.addPass(createNVWSLowerAref({numStages}));
4747
pm.addPass(createTritonGPUPartitionLoops());
4848
pm.addPass(createNVWSLowerWarpGroup());
4949
if (failed(runPipeline(pm, getOperation())))

test/NVWS/lower_aref.mlir

Lines changed: 172 additions & 71 deletions
Large diffs are not rendered by default.

test/TritonGPU/load-mma-specialization.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc | FileCheck %s --check-prefix=TMEM --check-prefix=FUNC
22
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -verify-diagnostics --tritongpu-hoist-tmem-alloc -tritongpu-partition-scheduling -tritongpu-load-mma-specialization -sccp -int-range-optimizations -canonicalize -cse -tritongpu-remove-layout-conversions | FileCheck %s
3-
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -verify-diagnostics --tritongpu-hoist-tmem-alloc -tritongpu-automatic-warp-specialization | FileCheck %s --check-prefix=AWS --check-prefix=FUNC
3+
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -verify-diagnostics --tritongpu-hoist-tmem-alloc -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-automatic-warp-specialization | FileCheck %s --check-prefix=AWS --check-prefix=FUNC
44

55
#acc_layout = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
66
#oper_layout = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
@@ -768,7 +768,7 @@ tt.func @matmul_scaled_rhs_scales_tma(
768768
%off_n: i32,
769769
%a_desc: !tt.tensordesc<tensor<128x64xf8E4M3FN, #nvmma_smem>>,
770770
%b_desc: !tt.tensordesc<tensor<128x64xf8E4M3FN, #nvmma_smem>>,
771-
%b_scale_desc: !tt.tensordesc<tensor<128x8xi8, #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}>>>
771+
%b_scale_desc: !tt.tensordesc<tensor<128x8xi8, #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>>>
772772
) {
773773
%true = arith.constant true
774774
%c0_i32 = arith.constant 0 : i32
@@ -791,7 +791,7 @@ tt.func @matmul_scaled_rhs_scales_tma(
791791
// CHECK-COUNT-3: async_tma_copy_global_to_local {{.*}} {ttg.partition = 2 : i32}
792792
%a_reg = tt.descriptor_load %a_desc[%off_m, %off_k] : !tt.tensordesc<tensor<128x64xf8E4M3FN, #nvmma_smem>> -> tensor<128x64xf8E4M3FN, #oper_layout>
793793
%b_reg = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc<tensor<128x64xf8E4M3FN, #nvmma_smem>> -> tensor<128x64xf8E4M3FN, #oper_layout>
794-
%b_scales_reg = tt.descriptor_load %b_scale_desc[%off_m, %c0_i32] : !tt.tensordesc<tensor<128x8xi8, #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}>>> -> tensor<128x8xi8, #scales>
794+
%b_scales_reg = tt.descriptor_load %b_scale_desc[%off_m, %c0_i32] : !tt.tensordesc<tensor<128x8xi8, #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>>> -> tensor<128x8xi8, #scales>
795795

796796
%a_sh = ttg.local_alloc %a_reg : (tensor<128x64xf8E4M3FN, #oper_layout>) -> !ttg.memdesc<128x64xf8E4M3FN, #nvmma_smem, #smem>
797797
%b_sh_raw = ttg.local_alloc %b_reg : (tensor<128x64xf8E4M3FN, #oper_layout>) -> !ttg.memdesc<128x64xf8E4M3FN, #nvmma_smem, #smem>
@@ -1023,13 +1023,13 @@ tt.func @specialize_load_only(%desc: !tt.tensordesc<tensor<128x64xf16, #shared>>
10231023
%c1_i32 = arith.constant 1 : i32
10241024
// CHECK: local_alloc : () -> !ttg.memdesc<3x128x64xf16,
10251025
scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 {
1026-
// CHECK: wait_barrier {{.*}} {ttg.partition = 0 : i32}
1027-
// CHECK-NEXT: local_load {{.*}} {ttg.partition = 0 : i32}
1026+
// CHECK: wait_barrier {{.*}} {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 0 : i32}
1027+
// CHECK-NEXT: local_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 0 : i32}
10281028
// CHECK-NEXT: fence_async_shared {{.*}}partition = 0
1029-
// CHECK-NEXT: arrive_barrier {{.*}} {ttg.partition = 0 : i32}
1030-
%val = tt.descriptor_load %desc[%i, %i] : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #oper_layout>
1031-
"use"(%val) : (tensor<128x64xf16, #oper_layout>) -> ()
1032-
} {tt.warp_specialize}
1029+
// CHECK-NEXT: arrive_barrier {{.*}} {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 0 : i32}
1030+
%val = tt.descriptor_load %desc[%i, %i] {loop.cluster = 1 : i32, loop.stage = 0}: !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #oper_layout>
1031+
"use"(%val) {loop.cluster = 0 : i32, loop.stage = 1 : i32} : (tensor<128x64xf16, #oper_layout>) -> ()
1032+
} {tt.num_stages = 3 : i32, tt.scheduled_max_stage = 1 : i32, tt.warp_specialize}
10331033
tt.return
10341034
}
10351035

@@ -1041,9 +1041,9 @@ tt.func @fp4_padded_load(%desc: !tt.tensordesc<tensor<1x256x64xui8, #fp4_padded_
10411041
scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 {
10421042
// CHECK: [[IDX:%.*]] = arith.muli [[I]], %c2_i32 : i32
10431043
// CHECK: async_tma_copy_global_to_local %arg{{[0-9]+}}[[[I]], [[IDX]]]
1044-
%val = tt.descriptor_load %desc[%i, %i] : !tt.tensordesc<tensor<1x256x64xui8, #fp4_padded_shared>> -> tensor<256x64xi8, #oper_layout>
1045-
"use"(%val) : (tensor<256x64xi8, #oper_layout>) -> ()
1046-
} {tt.warp_specialize}
1044+
%val = tt.descriptor_load %desc[%i, %i] {loop.cluster = 1 : i32, loop.stage = 0, ttg.partition = 2 : i32} : !tt.tensordesc<tensor<1x256x64xui8, #fp4_padded_shared>> -> tensor<256x64xi8, #oper_layout>
1045+
"use"(%val) {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 0 : i32} : (tensor<256x64xi8, #oper_layout>) -> ()
1046+
} {tt.num_stages = 2 : i32, tt.scheduled_max_stage = 1 : i32, tt.warp_specialize}
10471047
tt.return
10481048
}
10491049

third_party/nvidia/include/Dialect/NVWS/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def NVWSLowerAref : Pass<"nvws-lower-aref", "mlir::ModuleOp"> {
8080
"mlir::triton::gpu::TritonGPUDialect",
8181
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
8282
];
83+
84+
let options = [
85+
Option<"numStages", "num-stages", "int32_t", /*default*/"3",
86+
"number of pipeline stages">
87+
];
8388
}
8489

8590
def NVWSInsertAref: Pass<"nvws-insert-aref", "mlir::ModuleOp"> {

third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ ArefCreateOp createAref(OpBuilder &builder, ProducedValueInfo &producedValue) {
9595
};
9696

9797
MemDescType memDescType;
98-
if (isDescLoadAndAlloc<LocalAllocOp>(result)) {
98+
if (result.getDefiningOp<LocalAllocOp>()) {
9999
memDescType = dyn_cast<MemDescType>(result.getType());
100100
} else if (auto opt = isDescLoadAndAlloc<TMEMAllocOp>(result)) {
101101
auto descLoadResult = opt->first.getSrc();
@@ -206,6 +206,10 @@ SmallVector<Operation *> createArefPut(PartitionBuilder &builder,
206206
} else if (isGlobalLoadAndAlloc<LocalAllocOp>(result) ||
207207
isGlobalLoadAndAlloc<TMEMAllocOp>(result)) {
208208
llvm_unreachable("cpasync not supported yet");
209+
} else if (auto alloc = result.getDefiningOp<LocalAllocOp>()) {
210+
builder.createInto<LocalStoreOp>(*producerPartition, stageCluster,
211+
alloc.getSrc(), dataBuf);
212+
staleOps.push_back(alloc);
209213
} else if (auto tensorType = dyn_cast<RankedTensorType>(result.getType())) {
210214
if (auto descOp = result.getDefiningOp<triton::DescriptorOpInterface>()) {
211215
createNVWSDescriptorLoadOp(builder, descOp, dataBuf, producerPartition,
@@ -296,8 +300,7 @@ getEnterAndExitStageClustersOfUses(const SetVector<Value> &producedResults,
296300
scf::ForOp forOp) {
297301
CoarseSchedule coarseSchedule;
298302
if (failed(coarseSchedule.deSerialize(forOp))) {
299-
llvm::report_fatal_error(
300-
"Failed to deserialze stage and cluster annotations.");
303+
return std::make_pair(std::nullopt, std::nullopt);
301304
}
302305

303306
SmallVector<Operation *> ops;
@@ -485,6 +488,8 @@ class NVWSArefInsertion
485488
(allowDescLoadRegUse &&
486489
(isa<triton::DescriptorOpInterface>(op)))) {
487490
ops.push_back(op);
491+
} else if (isa<LocalAllocOp>(op)) {
492+
ops.push_back(op);
488493
}
489494
return WalkResult::advance();
490495
});

0 commit comments

Comments
 (0)