Skip to content

Commit 93367dc

Browse files
authored
[AMD] Add general swizzling support for allocation (#7606)
This PR removes the old assertion now that General swizzling for convert layout is suported in AMD backend. Until recently, the AMD backend lacked support for General Swizzling in ConvertLayoutOp operations. To handle this limitation, the code contained a blocking assertion that would terminate execution if the swizzling path was attempted. This wasn't immediately problematic because the system could fall back to an alternative approach that achieved the same functional result using `defaultAllocationAnalysisScratchSizeFn` . Now that General Swizzling is operational, this assertion can be safely removed, allowing ConvertLayoutOp to utilize `getConvertLayoutScratchInBytes` directly rather than relying on external implementations.
1 parent 33462c8 commit 93367dc

File tree

4 files changed

+30
-28
lines changed

4 files changed

+30
-28
lines changed

include/triton/Analysis/Allocation.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy);
6363
ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
6464
RankedTensorType dstTy);
6565

66+
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
67+
RankedTensorType dstTy);
68+
69+
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
70+
RankedTensorType dstTy);
6671
} // namespace triton
6772

6873
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ static unsigned getBitwidth(RankedTensorType ty) {
3939
return isPtr ? kPtrBitWidth : std::max(ty.getElementTypeBitWidth(), 8u);
4040
}
4141

42-
static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
43-
RankedTensorType dstTy) {
42+
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
43+
RankedTensorType dstTy) {
4444
auto *ctx = srcTy.getContext();
4545
auto srcLayout = gpu::toLinearLayout(srcTy);
4646
auto dstLayout = gpu::toLinearLayout(dstTy);
@@ -52,8 +52,8 @@ static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
5252
return smem.getTotalOutDimSize() / reps;
5353
}
5454

55-
static unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
56-
RankedTensorType dstTy) {
55+
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
56+
RankedTensorType dstTy) {
5757
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
5858
return getNumScratchElements(scratchConfig.paddedRepShape);
5959
}

