Skip to content

Commit 54606e8

Browse files
authored
[hopper][WS] Use required layout for buffers (#7284)
When creating buffers, the layout required should be decided by the consumer. We were using mma layout previously and it broken the case where consumer wasn't actually a dot op.
1 parent d78b4f9 commit 54606e8

File tree

3 files changed

+89
-8
lines changed

3 files changed

+89
-8
lines changed

test/Hopper/WarpSpecialization/ws_code_partition.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
260260
tt.return
261261
}
262262
}
263+
264+
265+
// -----
266+
267+
// CHECK-DAG: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
268+
// CHECK-DAG: #[[$SHARED1:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
269+
// CHECK-LABEL: @_fbgemm_grouped_gemm_fp8_rowwise_ws
270+
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x64x64xf8E4M3FN, #[[$SHARED1]], #smem, mutable>
271+
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf8E4M3FN, #[[$SHARED1]], #smem, mutable>
272+
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x128xf32, #[[$SHARED]], #smem, mutable>
273+
274+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
275+
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
276+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 32]}>
277+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
278+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
279+
#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8}>
280+
#smem = #ttg.shared_memory
281+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
282+
tt.func public @_fbgemm_grouped_gemm_fp8_rowwise_ws(%arg0: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}, %arg1: i32, %arg2: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}, %arg3: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} {
283+
%c0_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 0 : i32
284+
%c2048_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 2048 : i32
285+
%c64_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 64 : i32
286+
%cst = arith.constant {async_task_id = array<i32: 0, 1, 2>} dense<0.000000e+00> : tensor<64x128xf32, #mma>
287+
%0 = tt.get_program_id x {async_task_id = array<i32: 0, 1, 2>} : i32
288+
%1 = ttng.reinterpret_tensor_descriptor %arg0 {async_task_id = array<i32: 0>} : !tt.ptr<i8, 0> to !tt.tensordesc<tensor<64x64xf8E4M3FN, #shared>>
289+
%2 = ttng.reinterpret_tensor_descriptor %arg2 {async_task_id = array<i32: 0>} : !tt.ptr<i8, 0> to !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared>>
290+
%3 = ttng.reinterpret_tensor_descriptor %arg3 {async_task_id = array<i32: 0>} : !tt.ptr<i8, 0> to !tt.tensordesc<tensor<128xf32, #shared1>>
291+
scf.for %arg4 = %0 to %arg1 step %c64_i32 : i32 {
292+
%4 = arith.muli %arg4, %c2048_i32 {async_task_id = array<i32: 0>} : i32
293+
%5 = scf.for %arg5 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg6 = %cst) -> (tensor<64x128xf32, #mma>) : i32 {
294+
%8 = tt.descriptor_load %1[%4, %arg5] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<64x64xf8E4M3FN, #shared>> -> tensor<64x64xf8E4M3FN, #blocked>
295+
%9 = ttg.local_alloc %8 {async_task_id = array<i32: 1>} : (tensor<64x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<64x64xf8E4M3FN, #shared, #smem>
296+
%10 = tt.descriptor_load %2[%4, %arg5] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared>> -> tensor<128x64xf8E4M3FN, #blocked>
297+
%11 = ttg.local_alloc %10 {async_task_id = array<i32: 1, 2>} : (tensor<128x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared, #smem>
298+
%12 = ttg.memdesc_trans %11 {async_task_id = array<i32: 1, 2>, order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf8E4M3FN, #shared, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #shared2, #smem>
299+
%13 = ttng.warp_group_dot %9, %12, %arg6 {async_task_id = array<i32: 1>, inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<64x64xf8E4M3FN, #shared, #smem> * !ttg.memdesc<64x128xf8E4M3FN, #shared2, #smem> -> tensor<64x128xf32, #mma>
300+
scf.yield {async_task_id = array<i32: 1, 2>} %13 : tensor<64x128xf32, #mma>
301+
} {async_task_id = array<i32: 0, 1, 2>}
302+
%6 = tt.descriptor_load %3[%4] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128xf32, #shared1>> -> tensor<128xf32, #blocked1>
303+
%7 = ttg.convert_layout %6 {async_task_id = array<i32: 1, 2>} : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 0, parent = #blocked}>>
304+
} {async_task_id = array<i32: 1, 2>}
305+
tt.return
306+
}
307+
}

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSCodePartition.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
2020
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
2121
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
22+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
2223
#include <unordered_set>
2324

2425
namespace tt = mlir::triton;
@@ -725,10 +726,19 @@ DenseMap<Channel *, Value> createBuffer(
725726
auto &channels = channelsGroupedByProducers[channelInOrder];
726727
auto srcValue = channelInOrder->getSrcOperand();
727728
auto srcOp = channelInOrder->getSrcOp();
729+
auto dstOp = channelInOrder->getDstOp();
728730
auto *channel = channels.front();
729731
unsigned numBuffers = channel->numBuffers;
730732
Value buffer;
731733

734+
LLVM_DEBUG({
735+
LDBG("Creating buffers for channel:");
736+
LDBG("Producer:");
737+
DBGS() << *srcOp << "\n";
738+
LDBG("Consumer:");
739+
DBGS() << *dstOp << "\n";
740+
});
741+
732742
// For TMEM channel, multi-buffer TMEM alloc
733743
if (channel->channelKind == DataChannelKind::TMEM) {
734744
// Move TMEM alloc to the beginning of the function.
@@ -745,8 +755,34 @@ DenseMap<Channel *, Value> createBuffer(
745755

746756
// Get shape, layout and type of a slice
747757
auto sliceShape = tensorType.getShape();
748-
auto sharedLayout = ttg::NVMMASharedEncodingAttr::get(
749-
context, sliceShape, order, CTALayout, elemType, /*fp4Padded*/ false);
758+
// Check the consumer type
759+
auto actualConsumers = getActualConsumers(dstOp);
760+
LLVM_DEBUG({
761+
DBGS() << "actual consumers: \n";
762+
for (auto consumerOp : actualConsumers) {
763+
DBGS() << *consumerOp << "\n";
764+
}
765+
});
766+
767+
bool requireMMASharedEncoding =
768+
llvm::any_of(actualConsumers, [](Operation *op) {
769+
return isa<mlir::triton::DotOpInterface>(op);
770+
});
771+
772+
Attribute sharedLayout;
773+
if (requireMMASharedEncoding) {
774+
sharedLayout = ttg::NVMMASharedEncodingAttr::get(
775+
context, sliceShape, order, CTALayout, elemType,
776+
/*fp4Padded*/ false);
777+
} else if (auto tmaLoad = dyn_cast<tt::DescriptorLoadOp>(srcOp)) {
778+
sharedLayout = ttng::getEncodingFromDescriptor(
779+
tmaLoad, tmaLoad.getType(), tmaLoad.getDesc());
780+
} else {
781+
// Create an unswizzled layout for now.
782+
// TODO: optimize it based on the consumer.
783+
sharedLayout = ttg::SwizzledSharedEncodingAttr::get(context, 1, 1, 1,
784+
order, CTALayout);
785+
}
750786

751787
// Get shape, layout and type of the complete buffer
752788
SmallVector<int64_t> bufferShape(sliceShape.begin(), sliceShape.end());

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ createAsyncCopy(const DenseMap<Channel *, Value> &bufferMap, Channel *c,
5757

5858
// Get shape, layout and type of a slice
5959
auto sliceShape = tensorType.getShape();
60-
auto sharedLayout = ttg::NVMMASharedEncodingAttr::get(
61-
context, sliceShape, order, CTALayout, elemType, /*fp4Padded*/ false);
60+
auto sharedLayout =
61+
dyn_cast<triton::gpu::MemDescType>(buffer.getType()).getEncoding();
6262
auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout);
6363

6464
Attribute sharedMemorySpace =
@@ -118,8 +118,8 @@ createLocalCopy(const DenseMap<Channel *, Value> &bufferMap, Channel *channel,
118118

119119
// Get shape, layout and type of a slice
120120
auto sliceShape = tensorType.getShape();
121-
auto sharedLayout = ttg::NVMMASharedEncodingAttr::get(
122-
context, sliceShape, order, CTALayout, elemType, /*fp4Padded*/ false);
121+
auto sharedLayout =
122+
dyn_cast<triton::gpu::MemDescType>(buffer.getType()).getEncoding();
123123
auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout);
124124

125125
Attribute sharedMemorySpace =
@@ -205,8 +205,8 @@ Value getBufferForPipelineStage(OpBuilderWithAsyncTaskIds &builder,
205205

206206
// Get shape, layout and type of a slice
207207
auto sliceShape = tensorType.getShape();
208-
auto sharedLayout = ttg::NVMMASharedEncodingAttr::get(
209-
context, sliceShape, order, CTALayout, elemType, /*fp4Padded*/ false);
208+
auto sharedLayout =
209+
dyn_cast<triton::gpu::MemDescType>(buffer.getType()).getEncoding();
210210
auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout);
211211

212212
Attribute sharedMemorySpace =

0 commit comments

Comments
 (0)