Skip to content

Commit 12f8ae7

Browse files
authored
[Blackwell] Move optimizeTMemLoad and tmem load subtiling into one pass (triton-lang#6715)
Follow-up to triton-lang#6694 that moves `optimizeTMemLoad` (picks a splitM layout when the result is fed into a reduction) into the optimize tmem load subtiling pass and renames the pass to `optimize-tmem-layouts`. This separates the optimization from accelerate matmul and allows the relayout pass to not have to run it.
1 parent e32c3b1 commit 12f8ae7

File tree

10 files changed

+103
-119
lines changed

10 files changed

+103
-119
lines changed

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ std::unique_ptr<Pass> createTritonNvidiaGPURemoveTMEMTokensPass();
6262

6363
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeDescriptorEncodingPass();
6464

65-
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeTMemSubtilingPass();
65+
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeTMemLayoutsPass();
6666

6767
/// Generate the code for registering passes.
6868
#define GEN_PASS_REGISTRATION

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,12 @@ def TritonNvidiaGPUOptimizeDescriptorEncodingPass : Pass<"triton-nvidia-optimize
130130
"mlir::triton::TritonDialect"];
131131
}
132132

133-
def TritonNvidiaGPUOptimizeTMemSubtilingPass : Pass<"triton-nvidia-optimize-tmem-subtiling", "mlir::ModuleOp"> {
134-
let summary = "Optimize subtiling.";
133+
def TritonNvidiaGPUOptimizeTMemLayoutsPass : Pass<"triton-nvidia-optimize-tmem-layouts", "mlir::ModuleOp"> {
134+
let summary = "Optimize TMEM layouts.";
135135

136136
let description = [{
137-
Optimize subtiling by trying to split tmem_load when user splits a tensor.
137+
Optimize TMEM layouts by selecting a layouts to enable better subtiling,
138+
reduction performance, etc.
138139
}];
139140

140141
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -807,63 +807,6 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
807807
});
808808
}
809809

810-
// When there are multiple warpgroups tmem_load results can be distirbuted along
811-
// M or N across the warpgroups. By default distribute along N but when there is
812-
// a reduction along N dimension we want to distribute along M instead to avoid
813-
// having to reduce across warps.
814-
static void optimizeTMemLoad(ModuleOp mod) {
815-
SmallVector<triton::nvidia_gpu::TMEMLoadOp> tmemLoads;
816-
mod.walk([&](triton::nvidia_gpu::TMEMLoadOp tmemLoadOp) -> void {
817-
tmemLoads.push_back(tmemLoadOp);
818-
});
819-
for (triton::nvidia_gpu::TMEMLoadOp tmemLoadOp : tmemLoads) {
820-
int numWarps = lookupNumWarps(tmemLoadOp);
821-
// If there is only 1 warpgroup there is nothing to optimize as the layout
822-
// is already reduction friendly.
823-
if (numWarps != 8)
824-
return;
825-
auto tmemEnc = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
826-
tmemLoadOp.getSrc().getType().getEncoding());
827-
if (!tmemEnc)
828-
continue;
829-
int M = tmemEnc.getBlockM();
830-
int N = tmemEnc.getBlockN();
831-
if (M != 128)
832-
continue;
833-
bool foundReductionAlongN = false;
834-
auto filter = [&](Operation *op) {
835-
if (isa<ConvertLayoutOp>(op) || op->hasTrait<OpTrait::Elementwise>())
836-
return true;
837-
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
838-
foundReductionAlongN = reduce.getAxis() == 1;
839-
}
840-
return false;
841-
};
842-
ForwardSliceOptions fwdOpt;
843-
fwdOpt.filter = filter;
844-
SetVector<mlir::Operation *> fwdSlices;
845-
getForwardSlice(tmemLoadOp.getResult(), &fwdSlices, fwdOpt);
846-
if (!foundReductionAlongN)
847-
continue;
848-
// Try to split along M dimension but follow the restrictions of TMEM:
849-
// warp0 get M = 0, warp 1 gets M = 32, warp 2 gets M = 64, warp 3 gets
850-
// M = 96 warp 4 gets M = 16, warp 5 gets M = 48, warp 6 gets M = 80,
851-
// warp 7 gets M = 112
852-
RankedTensorType oldType = tmemLoadOp.getType();
853-
Attribute newLayout = triton::gpu::LinearEncodingAttr::get(
854-
tmemLoadOp.getContext(),
855-
getTmemLoadLayoutSplitLongM(M, N, oldType, numWarps));
856-
auto newType = RankedTensorType::get(oldType.getShape(),
857-
oldType.getElementType(), newLayout);
858-
tmemLoadOp.getResult().setType(newType);
859-
OpBuilder builder(tmemLoadOp);
860-
builder.setInsertionPointAfter(tmemLoadOp);
861-
auto cvt = builder.create<ConvertLayoutOp>(tmemLoadOp.getLoc(), oldType,
862-
tmemLoadOp.getResult());
863-
tmemLoadOp.getResult().replaceAllUsesExcept(cvt.getResult(), cvt);
864-
}
865-
}
866-
867810
// Transpose scaled_dot ops that have a scale on lhs.
868811
static void transposeDotOp(DotScaledOp dotOp) {
869812
OpBuilder builder(dotOp);
@@ -931,9 +874,6 @@ class TritonGPUAccelerateMatmulPass
931874
// Now that we have picked the mma type, decompose dot that are not natively
932875
// supported.
933876
decomposeMixedModeDotOp(m, computeCapability);
934-
935-
// Pick an optimized tmem load layout based on its users.
936-
optimizeTMemLoad(m);
937877
}
938878
};
939879

lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ add_triton_library(TritonNvidiaGPUTransforms
22
FenceInsertion.cpp
33
MMALowering.cpp
44
OptimizeDescriptorEncoding.cpp
5-
OptimizeTMemSubtiling.cpp
5+
OptimizeTMemLayouts.cpp
66
PlanCTA.cpp
77
PromoteLHSToTMem.cpp
88
RemoveTMEMTokens.cpp

lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemSubtiling.cpp renamed to lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemLayouts.cpp

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "mlir/Analysis/SliceAnalysis.h"
12
#include "mlir/IR/TypeUtilities.h"
23
#include "mlir/Pass/PassManager.h"
34
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -181,27 +182,89 @@ class TMemSplitLoadPattern : public OpRewritePattern<tt::SplitOp> {
181182
}
182183
};
183184

184-
class TritonNvidiaGPUOptimizeTMemSubtilingPass
185-
: public TritonNvidiaGPUOptimizeTMemSubtilingPassBase<
186-
TritonNvidiaGPUOptimizeTMemSubtilingPass> {
185+
// Pick an optimized tmem load layout based on its users. When there are
186+
// multiple warpgroups tmem_load results can be distirbuted along M or N across
187+
// the warpgroups. By default distribute along N but when there is a reduction
188+
// along N dimension we want to distribute along M instead to avoid having to
189+
// reduce across warps.
190+
class TMemLoadReducePattern : public OpRewritePattern<ttng::TMEMLoadOp> {
187191
public:
188-
using BaseT = TritonNvidiaGPUOptimizeTMemSubtilingPassBase<
189-
TritonNvidiaGPUOptimizeTMemSubtilingPass>;
192+
using OpRewritePattern::OpRewritePattern;
193+
194+
LogicalResult matchAndRewrite(ttng::TMEMLoadOp tmemLoadOp,
195+
PatternRewriter &rewriter) const override {
196+
int numWarps = ttg::lookupNumWarps(tmemLoadOp);
197+
// If there is only 1 warpgroup there is nothing to optimize as the layout
198+
// is already reduction friendly.
199+
if (numWarps != 8)
200+
return failure();
201+
auto tmemEnc = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
202+
tmemLoadOp.getSrc().getType().getEncoding());
203+
if (!tmemEnc)
204+
return failure();
205+
int M = tmemEnc.getBlockM();
206+
int N = tmemEnc.getBlockN();
207+
if (M != 128)
208+
return failure();
209+
bool foundReductionAlongN = false;
210+
auto filter = [&](Operation *op) {
211+
if (isa<ttg::ConvertLayoutOp>(op) || op->hasTrait<OpTrait::Elementwise>())
212+
return true;
213+
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
214+
foundReductionAlongN = reduce.getAxis() == 1;
215+
}
216+
return false;
217+
};
218+
ForwardSliceOptions fwdOpt;
219+
fwdOpt.filter = filter;
220+
SetVector<mlir::Operation *> fwdSlices;
221+
getForwardSlice(tmemLoadOp.getResult(), &fwdSlices, fwdOpt);
222+
if (!foundReductionAlongN)
223+
return failure();
224+
// Try to split along M dimension but follow the restrictions of TMEM:
225+
// warp0 get M = 0, warp 1 gets M = 32, warp 2 gets M = 64, warp 3 gets
226+
// M = 96 warp 4 gets M = 16, warp 5 gets M = 48, warp 6 gets M = 80,
227+
// warp 7 gets M = 112
228+
RankedTensorType oldType = tmemLoadOp.getType();
229+
Attribute newLayout = ttg::LinearEncodingAttr::get(
230+
tmemLoadOp.getContext(),
231+
ttg::getTmemLoadLayoutSplitLongM(M, N, oldType, numWarps));
232+
if (newLayout == oldType.getEncoding())
233+
return failure();
234+
235+
auto newType = RankedTensorType::get(oldType.getShape(),
236+
oldType.getElementType(), newLayout);
237+
tmemLoadOp.getResult().setType(newType);
238+
OpBuilder builder(tmemLoadOp);
239+
builder.setInsertionPointAfter(tmemLoadOp);
240+
auto cvt = builder.create<ttg::ConvertLayoutOp>(
241+
tmemLoadOp.getLoc(), oldType, tmemLoadOp.getResult());
242+
tmemLoadOp.getResult().replaceAllUsesExcept(cvt.getResult(), cvt);
243+
return success();
244+
}
245+
};
246+
247+
class TritonNvidiaGPUOptimizeTMemLayoutsPass
248+
: public TritonNvidiaGPUOptimizeTMemLayoutsPassBase<
249+
TritonNvidiaGPUOptimizeTMemLayoutsPass> {
250+
public:
251+
using BaseT = TritonNvidiaGPUOptimizeTMemLayoutsPassBase<
252+
TritonNvidiaGPUOptimizeTMemLayoutsPass>;
190253
using BaseT::BaseT;
191254

192255
void runOnOperation() override {
193256
MLIRContext *context = &getContext();
194257
ModuleOp m = getOperation();
195258

196259
mlir::RewritePatternSet patterns(context);
197-
patterns.add<TMemSplitLoadPattern>(context);
260+
patterns.add<TMemSplitLoadPattern, TMemLoadReducePattern>(context);
198261
if (failed(applyPatternsGreedily(m, std::move(patterns))))
199262
signalPassFailure();
200263
}
201264
};
202265

