Skip to content

Commit b8a4b87

Browse files
authored
[NVIDIA][Backend] Add CoalesceAsyncCopy Pass for in-DotOpEnc Upcasting (#5222)
This is a follow-up to the dotOp hoisting optimization for WGMMA (MMAv3). See triton-lang/triton#5003 (comment) In short, when upcasting operand A in registers prior to WGMMA and when pipelining is enabled, `AsyncCopyGLobalToLocal`'s src gmem blocked encoding will have `sizePerThread` > smem view's `vec` (along the contiguous dimension). This will resulting in multiple `cp.async` instructions being generated for a contiguous global data segment, resulting in uncoalesced loads. This was previously confirmed in ncu. See above comment for an example. I've added a generalized fix in a new pass after the pipeliner. I've reused the logic in the LLVM lowering for `AsyncCopyGlobalToLocal` to calculate the max contiguous copy size. I compare that to the blockEnc's `sizePerThread` along the inner (contiguous) dimension. If the former is less than latter, I set the latter to former. When A is k-major, can verify a small perf improvement and that ncu no longer reports uncoalesced loads. When A is m-major, this pass is a no-op because `copy size == sizePerThread == 16` ptal, thanks @ThomasRaoux
1 parent 7b2beae commit b8a4b87

File tree

9 files changed

+228
-34
lines changed

9 files changed

+228
-34
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,16 @@ def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp">
192192
"number of pipeline stages">
193193
];
194194
}
195+
196+
def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::ModuleOp"> {
197+
let summary = "Improve coalescing for async global to local copies";
198+
199+
let description = "For AsyncCopyGlobalToLocal ops where the shared encoding's vec is less than "
200+
"the blocked encoding's sizePerThread, this pass improves coalescing by clipping the "
201+
"sizePerThread value";
202+
203+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
204+
"mlir::triton::TritonDialect"];
205+
}
206+
195207
#endif

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ enum class MMALoadType {
202202
// pipelining
203203
};
204204
MMALoadType getMMALoadType(Operation *loadOp);
205+
206+
// Returns composed LinearLayout for register to shared copy
207+
std::optional<triton::LinearLayout>
208+
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
209+
Attribute srcEnc, Attribute dstEnc, int elemBitWidth);
205210
} // namespace mlir
206211

207212
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
55
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
66
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
7+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
78
#include "llvm/ADT/STLExtras.h"
89

910
namespace mlir {
@@ -174,41 +175,17 @@ bool emitTransferBetweenRegistersAndShared(
174175
StringAttr kLane = str_attr("lane");
175176
StringAttr kWarp = str_attr("warp");
176177

177-
std::optional<LinearLayout> regLayout =
178-
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
179-
std::optional<LinearLayout> sharedLayout = triton::gpu::toLinearLayout(
180-
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
181-
if (!regLayout.has_value() || !sharedLayout.has_value()) {
178+
auto regToSharedLayout = getRegToSharedLayout(
179+
ctx, shape, registerTy.getEncoding(), sharedTy.getEncoding(),
180+
elemLlvmTy.getIntOrFloatBitWidth());
181+
if (!regToSharedLayout.has_value())
182182
return false;
183-
}
184-
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
185-
186-
// sharedLayout's in-dims are currently (offset, block). Reshape to
187-
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
188-
// shmem strides. (The offsetX's appear in minor-to-major order.)
189-
auto sharedLegacy =
190-
cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());
191-
SmallVector<std::pair<StringAttr, int32_t>> multiDimSharedSize;
192-
for (int i = 0; i < rank; i++) {
193-
int dim = sharedOrder[i];
194-
int64_t size = std::max(
195-
int64_t{1},
196-
shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]);
197-
multiDimSharedSize.push_back(
198-
{str_attr("offset" + std::to_string(dim)), size});
199-
}
200-
multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)});
201-
sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize);
202-
203-
// regToSharedLayout maps from (register, lane, warp, block) to (offsetX1,
204-
// ..., offsetXN, block), where the offsetX's are in minor-to-major order.
205-
LinearLayout regToSharedLayout = regLayout->invertAndCompose(*sharedLayout);
206183

