Skip to content

Commit e3093b4

Browse files
authored
[AMD] Fix OptimizeLDSUsagePass lit pass (#6271)
This PR fixes and re-enables two lit tests for optimize-amd-lds-usage pass and adds basic logging in the pass.
1 parent 88b6833 commit e3093b4

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

test/TritonGPU/amd/optimize-lds-usage.mlir

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,19 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
4444
}
4545
}
4646

47-
// FIXME: This was broken in https://github.com/triton-lang/triton/pull/5840
48-
// // -----
47+
// -----
4948

49+
// CHECK-DAG: [[$BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
50+
// CHECK-DAG: [[$BLOCKED2:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 8, 8], warpsPerCTA = [1, 2, 4], order = [1, 2, 0]}>
51+
// CHECK-DAG: [[$MMA:#.*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
5052
// Check that optimization works with 3d tensors
5153
// in case of relatively small scratch buffer
52-
// DISABLE-CHECK-LABEL: alloc_convert_3d_load
53-
// DISABLE-CHECK-32KLIMIT-LABEL: alloc_convert_3d_load
54-
// DISABLE-CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
55-
// DISABLE-CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma
56-
// DISABLE-CHECK: %2 = ttg.convert_layout %1 : {{.*}}#mma{{.*}}#mma1
57-
// DISABLE-CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
54+
// CHECK-LABEL: alloc_convert_3d_load
55+
// CHECK-32KLIMIT-LABEL: alloc_convert_3d_load
56+
// CHECK: [[V0:%.*]] = ttg.local_alloc {{.*}}[[$BLOCKED1]]{{.*}}
57+
// CHECK: [[V1:%.*]] = ttg.convert_layout {{.*}}[[$BLOCKED1]]{{.*}}[[$BLOCKED2]]
58+
// CHECK: [[V2:%.*]] = ttg.convert_layout [[V1]] : {{.*}}[[$BLOCKED2]]{{.*}}[[$MMA]]
59+
// CHECK: [[V3:%.*]] = ttg.local_load [[V0]] : {{.*}}#ttg.dot_op<{opIdx = 0, parent = [[$MMA]], kWidth = 4}>>
5860
#blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
5961
#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
6062
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2]}>
@@ -93,22 +95,21 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
9395
}
9496
}
9597

96-
// FIXME: This was broken in https://github.com/triton-lang/triton/pull/5840
9798
// -----
9899

99100
// Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion)
100-
// DISABLE-CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
101-
// DISABLE-CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}>
102-
// DISABLE-CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
103-
// DISABLE-CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
104-
// DISABLE-CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
101+
// CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
102+
// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}>
103+
// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
104+
// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
105+
// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
105106

106-
// DISABLE-CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}})
107-
// DISABLE-CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem>
108-
// DISABLE-CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]>
109-
// DISABLE-CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]>
110-
// DISABLE-CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
111-
// DISABLE-CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>>
107+
// CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}})
108+
// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem>
109+
// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]>
110+
// CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]>
111+
// CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
112+
// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>>
112113
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
113114
#mma1 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
114115
#mma2 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
@@ -125,6 +126,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
125126
tt.return
126127
}
127128
}
129+
128130
// -----
129131

130132
// Checks that optimization do not crash on 1d tensor

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
3030
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
3131

32+
#define DEBUG_TYPE "optimize-amd-lds-usage"
33+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
34+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
35+
3236
using namespace mlir;
3337

3438
namespace mlir::triton {
@@ -86,6 +90,7 @@ class OptimizeAMDLDSUsage
8690
// times do not intersect, therefore this transformation lowers LDS
8791
// consumption.
8892
void tryFitCvtIntoLDS(triton::gpu::ConvertLayoutOp cvtOp, int targetLDSSize) {
93+
LDBG("Trying fit " << cvtOp << " into " << targetLDSSize << " bytes");
8994
OpBuilder builder(cvtOp);
9095

9196
auto srcType = cvtOp.getSrc().getType();
@@ -131,6 +136,7 @@ class OptimizeAMDLDSUsage
131136
SmallVector<Attribute> tmpLayouts;
132137
for (int i = 0; i < factorizedNumWarps.size(); i++) {
133138
auto warpsPerCTA = factorizedNumWarps[i];
139+
134140
auto pushNotNull = [&](Attribute enc) {
135141
if (enc)
136142
tmpLayouts.push_back(enc);
@@ -147,6 +153,8 @@ class OptimizeAMDLDSUsage
147153
for (int i = 0; i < tmpLayouts.size(); i++) {
148154
auto resources = mlir::triton::AMD::estimateResourcesForReplacement(
149155
builder, cvtOp, tmpLayouts[i]);
156+
LDBG("layout " << tmpLayouts[i] << " requires " << resources.LDS
157+
<< " bytes");
150158
// TODO analyze performance along with LDS consumption
151159
if (resources.LDS < minLDSUsage) {
152160
minLDSUsage = resources.LDS;

0 commit comments

Comments
 (0)