Skip to content

Commit bc5471c

Browse files
committed
Fix layout order
1 parent a6a3f13 commit bc5471c

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

test/TritonIntelGPU/optimize-reduction.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
88

99
// CHECK-DAG: #[[$ATTR_2:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
10-
// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}>
10+
// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [1, 2, 3, 4, 0]}>
1111
// CHECK-DAG: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [1, 0]}>
12-
// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [2, 0, 1]}>
12+
// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [1, 2, 0]}>
1313

1414
// CHECK: tt.func @test_single(
1515
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> {
@@ -53,9 +53,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
5353
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
5454

5555
// CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
56-
// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}>
56+
// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [1, 2, 3, 4, 0]}>
5757
// CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [1, 0]}>
58-
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 1], order = [2, 0, 1]}>
58+
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 1], order = [1, 2, 0]}>
5959

6060
// CHECK: tt.func @test_single_twice(
6161
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_5]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> {
@@ -97,9 +97,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
9797
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1]}>
9898

9999
// CHECK-DAG: #[[$ATTR_8:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
100-
// CHECK-DAG: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}>
100+
// CHECK-DAG: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}>
101101
// CHECK-DAG: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [1, 0]}>
102-
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [2, 0, 1]}>
102+
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [1, 2, 0]}>
103103

104104
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
105105

@@ -142,10 +142,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
142142

143143
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1]}>
144144

145-
// CHECK-DAG: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}>
145+
// CHECK-DAG: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}>
146146
// CHECK-DAG: #[[$ATTR_10:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}>
147147
// CHECK-DAG: #[[$ATTR_11:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
148-
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [2, 0, 1]}>
148+
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}>
149149

150150
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
151151

@@ -219,9 +219,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
219219
// Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension.
220220

221221
// CHECK-DAG: #[[$ATTR_14:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2], A = [16, 8], B = [8, 32], C = [16, 32]}>
222-
// CHECK-DAG: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}>
222+
// CHECK-DAG: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}>
223223
// CHECK-DAG: #[[$ATTR_13:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [1, 0]}>
224-
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 1, 2], order = [2, 0, 1]}>
224+
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 1, 2], order = [1, 2, 0]}>
225225

226226
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}>
227227

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
230230
std::array<unsigned, rank> warpsPerCTA{oldEncoding.getWarpsPerCTA()[0], 1,
231231
1, oldEncoding.getWarpsPerCTA()[1],
232232
1};
233-
std::array<unsigned, rank> order{4, 0, 1, 2, 3};
233+
std::array<unsigned, rank> order{1, 2, 3, 4, 0};
234234
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
235235

236236
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
@@ -247,7 +247,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
247247
Value performReduction(ReduceOp op, PatternRewriter &rewriter, Value val,
248248
int axis) const {
249249
assert(axis >= 0 && "Expecting positive axis");
250-
250+
251251
auto newOp = rewriter.create<ReduceOp>(op.getLoc(), val, /*axis=*/axis);
252252
auto &newCombineOp = newOp.getCombineOp();
253253
rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp,
@@ -286,7 +286,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
286286
1, 1};
287287
std::array<unsigned, rank> warpsPerCTA{dpasEncoding.getWarpsPerCTA()[0], 1,
288288
dpasEncoding.getWarpsPerCTA()[1]};
289-
std::array<unsigned, rank> order{2, 0, 1};
289+
std::array<unsigned, rank> order{1, 2, 0};
290290
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
291291

292292
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(

0 commit comments

Comments
 (0)