207184
// TODO(jlebar): We don't currently support loading from shared memory in a
208185
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
209-
for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock);
186+
for (int inBlock = 1; inBlock < regToSharedLayout->getInDimSize(kBlock);
210187
inBlock *= 2) {
211-
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply(
188+
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout->apply(
212189
{{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}})));
213190
// offsetX1, ..., offsetXN must all be 0.
214191
if (!llvm::all_of(ArrayRef(idx).drop_back(1),
@@ -234,15 +211,15 @@ bool emitTransferBetweenRegistersAndShared(
234211
// which have known strides. This would allow us to vectorize across multiple
235212
// shmem out dimensions where possible.
236213
const int vecElems =
237-
std::min(regToSharedLayout.getNumConsecutiveInOut(),
214+
std::min(regToSharedLayout->getNumConsecutiveInOut(),
238215
maxVecElems.value_or(std::numeric_limits<int>::max()));
239216

240217
Value threadId = getThreadId(rewriter, loc);
241-
Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane));
218+
Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane));
242219
Value laneId = urem(threadId, threadsPerWarp);
243220
Value warpId = udiv(threadId, threadsPerWarp);
244221

245-
int numElems = regToSharedLayout.getInDimSize(kRegister);
222+
int numElems = regToSharedLayout->getInDimSize(kRegister);
246223
auto vecTy = vec_ty(elemLlvmTy, vecElems);
247224
auto ptrTy = shmemBase.getType();
248225
Value zero = i32_val(0);
@@ -253,14 +230,15 @@ bool emitTransferBetweenRegistersAndShared(
253230
// we drop_end to drop block, which we know from above will be 0.
254231
auto multiDimShmemOffset =
255232
llvm::to_vector(llvm::drop_end(llvm::make_second_range(
256-
applyLinearLayout(loc, rewriter, regToSharedLayout,
233+
applyLinearLayout(loc, rewriter, *regToSharedLayout,
257234
{{kRegister, i32_val(i * vecElems)},
258235
{kLane, laneId},
259236
{kWarp, warpId},
260237
{kBlock, zero}}))));
261238

262239
// Reorder strides according to `order`. This way they match the
263240
// multi-dimensional offsets in regToSharedLayout.
241+
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
264242
Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset,
265243
applyPermutation(shmemStrides, sharedOrder));
266244
auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset);

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_triton_library(TritonGPUTransforms
1818
Prefetch.cpp
1919
RemoveLayoutConversions.cpp
2020
ReorderInstructions.cpp
21+
CoalesceAsyncCopy.cpp
2122
Utility.cpp
2223

2324
DEPENDS
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#include "mlir/Support/LLVM.h"
2+
#include "mlir/Transforms/Passes.h"
3+
#include "triton/Analysis/Utility.h"
4+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
5+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
6+
7+
namespace mlir {
8+
namespace triton {
9+
namespace gpu {
10+
11+
#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY
12+
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
13+
14+
// This pass currently only applies if the following are all true...
15+
// 1) Operand A for WGMMA is to be loaded in registers
16+
// 2) We upcast operand A in registers before the WGMMA
17+
// (downcasting is not yet supported)
18+
// 3) Pipelining is enabled for loading A
19+
//
20+
// ...then for the AsyncCopyGlobalToLocal op, the SharedEncoding
21+
// vec will be less than BlockedEncoding's sizePerThread for k-dim. E.g. if
22+
// we're upcasting from int8 to bf16, then shared vec is 8 and sizePerThread
23+
// for k is 16. In this case, AsyncCopyGlobalToLocal will generate two
24+
// 8-byte-cp.async's for each contiguous 16B global data owned by each
25+
// thread. This breaks coalescing (i.e. results 2x the minimum required
26+
// transactions).
27+
//
28+
// This issue occurs for cp.async because it combines load and store into one
29+
// instruction. The fix is to clip each dim of sizePerThread by shared vec, so
30+
// that the vectorization of load and store are equal along the contiguous
31+
// dimension. In the above example, each thread will then only own 8B contiguous
32+
// global data.
33+
struct ClipAsyncCopySizePerThread
34+
: public OpRewritePattern<AsyncCopyGlobalToLocalOp> {
35+
using OpRewritePattern::OpRewritePattern;
36+
37+
LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp,
38+
PatternRewriter &rewriter) const override {
39+
Value src = copyOp.getSrc();
40+
Value mask = copyOp.getMask();
41+
Value other = copyOp.getOther();
42+
auto srcTy = cast<RankedTensorType>(src.getType());
43+
auto dstTy = cast<MemDescType>(copyOp.getResult().getType());
44+
auto blockEnc = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
45+
if (!blockEnc)
46+
return rewriter.notifyMatchFailure(copyOp,
47+
"src must be of blocked encoding");
48+
auto sharedEnc = cast<SharedEncodingAttr>(dstTy.getEncoding());
49+
auto sharedVec = sharedEnc.getVec();
50+
51+
// obtain max contiguous copy size
52+
// Note this can be further optimized, as copyContigSize can be even
53+
// smaller when lowering, depending on contiguity and mask alignment
54+
// (see AsyncCopyGlobalToLocalOpConversion)
55+
auto elemBitWidth = dstTy.getElementTypeBitWidth();
56+
auto regToSharedLayout =
57+
getRegToSharedLayout(rewriter.getContext(), srcTy.getShape(), blockEnc,
58+
sharedEnc, elemBitWidth);
59+
auto copyContigSize = regToSharedLayout->getNumConsecutiveInOut();
60+
61+
// obtain block sizePerThread along contig dim
62+
auto sizePerThread = blockEnc.getSizePerThread();
63+
auto blockContigSize = sizePerThread[blockEnc.getOrder()[0]];
64+
65+
if (blockContigSize <= copyContigSize)
66+
return rewriter.notifyMatchFailure(
67+
copyOp,
68+
"blocked sizePerThread along contiguous dim must be greater than the "
69+
"max contiguous copy size ");
70+
71+
sizePerThread[blockEnc.getOrder()[0]] = copyContigSize;
72+
73+
// obtain new blockedEnc based on clipped sizePerThread
74+
auto mod = copyOp->getParentOfType<ModuleOp>();
75+
int numWarps = TritonGPUDialect::getNumWarps(mod);
76+
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
77+
auto newBlockEnc = BlockedEncodingAttr::get(
78+
copyOp.getContext(), srcTy.getShape(), sizePerThread,
79+
blockEnc.getOrder(), numWarps, threadsPerWarp, blockEnc.getCTALayout());
80+
81+
// insert cvt's after src, mask, and other
82+
auto convertBlockLayout = [&](Value src, BlockedEncodingAttr enc) {
83+
auto ty = cast<TensorType>(src.getType());
84+
auto newTy =
85+
RankedTensorType::get(ty.getShape(), ty.getElementType(), enc);
86+
auto cvt = rewriter.create<ConvertLayoutOp>(copyOp->getLoc(), newTy, src);
87+
return cvt.getResult();
88+
};
89+
src = convertBlockLayout(src, newBlockEnc);
90+
if (mask)
91+
mask = convertBlockLayout(mask, newBlockEnc);
92+
if (other)
93+
other = convertBlockLayout(other, newBlockEnc);
94+
95+
rewriter.modifyOpInPlace(copyOp, [&]() {
96+
copyOp.getSrcMutable().assign(src);
97+
if (mask)
98+
copyOp.getMaskMutable().assign(mask);
99+
if (other)
100+
copyOp.getOtherMutable().assign(other);
101+
});
102+
103+
return success();
104+
}
105+
};
106+
107+
class CoalesceAsyncCopyPass
108+
: public impl::TritonGPUCoalesceAsyncCopyBase<CoalesceAsyncCopyPass> {
109+
public:
110+
void runOnOperation() override {
111+
ModuleOp m = getOperation();
112+
MLIRContext *context = &getContext();
113+
114+
mlir::RewritePatternSet patterns(context);
115+
patterns.add<ClipAsyncCopySizePerThread>(context);
116+
117+
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
118+
signalPassFailure();
119+
}
120+
};
121+
122+
} // namespace gpu
123+
} // namespace triton
124+
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,4 +1154,40 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) {
11541154
patterns.add<ForOpDeadArgElimination>(patterns.getContext());
11551155
}
11561156

