Skip to content

Commit 28d84c6

Browse files
committed
[AMD] Added a fix to elementwise refinement (#789)
- Added a lit-test to `elementwise.mlir` which test refinement of `convert_layout` from one `mma` to another `mma` layoyts - Added a bug fix
1 parent 4fb6cc2 commit 28d84c6

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

test/TritonGPU/amd/ops-refinement/elementwise.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,42 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
125125
tt.return %3 : tensor<128x32xf32, #blocked>
126126
}
127127
}
128+
129+
// -----
130+
131+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
132+
#mma1 = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
133+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 16384 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
134+
tt.func public @convert_layout(%arg0: tensor<128x64xf16, #mma>) attributes {noinline = false} {
135+
// CHECK-LABEL: convert_layout
136+
137+
// CHECK: [[ES_0:%.*]] = amdgpu.extract_slice %arg0 [0, 0] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma>
138+
// CHECK: [[CL_0:%.*]] = ttg.convert_layout [[ES_0]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
139+
// CHECK: [[ES_1:%.*]] = amdgpu.extract_slice %arg0 [0, 16] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma>
140+
// CHECK: [[CL_1:%.*]] = ttg.convert_layout [[ES_1]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
141+
// CHECK: [[ES_2:%.*]] = amdgpu.extract_slice %arg0 [0, 32] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma>
142+
// CHECK: [[CL_2:%.*]] = ttg.convert_layout [[ES_2]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
143+
// CHECK: [[ES_3:%.*]] = amdgpu.extract_slice %arg0 [0, 48] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma>
144+
// CHECK: [[CL_3:%.*]] = ttg.convert_layout [[ES_3]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
145+
// CHECK: %8 = amdgpu.concat [[CL_0]], [[CL_1]], [[CL_2]], [[CL_3]] [1, 4] {loweringOrder = array<i64: 1, 0>} : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>, tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>, tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>, tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
146+
147+
%0 = ttg.convert_layout %arg0 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
148+
amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant<refine_ops>}
149+
tt.return
150+
}
151+
}
152+
153+
// -----
154+
155+
// blocked layout cta tile has size of whole tensor, no transformation should happen
156+
// CHECK-LABEL: @convert_layout_kernel_neg
157+
// CHECK-NOT: amdgpu.extract_slice
158+
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
159+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
160+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
161+
tt.func public @convert_layout_kernel_neg(%arg0: tensor<128x32xf32, #blocked1>) -> tensor<128x32xf32, #blocked2> attributes {noinline = false} {
162+
amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant<refine_ops>}
163+
%0 = ttg.convert_layout %arg0 : tensor<128x32xf32, #blocked1> -> tensor<128x32xf32, #blocked2>
164+
tt.return %0 : tensor<128x32xf32, #blocked2>
165+
}
166+
}

third_party/amd/lib/TritonAMDGPUTransforms/RefineOps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -902,9 +902,10 @@ struct ElementWiseOpPattern : public RefineRewritePattern<OpTy> {
902902
SmallVector<int64_t> numReps;
903903
for (int i = 0; i < rank; ++i) {
904904
// src and res can have different refineable shapes if different layouts.
905-
refinedShape.push_back(
906-
std::max(srcShapePerCtaTile[i], resShapePerCtaTile[i]));
907-
numReps.push_back(srcShape[i] / srcShapePerCtaTile[i]);
905+
const auto refinedDim =
906+
std::max(srcShapePerCtaTile[i], resShapePerCtaTile[i]);
907+
refinedShape.push_back(refinedDim);
908+
numReps.push_back(srcShape[i] / refinedDim);
908909
}
909910

910911
if (product<int64_t>(numReps) == 1)

0 commit comments

Comments
 (0)