Skip to content

Commit 33c0c1c

Browse files
authored
[AMD] Fix shared layout order for batch dimension in pipeline passes (#4796)
Batch dimension should be slowest one, other cases are not supported by MFMA/WMMA/MMA pipeline.
1 parent 5f77e8c commit 33c0c1c

File tree

3 files changed

+68
-4
lines changed

3 files changed

+68
-4
lines changed

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
198198
tt.return
199199
}
200200
} // end module
201+
202+
// -----
203+
204+
// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1]
205+
// CHECK: #triton_gpu.shared<{{.*}} order = [2, 1, 0]
206+
// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1]
207+
208+
// CHECK-LABEL: tt.func public @slowest_dim_is_batch
209+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
210+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>
211+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}>
212+
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>
213+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
214+
tt.func public @slowest_dim_is_batch(%arg0: tensor<1x512x!tt.ptr<f32>, #blocked2>, %arg1: tensor<64x8x32x!tt.ptr<f32>, #blocked1>, %arg2: tensor<64x1x32x!tt.ptr<f32>, #blocked>) attributes {noinline = false} {
215+
%cst = arith.constant dense<0.000000e+00> : tensor<64x1x32xf32, #blocked>
216+
%cst_0 = arith.constant dense<512> : tensor<1x512xi32, #blocked2>
217+
%cst_1 = arith.constant dense<128> : tensor<64x8x32xi32, #blocked1>
218+
%c1_i32 = arith.constant 1 : i32
219+
%c5_i32 = arith.constant 2 : i32
220+
%c0_i32 = arith.constant 0 : i32
221+
%33:3 = scf.for %arg7 = %c0_i32 to %c5_i32 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %arg0, %arg10 = %arg1) -> (tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr<f32>, #blocked2>, tensor<64x8x32x!tt.ptr<f32>, #blocked1>) : i32 {
222+
%39 = tt.load %arg9 : tensor<1x512x!tt.ptr<f32>, #blocked2>
223+
%40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr<f32>, #blocked1>
224+
%41 = tt.reshape %39 {allow_reorder = true} : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5>
225+
%43 = triton_gpu.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
226+
%44 = triton_gpu.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
227+
%45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked>
228+
%46 = tt.addptr %arg9, %cst_0 : tensor<1x512x!tt.ptr<f32>, #blocked2>, tensor<1x512xi32, #blocked2>
229+
%47 = tt.addptr %arg10, %cst_1 : tensor<64x8x32x!tt.ptr<f32>, #blocked1>, tensor<64x8x32xi32, #blocked1>
230+
scf.yield %45, %46, %47 : tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr<f32>, #blocked2>, tensor<64x8x32x!tt.ptr<f32>, #blocked1>
231+
}
232+
tt.store %arg2, %33#0 : tensor<64x1x32x!tt.ptr<f32>, #blocked>
233+
tt.return
234+
}
235+
}

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,24 @@ void LoopPipeliner::createBufferTypes() {
403403
// unsigned bitWidth = dotOpEnc.getMMAv2kWidth()
404404
// ? 32 / dotOpEnc.getMMAv2kWidth()
405405
// : ty.getElementType().getIntOrFloatBitWidth();
406-
auto sharedEnc = ttg::SharedEncodingAttr::get(
407-
ty.getContext(), dotOpEnc, ty.getShape(),
408-
ttg::getOrder(ty.getEncoding()), CTALayout, eType);
406+
auto srcOrder = ttg::getOrder(ty.getEncoding());
407+
SmallVector<unsigned> sharedOrder;
408+
int rank = srcOrder.size();
409+
// TODO rework this when shared -> dotOp conversions support arbitrary
410+
// shared memory ordering
411+
if (rank == 3) {
412+
// Move the batch dimension (dim #0) to be the last so that it will be the
413+
// slowest varying dimension.
414+
for (unsigned i = 0; i < rank; ++i)
415+
if (srcOrder[i] != 0)
416+
sharedOrder.emplace_back(srcOrder[i]);
417+
sharedOrder.emplace_back(0);
418+
} else {
419+
sharedOrder = srcOrder;
420+
}
421+
auto sharedEnc =
422+
ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(),
423+
sharedOrder, CTALayout, eType);
409424
loadsBufferType[loadOp] = triton::MemDescType::get(
410425
bufferShape, eType, sharedEnc,
411426
triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()),

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,22 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
207207
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
208208
auto order = ttg::getOrder(srcTy.getEncoding());
209209
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
210+
SmallVector<unsigned> sharedOrder;
211+
int rank = order.size();
212+
// TODO rework this when shared -> dotOp conversions support arbitrary
213+
// shared memory ordering
214+
if (rank == 3) {
215+
// Move the batch dimension (dim #0) to be the last so that it will be
216+
// the slowest varying dimension.
217+
for (unsigned i = 0; i < rank; ++i)
218+
if (order[i] != 0)
219+
sharedOrder.emplace_back(order[i]);
220+
sharedOrder.emplace_back(0);
221+
} else {
222+
sharedOrder = order;
223+
}
210224
tempAttr = ttg::SharedEncodingAttr::get(
211-
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
225+
val.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, CTALayout,
212226
bitWidth, /*needTrans=*/false);
213227
}
214228
// Check that the shared encodings needed by the users are compatible.

0 commit comments

Comments
 (0)