test/TritonGPU/amd/optimize-lds-usage.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
3636
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
3737
#smem = #ttg.shared_memory
3838
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
39-
tt.func public @alloc_convert_small_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf16, #blocked>) attributes {noinline = false} {
39+
tt.func public @alloc_convert_small_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<256x128xf16, #blocked>) attributes {noinline = false} {
4040
%1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
41-
%2 = ttg.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma>
41+
%2 = ttg.convert_layout %arg1 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #mma>
4242
%3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
4343
tt.return
4444
}
@@ -62,9 +62,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
6262
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2]}>
6363
#smem = #ttg.shared_memory
6464
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
65-
tt.func public @alloc_convert_3d_load(%arg0: tensor<1x128x128xf16, #blocked>, %arg1: tensor<1x128x128xf16, #blocked>) attributes {noinline = false} {
65+
tt.func public @alloc_convert_3d_load(%arg0: tensor<1x128x128xf16, #blocked>, %arg1: tensor<1x256x128xf16, #blocked>) attributes {noinline = false} {
6666
%1 = ttg.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !ttg.memdesc<1x128x128xf16, #shared, #smem>
67-
%2 = ttg.convert_layout %arg1 : tensor<1x128x128xf16, #blocked> -> tensor<1x128x128xf16, #mma>
67+
%2 = ttg.convert_layout %arg1 : tensor<1x256x128xf16, #blocked> -> tensor<1x256x128xf16, #mma>
6868
%3 = ttg.local_load %1 : !ttg.memdesc<1x128x128xf16, #shared, #smem> -> tensor<1x128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
6969
tt.return
7070
}
@@ -87,9 +87,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
8787
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
8888
#smem = #ttg.shared_memory
8989
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
90-
tt.func public @alloc_convert_32k_limit(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<64x128xf16, #blocked>) attributes {noinline = false} {
90+
tt.func public @alloc_convert_32k_limit(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<128x128xf16, #blocked>) attributes {noinline = false} {
9191
%1 = ttg.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared, #smem>
92-
%2 = ttg.convert_layout %arg1 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #mma>
92+
%2 = ttg.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma>
9393
%3 = ttg.local_load %1 : !ttg.memdesc<64x128xf16, #shared, #smem> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>>
9494
tt.return
9595
}
@@ -98,29 +98,29 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
9898
// -----
9999

100100
// Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion)
101-
// CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
101+
// CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
102102
// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}>
103103
// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{version = 2, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
104104
// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
105105
// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
106106

107107
// CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}})
108108
// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem>
109-
// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] {{.*}}: tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]>
110-
// CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] {{.*}}: tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]>
109+
// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] {{.*}}: tensor<256x128xf32, [[BLOCKED_1]]> -> tensor<256x128xf32, [[BLOCKED_2]]>
110+
// CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] {{.*}}: tensor<256x128xf32, [[BLOCKED_2]]> -> tensor<256x128xf32, [[MMA_2]]>
111111
// CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] {{.*}}: tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
112112
// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>>
113-
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
113+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
114114
#mma1 = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
115115
#mma2 = #ttg.amd_mfma<{version = 2, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
116116
#dotop1 = #ttg.dot_op<{opIdx=0, parent=#mma1, kWidth=4}>
117117
#dotop2 = #ttg.dot_op<{opIdx=0, parent=#mma2, kWidth=4}>
118118
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
119119
#smem = #ttg.shared_memory
120120
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
121-
tt.func public @mfma_dot_shortcut(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>, %arg2: tensor<256x128xf16, #mma2>) attributes {noinline = false} {
121+
tt.func public @mfma_dot_shortcut(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<256x128xf32, #blocked>, %arg2: tensor<256x128xf16, #mma2>) attributes {noinline = false} {
122122
%alloc = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
123-
%convert_1 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma1>
123+
%convert_1 = ttg.convert_layout %arg1 : tensor<256x128xf32, #blocked> -> tensor<256x128xf32, #mma1>
124124
%convert_2 = ttg.convert_layout %arg2 : tensor<256x128xf16, #mma2> -> tensor<256x128xf16, #dotop2>
125125
%load = ttg.local_load %alloc : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #dotop1>
126126
tt.return

third_party/amd/lib/Analysis/AMDGPUAllocation.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,22 @@ unsigned getConvertLayoutScratchInBytes(RankedTensorType srcTy,
1919
return 0;
2020
unsigned elems = 0;
2121
if (usePadding) {
22-
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
23-
elems = getNumScratchElements(scratchConfig.paddedRepShape);
22+
elems = getNumScratchElemsPaddedCvt(srcTy, dstTy);
2423
} else {
25-
assert(false && "General swizzling for convert layout is not suported in "
26-
"AMD backend yet");
27-
// TODO use swizzling
24+
elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
2825
}
2926
return elems * getBitwidth(srcTy) / 8;
3027
}
3128

3229
unsigned AMDAllocationAnalysisScratchSizeFn(Operation *op) {
33-
if (op->hasAttr(AttrSharedMemPadded)) {
34-
if (auto cvtLayout = dyn_cast<mlir::triton::gpu::ConvertLayoutOp>(op)) {
35-
auto srcTy = cvtLayout.getSrc().getType();
36-
auto dstTy = cvtLayout.getType();
37-
return getConvertLayoutScratchInBytes(srcTy, dstTy,
38-
op->hasAttr(AttrSharedMemPadded));
39-
}
30+
31+
if (auto cvtLayout = dyn_cast<mlir::triton::gpu::ConvertLayoutOp>(op)) {
32+
auto srcTy = cvtLayout.getSrc().getType();
33+
auto dstTy = cvtLayout.getType();
34+
return getConvertLayoutScratchInBytes(srcTy, dstTy,
35+
op->hasAttr(AttrSharedMemPadded));
4036
}
37+
4138
return defaultAllocationAnalysisScratchSizeFn(op);
4239
}
4340

0 commit comments

Comments
 (0)