203266
} // namespace
204267

205-
std::unique_ptr<Pass> mlir::createTritonNvidiaGPUOptimizeTMemSubtilingPass() {
206-
return std::make_unique<TritonNvidiaGPUOptimizeTMemSubtilingPass>();
268+
std::unique_ptr<Pass> mlir::createTritonNvidiaGPUOptimizeTMemLayoutsPass() {
269+
return std::make_unique<TritonNvidiaGPUOptimizeTMemLayoutsPass>();
207270
}

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -511,25 +511,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
511511

512512
// -----
513513

514-
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
515-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
516-
// CHECK{LITERALE}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[32, 0], [64, 0], [16, 0]], block = []}>
517-
// CHECK-LABEL: dot_reduce
518-
tt.func public @dot_reduce(%arg0: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> {
519-
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked>
520-
%0 = tt.dot %arg0, %arg1, %cst, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked>
521-
// ttng.tmem_load %{{.*}} : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #linear>
522-
%1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({
523-
^bb0(%arg2: f32, %arg3: f32):
524-
%2 = arith.addf %arg2, %arg3 : f32
525-
tt.reduce.return %2 : f32
526-
}) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
527-
tt.return %1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
528-
}
529-
}
530-
531-
// -----
532-
533514
// CHECK-DAG: #[[$SHARED_A:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>
534515
// CHECK-DAG: #[[$SHARED_B:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8, fp4Padded = true}>
535516
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>

test/TritonGPU/optimize-partition-warps.mlir

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -163,27 +163,4 @@ tt.func @tmem_min_4_warps(%tensor_desc: !ttg.memdesc<64x64xf32, #tmem, #ttng.ten
163163
tt.return
164164
}
165165

166-
tt.func @tmem_split_m_layout(%tensor_desc: !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable>) {
167-
ttg.warp_specialize(%tensor_desc)
168-
default {
169-
ttg.warp_yield
170-
}
171-
// CHECK: partition0{{.*}} num_warps(8)
172-
partition0(%desc: !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable>) num_warps(16) {
173-
// CHECK: ttng.tmem_load {{.*}} -> tensor<128x64xf32, #linear>
174-
%0 = ttng.tmem_load %desc : !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked2d_16>
175-
176-
%1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({
177-
^bb0(%arg2: f32, %arg3: f32):
178-
%2 = arith.addf %arg2, %arg3 : f32
179-
tt.reduce.return %2 : f32
180-
}) : (tensor<128x64xf32, #blocked2d_16>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2d_16}>>
181-
182-
%cst = arith.constant dense<0.0> : tensor<128x128xf32, #blocked2d_16>
183-
"use"(%1, %cst) : (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2d_16}>>, tensor<128x128xf32, #blocked2d_16>) -> ()
184-
ttg.warp_return
185-
} : (!ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable>) -> ()
186-
tt.return
187-
}
188-
189166
}

test/TritonNvidiaGPU/tmem_subtiling.mlir renamed to test/TritonNvidiaGPU/tmem_layouts.mlir

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --triton-nvidia-optimize-tmem-subtiling --allow-unregistered-dialect | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file --triton-nvidia-optimize-tmem-layouts --allow-unregistered-dialect | FileCheck %s
22

33
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
44
#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
@@ -106,3 +106,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
106106
tt.return %5, %6, %11 : tensor<128x64xf16, #blocked>, tensor<128x64xf16, #blocked>, tensor<128x64xf16, #blocked>
107107
}
108108
}
109+
110+
// -----
111+
112+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
113+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>
114+
115+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
116+
117+
// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[32, 0], [64, 0], [16, 0]], block = []}>
118+
// CHECK-LABEL: tmem_load_reduce
119+
tt.func public @tmem_load_reduce(%arg0: !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> {
120+
%0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory> -> tensor<128x64xf32, #blocked>
121+
// CHECK: ttng.tmem_load %{{.*}} : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory> -> tensor<128x64xf32, #linear>
122+
%1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({
123+
^bb0(%arg2: f32, %arg3: f32):
124+
%2 = arith.addf %arg2, %arg3 : f32
125+
tt.reduce.return %2 : f32
126+
}) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
127+
tt.return %1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
128+
}
129+
130+
}

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def make_ttgir(mod, metadata, opt, capability):
270270
passes.ttgpuir.add_WGMMAPrefetch(pm)
271271
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
272272
passes.ttgpuir.add_coalesce_async_copy(pm)
273-
nvidia.passes.ttnvgpuir.add_optimize_tmem_subtiling(pm)
273+
nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
274274
passes.ttgpuir.add_remove_layout_conversions(pm)
275275
passes.ttgpuir.add_reduce_data_duplication(pm)
276276
passes.ttgpuir.add_reorder_instructions(pm)

third_party/nvidia/triton_nvidia.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ void init_triton_nvidia_passes_ttnvgpuir(py::module &&m) {
4848
mlir::createTritonNvidiaGPUMMALoweringPass);
4949
ADD_PASS_WRAPPER_0("add_optimize_descriptor_encoding",
5050
mlir::createTritonNvidiaGPUOptimizeDescriptorEncodingPass);
51-
ADD_PASS_WRAPPER_0("add_optimize_tmem_subtiling",
52-
mlir::createTritonNvidiaGPUOptimizeTMemSubtilingPass);
51+
ADD_PASS_WRAPPER_0("add_optimize_tmem_layouts",
52+
mlir::createTritonNvidiaGPUOptimizeTMemLayoutsPass);
5353
}
5454

5555
void init_triton_nvidia_passes_nvws(py::module &&m) {

0 commit comments

Comments
 (0)