Skip to content

Commit 220e4e1

Browse files
committed
Address comments
1 parent aa25956 commit 220e4e1

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

test/TritonIntelGPU/optimize-reduction.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1]}>
66

7-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
7+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
88

99
// CHECK-DAG: #[[$ATTR_2:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
1010
// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [3, 4, 5, 6, 0, 1, 2]}>
@@ -52,7 +52,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
5252

5353
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1]}>
5454

55-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
55+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
5656

5757
// CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
5858
// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1, 1, 1, 1], order = [3, 4, 5, 6, 0, 1, 2]}>
@@ -106,7 +106,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
106106
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2], order = [3, 4, 0, 1, 2]}>
107107
// CHECK-DAG: #[[$BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 1, 2], order = [3, 0, 1, 2]}>
108108

109-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
109+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
110110

111111
// CHECK-LABEL: tt.func @test_two_warps_red(
112112
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_8]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> {
@@ -154,7 +154,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
154154
// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1, 2], order = [3, 4, 0, 1, 2]}>
155155
// CHECK-DAG: #[[$BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 2, 2], order = [3, 0, 1, 2]}>
156156

157-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
157+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
158158

159159
// CHECK-LABEL: tt.func @test_two_warps(
160160
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_11]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> {
@@ -235,7 +235,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
235235

236236
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}>
237237

238-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
238+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
239239
// CHECK: tt.func @test(
240240
// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_14]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> {
241241
// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x64xf32, #[[$ATTR_14]]> -> tensor<16x1x4x16x2x2x1xf32, #[[$ATTR_12]]>
@@ -315,7 +315,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
315315

316316
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [4, 2]}>
317317

318-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
318+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
319319
// CHECK: tt.func @test(
320320
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x64xf32, #[[$DPAS]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> {
321321
// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<128x64xf32, #[[$DPAS]]> -> tensor<16x2x4x16x2x2x1xf32, #[[$BLOCKED_EW]]>

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc,
8787
/// ```
8888
/// warpsPerCTA[5]
8989
/// <------------------------------------------------------------------------------->
90-
/// size[4]
90+
/// getShape()[4]
9191
/// <---------------------------------->
9292
/// threadsPerWarp[3]
9393
/// <---------------->
@@ -103,7 +103,7 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc,
103103
/// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
104104
/// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
105105
/// ```
106-
/// So we can reduce on dimensions 4 and 2 to get to:
106+
/// So we can reduce on dimensions 6 and 4 to get to:
107107
/// ```
108108
/// warpsPerCTA[3]
109109
/// <------------------------------------------------------------------------------->
@@ -257,10 +257,13 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
257257

258258
constexpr size_t rank = 7;
259259
std::array<int64_t, rank> shape{
260-
// Y axis
260+
// Y axis contiguous elements handled by a single thread.
261261
oldEncoding.getExecutionSize(),
262+
// Y axis contiguous elements handled by a single thread.
263+
// Needs to be split from previous dimension to perform transpose.
262264
(oldEncoding.getRepeatCount() * oldEncoding.getRepCluster()[0]) /
263265
oldEncoding.getExecutionSize(),
266+
// Y axis rest.
264267
oldShape[0] /
265268
(oldEncoding.getRepeatCount() * oldEncoding.getRepCluster()[0]),
266269
// X axis contiguous elements distributed within individual threads in a

0 commit comments

Comments
 (0)