1157+
std::optional<LinearLayout>
1158+
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
1159+
Attribute srcEnc, Attribute dstEnc, int elemBitWidth) {
1160+
StringAttr kBlock = StringAttr::get(ctx, ("block"));
1161+
int rank = shape.size();
1162+
1163+
std::optional<LinearLayout> regLayout =
1164+
triton::gpu::toLinearLayout(shape, srcEnc);
1165+
std::optional<LinearLayout> sharedLayout =
1166+
triton::gpu::toLinearLayout(shape, dstEnc, elemBitWidth);
1167+
if (!regLayout.has_value() || !sharedLayout.has_value()) {
1168+
return std::nullopt;
1169+
}
1170+
auto sharedOrder = triton::gpu::getOrder(dstEnc);
1171+
1172+
// sharedLayout's in-dims are currently (offset, block). Reshape to
1173+
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
1174+
// shmem strides. (The offsetX's appear in minor-to-major order.)
1175+
auto sharedLegacy = cast<triton::gpu::SharedEncodingAttr>(dstEnc);
1176+
SmallVector<std::pair<StringAttr, int32_t>> multiDimSharedSize;
1177+
for (int i = 0; i < rank; i++) {
1178+
int dim = sharedOrder[i];
1179+
int64_t size = std::max(
1180+
int64_t{1},
1181+
shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]);
1182+
multiDimSharedSize.push_back(
1183+
{StringAttr::get(ctx, ("offset" + std::to_string(dim))), size});
1184+
}
1185+
multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)});
1186+
sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize);
1187+
1188+
// regToSharedLayout maps from (register, lane, warp, block) to (offsetX1,
1189+
// ..., offsetXN, block), where the offsetX's are in minor-to-major order.
1190+
return regLayout->invertAndCompose(*sharedLayout);
1191+
}
1192+
11571193
} // namespace mlir

python/src/passes.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ void init_triton_passes_ttgpuir(py::module &&m) {
7272
createTritonGPUOptimizeAccumulatorInit);
7373
ADD_PASS_OPTION_WRAPPER_1("add_loop_scheduling",
7474
createTritonGPULoopScheduling, int);
75+
ADD_PASS_WRAPPER_0("add_coalesce_async_copy",
76+
createTritonGPUCoalesceAsyncCopy);
7577
}
7678

7779
void init_triton_passes_convert(py::module &&m) {
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s
2+
3+
// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
5+
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]>
6+
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]>
7+
// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
8+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
9+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
10+
11+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
12+
tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr<i8>, #blocked>,
13+
%view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>,
14+
%mask: tensor<64x16xi1, #blocked>,
15+
%other: tensor<64x16xi8, #blocked>) {
16+
%token = triton_gpu.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr<i8>, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable>
17+
tt.return
18+
}
19+
}
20+
21+
// -----
22+
23+
// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
24+
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
25+
// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
26+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
27+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
28+
29+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
30+
tt.func @async_copy_i8_no_mask_or_other(%input: tensor<64x16x!tt.ptr<i8>, #blocked>,
31+
%view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>) {
32+
%token = triton_gpu.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr<i8>, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable>
33+
tt.return
34+
}
35+
}

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def make_ttgir(mod, metadata, opt, capability):
234234
passes.ttgpuir.add_pipeline(pm, opt.num_stages)
235235
passes.ttgpuir.add_prefetch(pm)
236236
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
237+
passes.ttgpuir.add_coalesce_async_copy(pm)
237238
passes.ttgpuir.add_remove_layout_conversions(pm)
238239
passes.ttgpuir.add_reduce_data_duplication(pm)
239240
passes.ttgpuir.add_reorder_instructions(pm)

0 commit comments

Comments
 (0)