From fdd05121ae1c7375f6d73a6d135cd86b01b75d14 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Mon, 21 Oct 2024 13:43:08 +0100 Subject: [PATCH] [OptRed] Extend `-tritonintelgpu-optimize-reduction-locality` to support `repCluster[0] > 2` Support `repCluster[0] > 2` by using 7-D tensors and adding a `convert_layout` operation before the final `reshape`. See code for implementation details. Signed-off-by: victor-eds --- test/TritonIntelGPU/optimize-reduction.mlir | 276 ++++++++++++------ .../TritonIntelGPU/Transforms/Passes.td | 54 ++-- .../OptimizeReductionLocality.cpp | 217 +++++++++----- 3 files changed, 356 insertions(+), 191 deletions(-) diff --git a/test/TritonIntelGPU/optimize-reduction.mlir b/test/TritonIntelGPU/optimize-reduction.mlir index 79ff12e072..c6e069303b 100644 --- a/test/TritonIntelGPU/optimize-reduction.mlir +++ b/test/TritonIntelGPU/optimize-reduction.mlir @@ -7,32 +7,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-DAG: #[[$ATTR_2:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [1, 2, 3, 4, 0]}> -// CHECK-DAG: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [1, 0]}> -// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [1, 2, 0]}> +// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [3, 4, 5, 6, 0, 1, 2]}> +// CHECK-DAG: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1], order = [3, 0, 1, 2]}> +// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [3, 4, 0, 1, 2]}> +// CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 1, 1], order = [3, 0, 1, 2]}> // CHECK: tt.func @test_single( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<16x16xf32, #[[$ATTR_2]]> -> tensor<16x16x1x1x1xf32, #[[$ATTR_0]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<16x16xf32, #[[$ATTR_2]]> -> tensor<16x1x1x16x1x1x1xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x16x1x1x1xf32, #[[$ATTR_0]]>) -> tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>> -// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: }) : (tensor<16x1x1x16x1x1x1xf32, #[[$ATTR_0]]>) -> tensor<16x1x1x16x1x1xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_0]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>>) -> tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>}>> -> tensor<16x16x1xf32, #[[$ATTR_3]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x16x1xf32, #[[$ATTR_3]]> -> tensor<16x16xf32, #[[$ATTR_1]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: }) : (tensor<16x1x1x16x1x1xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_0]]}>>) -> tensor<16x1x1x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_0]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x1x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_0]]}>}>> -> tensor<16x1x1x16x1xf32, #[[$ATTR_3]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x1x1x16x1xf32, #[[$ATTR_3]]> -> tensor<16x1x1x16xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x16xf32, #[[$ATTR_1]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> +// CHECK: }) : (tensor<16x1x1x16xf32, #[[$ATTR_1]]>) -> tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_1]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_1]]}>> -> tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_4]]}>> +// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_4]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> +// CHECK: tt.return %[[VAL_16]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> // CHECK: } tt.func @test_single(%arg0: tensor<16x16xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -53,32 +55,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [1, 2, 3, 4, 0]}> -// CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [1, 0]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 1], order = [1, 2, 0]}> +// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1, 1, 1, 1], order = [3, 4, 5, 6, 0, 1, 2]}> +// CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1], order = [3, 0, 1, 2]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1, 1], order = [3, 4, 0, 1, 2]}> +// CHECK-DAG: #[[$BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 2, 1], order = [3, 0, 1, 2]}> // CHECK: tt.func @test_single_twice( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_5]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<32x16xf32, #[[$ATTR_5]]> -> tensor<32x16x1x1x1xf32, #[[$ATTR_3]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<32x16xf32, #[[$ATTR_5]]> -> tensor<16x1x2x16x1x1x1xf32, #[[$ATTR_3]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<32x16x1x1x1xf32, #[[$ATTR_3]]>) -> tensor<32x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>> -// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: }) : (tensor<16x1x2x16x1x1x1xf32, #[[$ATTR_3]]>) -> tensor<16x1x2x16x1x1xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<32x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>>) -> tensor<32x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<32x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>}>> -> tensor<32x16x1xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<32x16x1xf32, #[[$BLOCKED]]> -> tensor<32x16xf32, #[[$ATTR_4]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: }) : (tensor<16x1x2x16x1x1xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_3]]}>>) -> tensor<16x1x2x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_3]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x2x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_3]]}>}>> -> tensor<16x1x2x16x1xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x1x2x16x1xf32, #[[$BLOCKED]]> -> tensor<16x1x2x16xf32, #[[$ATTR_4]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<32x16xf32, #[[$ATTR_4]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> +// CHECK: }) : (tensor<16x1x2x16xf32, #[[$ATTR_4]]>) -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_4]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_4]]}>> -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> +// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> +// CHECK: tt.return %[[VAL_16]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> // CHECK: } tt.func @test_single_twice(%arg0: tensor<32x16xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -97,34 +101,36 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1]}> // CHECK-DAG: #[[$ATTR_8:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> -// CHECK-DAG: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [1, 0]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [1, 2, 0]}> +// CHECK-DAG: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 2, 1], order = [3, 4, 5, 6, 0, 1, 2]}> +// CHECK-DAG: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2], order = [3, 0, 1, 2]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2], order = [3, 4, 0, 1, 2]}> +// CHECK-DAG: #[[$BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 1, 2], order = [3, 0, 1, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_two_warps_red( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_8]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<16x32xf32, #[[$ATTR_8]]> -> tensor<16x16x1x2x1xf32, #[[$ATTR_6]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<16x32xf32, #[[$ATTR_8]]> -> tensor<16x1x1x16x1x2x1xf32, #[[$ATTR_6]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x16x1x2x1xf32, #[[$ATTR_6]]>) -> tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>> -// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: }) : (tensor<16x1x1x16x1x2x1xf32, #[[$ATTR_6]]>) -> tensor<16x1x1x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_6]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>>) -> tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>}>> -> tensor<16x16x2xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x16x2xf32, #[[$BLOCKED]]> -> tensor<16x32xf32, #[[$ATTR_7]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: }) : (tensor<16x1x1x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_6]]}>>) -> tensor<16x1x1x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_6]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x1x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_6]]}>}>> -> tensor<16x1x1x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x1x1x16x2xf32, #[[$BLOCKED]]> -> tensor<16x1x1x32xf32, #[[$ATTR_7]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x32xf32, #[[$ATTR_7]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_7]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_7]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> +// CHECK: }) : (tensor<16x1x1x32xf32, #[[$ATTR_7]]>) -> tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_7]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_7]]}>> -> tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> +// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> +// CHECK: tt.return %[[VAL_16]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> // CHECK: } tt.func @test_two_warps_red(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -142,35 +148,37 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1]}> -// CHECK-DAG: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> -// CHECK-DAG: #[[$ATTR_10:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK-DAG: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1, 1, 2, 1], order = [3, 4, 5, 6, 0, 1, 2]}> +// CHECK-DAG: #[[$ATTR_10:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 2], order = [3, 0, 1, 2]}> // CHECK-DAG: #[[$ATTR_11:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1, 2], order = [3, 4, 0, 1, 2]}> +// CHECK-DAG: #[[$BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 2, 2], order = [3, 0, 1, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_two_warps( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_11]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<32x32xf32, #[[$ATTR_11]]> -> tensor<32x16x1x2x1xf32, #[[$ATTR_9]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<32x32xf32, #[[$ATTR_11]]> -> tensor<16x1x2x16x1x2x1xf32, #[[$ATTR_9]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<32x16x1x2x1xf32, #[[$ATTR_9]]>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>> -// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: }) : (tensor<16x1x2x16x1x2x1xf32, #[[$ATTR_9]]>) -> tensor<16x1x2x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -> tensor<32x16x2xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<32x16x2xf32, #[[$BLOCKED]]> -> tensor<32x32xf32, #[[$ATTR_10]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: }) : (tensor<16x1x2x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>>) -> tensor<16x1x2x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x2x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>}>> -> tensor<16x1x2x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x1x2x16x2xf32, #[[$BLOCKED]]> -> tensor<16x1x2x32xf32, #[[$ATTR_10]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<32x32xf32, #[[$ATTR_10]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: }) : (tensor<16x1x2x32xf32, #[[$ATTR_10]]>) -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_10]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_10]]}>> -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> +// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: tt.return %[[VAL_16]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> // CHECK: } tt.func @test_two_warps(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -183,26 +191,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: tt.func @test_two_warps_twice( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64x32xf32, #[[$ATTR_11]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x32xf32, #[[$ATTR_11]]> -> tensor<64x16x1x2x1xf32, #[[$ATTR_9]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x32xf32, #[[$ATTR_11]]> -> tensor<16x1x4x16x1x2x1xf32, #[[$ATTR_9]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<64x16x1x2x1xf32, #[[$ATTR_9]]>) -> tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>> -// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: }) : (tensor<16x1x4x16x1x2x1xf32, #[[$ATTR_9]]>) -> tensor<16x1x4x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -> tensor<64x16x2xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<64x16x2xf32, #[[$BLOCKED]]> -> tensor<64x32xf32, #[[$ATTR_10]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: }) : (tensor<16x1x4x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>>) -> tensor<16x1x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>}>> -> tensor<16x1x4x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x1x4x16x2xf32, #[[$BLOCKED]]> -> tensor<16x1x4x32xf32, #[[$ATTR_10]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<64x32xf32, #[[$ATTR_10]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: }) : (tensor<16x1x4x32xf32, #[[$ATTR_10]]>) -> tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_10]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_10]]}>> -> tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> +// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: tt.return %[[VAL_16]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> // CHECK: } tt.func @test_two_warps_twice(%arg0: tensor<64x32xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -219,35 +228,37 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension. // CHECK-DAG: #[[$ATTR_14:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2], A = [16, 8], B = [8, 32], C = [16, 32]}> -// CHECK-DAG: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> -// CHECK-DAG: #[[$ATTR_13:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [1, 0]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 1, 2], order = [1, 2, 0]}> +// CHECK-DAG: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 4, 1, 1, 2, 1], order = [3, 4, 5, 6, 0, 1, 2]}> +// CHECK-DAG: #[[$ATTR_13:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 4, 2], order = [3, 0, 1, 2]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 4, 1, 2], order = [3, 4, 0, 1, 2]}> +// CHECK-DAG: #[[$BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 4, 2], order = [3, 0, 1, 2]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK: tt.func @test( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_14]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x64xf32, #[[$ATTR_14]]> -> tensor<64x16x2x2x1xf32, #[[$ATTR_12]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x64xf32, #[[$ATTR_14]]> -> tensor<16x1x4x16x2x2x1xf32, #[[$ATTR_12]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<64x16x2x2x1xf32, #[[$ATTR_12]]>) -> tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>> -// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: }) : (tensor<16x1x4x16x2x2x1xf32, #[[$ATTR_12]]>) -> tensor<16x1x4x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -> tensor<64x16x2xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<64x16x2xf32, #[[$BLOCKED]]> -> tensor<64x32xf32, #[[$ATTR_13]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: }) : (tensor<16x1x4x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>>) -> tensor<16x1x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>}>> -> tensor<16x1x4x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x1x4x16x2xf32, #[[$BLOCKED]]> -> tensor<16x1x4x32xf32, #[[$ATTR_13]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<64x32xf32, #[[$ATTR_13]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: }) : (tensor<16x1x4x32xf32, #[[$ATTR_13]]>) -> tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_13]]}>> -> tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> +// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_16]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> // CHECK: } tt.func @test(%arg0: tensor<64x64xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -260,26 +271,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: tt.func @test_repeat_layout( // CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32, #[[$ATTR_14]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<128x128xf32, #[[$ATTR_14]]> -> tensor<128x16x2x2x2xf32, #[[$ATTR_12]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<128x128xf32, #[[$ATTR_14]]> -> tensor<16x1x8x16x2x2x2xf32, #[[$ATTR_12]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<128x16x2x2x2xf32, #[[$ATTR_12]]>) -> tensor<128x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>> -// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: }) : (tensor<16x1x8x16x2x2x2xf32, #[[$ATTR_12]]>) -> tensor<16x1x8x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<128x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>>) -> tensor<128x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<128x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -> tensor<128x16x2xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<128x16x2xf32, #[[$BLOCKED]]> -> tensor<128x32xf32, #[[$ATTR_13]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: }) : (tensor<16x1x8x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>>) -> tensor<16x1x8x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x8x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>}>> -> tensor<16x1x8x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x1x8x16x2xf32, #[[$BLOCKED]]> -> tensor<16x1x8x32xf32, #[[$ATTR_13]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<128x32xf32, #[[$ATTR_13]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: }) : (tensor<16x1x8x32xf32, #[[$ATTR_13]]>) -> tensor<16x1x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_13]]}>> -> tensor<16x1x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> +// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x1x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_16]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> // CHECK: } tt.func @test_repeat_layout(%arg0: tensor<128x128xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -290,3 +302,83 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : tt.return %0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } } + +// ----- + +// Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension with repCluster[0] = 4. + +// CHECK-DAG: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [4, 2], A = [32, 8], B = [8, 32], C = [32, 32]}> +// CHECK-DAG: #[[$BLOCKED_EW:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 4, 1, 1, 2, 1], order = [3, 4, 5, 6, 0, 1, 2]}> +// CHECK-DAG: #[[$BLOCKED_RED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 4, 2], order = [3, 0, 1, 2]}> +// CHECK-DAG: #[[$BLOCKED_TRANS:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 4, 1, 2], order = [3, 4, 0, 1, 2]}> +// CHECK-DAG: #[[$BLOCKED_FINAL:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 4, 2], order = [3, 0, 1, 2]}> + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [4, 2]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { +// CHECK: tt.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x64xf32, #[[$DPAS]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<128x64xf32, #[[$DPAS]]> -> tensor<16x2x4x16x2x2x1xf32, #[[$BLOCKED_EW]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x2x4x16x2x2x1xf32, #[[$BLOCKED_EW]]>) -> tensor<16x2x4x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x2x4x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>>) -> tensor<16x2x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x2x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>> -> tensor<16x2x4x16x2xf32, #[[$BLOCKED_TRANS]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x2x4x16x2xf32, #[[$BLOCKED_TRANS]]> -> tensor<16x2x4x32xf32, #[[$BLOCKED_RED]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<16x2x4x32xf32, #[[$BLOCKED_RED]]>) -> tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>> -> tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> +// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> +// CHECK: tt.return %[[VAL_16]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> +// CHECK: } + tt.func @test(%arg0: tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } + +// CHECK: tt.func @test_repeat_layout( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x64xf32, #[[$DPAS]]>) -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<256x64xf32, #[[$DPAS]]> -> tensor<16x2x8x16x2x2x1xf32, #[[$BLOCKED_EW]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x2x8x16x2x2x1xf32, #[[$BLOCKED_EW]]>) -> tensor<16x2x8x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x2x8x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>>) -> tensor<16x2x8x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x2x8x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>> -> tensor<16x2x8x16x2xf32, #[[$BLOCKED_TRANS]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x2x8x16x2xf32, #[[$BLOCKED_TRANS]]> -> tensor<16x2x8x32xf32, #[[$BLOCKED_RED]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<16x2x8x32xf32, #[[$BLOCKED_RED]]>) -> tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>> -> tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> +// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> +// CHECK: tt.return %[[VAL_16]] : tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> +// CHECK: } + tt.func @test_repeat_layout(%arg0: tensor<256x64xf32, #mma>) -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<256x64xf32, #mma>) -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 3b738b880e..68d5bfde6f 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -293,43 +293,45 @@ def TritonIntelGPUOptimizeReductionLocality sub-group reductions are converted to `tt.reshape`, `tt.reduce`, and `triton_gpu.convert_layout` operations, e.g.: ```mlir -#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> -tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> { - %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2]}> +tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %1 = arith.addf %arg1, %arg2 : f32 tt.reduce.return %1 : f32 - }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } ``` Is converted to: ```mlir -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> -#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [3, 4, 5, 6, 0, 1, 2]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [3, 4, 0, 1, 2]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1], order = [3, 0, 1, 2]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 1, 1], order = [3, 0, 1, 2]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { - %0 = tt.reshape %arg0 {allow_reorder = true} : tensor<32x32xf32, #mma> -> tensor<32x16x1x2x1xf32, #blocked> - %1 = "tt.reduce"(%0) <{axis = 4 : i32}> ({ + %0 = tt.reshape %arg0 {allow_reorder = true, efficient_layout} : tensor<32x32xf32, #mma> -> tensor<16x1x2x16x2x1x1xf32, #blocked> + %1 = "tt.reduce"(%0) <{axis = 6 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): - %7 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %7 : f32 - }) : (tensor<32x16x1x2x1xf32, #blocked>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>> - %2 = "tt.reduce"(%1) <{axis = 2 : i32}> ({ + %8 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %8 : f32 + }) : (tensor<16x1x2x16x2x1x1xf32, #blocked>) -> tensor<16x1x2x16x2x1xf32, #triton_gpu.slice<{dim = 6, parent = #blocked}>> + %2 = "tt.reduce"(%1) <{axis = 4 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): - %7 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %7 : f32 - }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> - %3 = triton_gpu.convert_layout %2 : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<32x16x2xf32, #blocked1> - %4 = tt.reshape %3 {allow_reorder = true} : tensor<32x16x2xf32, #blocked1> -> tensor<32x32xf32, #blocked2> - %5 = "tt.reduce"(%4) <{axis = 1 : i32}> ({ + %8 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %8 : f32 + }) : (tensor<16x1x2x16x2x1xf32, #triton_gpu.slice<{dim = 6, parent = #blocked}>>) -> tensor<16x1x2x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #blocked}>}>> + %3 = triton_gpu.convert_layout %2 : tensor<16x1x2x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #blocked}>}>> -> tensor<16x1x2x16x1xf32, #blocked1> + %4 = tt.reshape %3 {allow_reorder = true, efficient_layout} : tensor<16x1x2x16x1xf32, #blocked1> -> tensor<16x1x2x16xf32, #blocked2> + %5 = "tt.reduce"(%4) <{axis = 3 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): - %7 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %7 : f32 - }) : (tensor<32x32xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %6 = triton_gpu.convert_layout %5 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - tt.return %6 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %8 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %8 : f32 + }) : (tensor<16x1x2x16xf32, #blocked2>) -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #blocked2}>> + %6 = triton_gpu.convert_layout %5 : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #blocked2}>> -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #blocked3}>> + %7 = tt.reshape %6 {allow_reorder = true, efficient_layout} : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #blocked3}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %7 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } ``` The `tt.reshape` operation is a NOP so that the following `tt.reduce` diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 131e85c2d3..d0901ca505 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -69,65 +69,88 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc, /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | /// ``` - /// Blocked (#triton_gpu.blocked<{sizePerThread = [repCluster[0]*repeatCount, 1, 1, 1, 1], threadsPerWarp = [1, executionSize, 1, 1, 1], warpsPerCTA = [warpsPerCTA[0], 1, 1, warpsPerCTA[1], 1], order = [1, 2, 3, 4, 0]}>): + /// Blocked (#triton_gpu.blocked<{sizePerThread = [executionSize, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], 1, 1, warpsPerCTA[1], 1], order = [3, 4, 5, 6, 0, 1, 2]}>): /// ``` - /// warpsPerCTA[3] - /// <-------------------------------------------------------------------------------> - /// size[2] - /// <----------------------------------> - /// threadsPerWarp[1] - /// <----------------> - /// ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^ - /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[0] - /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// sizePerThread[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// warpsPerCTA[5] + /// <-------------------------------------------------------------------------------> + /// size[4] + /// <----------------------------------> + /// threadsPerWarp[3] + /// <----------------> + /// ^ ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^ + /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | sizePerThread[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | ..................................................................................| + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[2] + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// size[1] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | /// ``` /// So we can reduce on dimensions 4 and 2 to get to: /// ``` - /// warpsPerCTA[2] - /// <------------------------------------> - /// threadsPerWarp[1] - /// <------------------> - /// ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn ^ - /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | warpsPerCTA[0] - /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// sizePerThread[0] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// warpsPerCTA[3] + /// <-------------------------------------------------------------------------------> + /// threadsPerWarp[3] + /// <----------------> + /// ^ ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn ^ + /// | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | sizePerThread[0] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | .......................................| + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | warpsPerCTA[2] + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// size[1] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | /// ``` /// After reshaping and layout conversion, we can get to the actual layout /// optimization we wanted to achieve: - /// Blocked (#triton_gpu.blocked<{sizePerThread = [1, repCluster[0]*repeatCount], threadsPerWarp = [executionSize, 1], warpsPerCTA = [warpsPerCTA[0], warpsPerCTA[1]], order = [1, 0]}>): + /// Blocked (#triton_gpu.blocked<{sizePerThread = [1, 1, 1, executionSize], threadsPerWarp = [executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], warpsPerCTA[1]], order = [3, 0, 1, 2]}>): /// ``` - /// warpsPerCTA[1] + /// warpsPerCTA[3] /// <------------------------------------> - /// sizePerThread[1] + /// sizePerThread[3] /// <------------------> /// ^ t0 t0 t0 t0 ... t0 tn1 tn1 tn1 ... tn1 ^ /// | t1 t1 t1 t1 ... t1 tn2 tn2 tn2 ... tn2 | - /// threadsPerWarp[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[0] + /// threadsPerWarp[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[2] /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | /// ``` - /// And reducing on dimension 1 and converting the layout to the original one - /// leads to the same output as the original operation. + /// After reducing on dimension 3, we'd get: + /// ``` + /// ^ t0 ^ + /// | t1 | + /// threadsPerWarp[0] | t2 | warpsPerCTA[2] + /// | t3 | + /// ``` + /// Reshaping from this layout to the final state would not work, as we would + /// end up modifying the number of elements per work-item (not allowed in + /// `reshape`). + /// + /// In order to avoid that, we can just convert the layout to a sliced layout + /// equivalent to the end product we want to achieve: + /// Blocked (#triton_gpu.blocked<{sizePerThread = [1, 1, 1, executionSize], threadsPerWarp = [executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], warpsPerCTA[1]], order = [3, 0, 1, 2]}>) + /// Sliced (#triton_gpu.sliced<{dim = 3, parent = #blocked}>) + /// ``` + /// ^ t0 ^ + /// | t0 | + /// threadsPerWarp[0] | t0 | warpsPerCTA[2] + /// | t0 | + /// ``` + /// And just reshape to the final type using a NOP `reshape`. // clang-format on struct DpasOperandPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; static constexpr int preferredNonReductionAxis = 0; + static constexpr int finalReductionAxis = 3; static constexpr int preferredReductionAxis = 1; - static constexpr int repCountReshapedAxis = 2; - static constexpr int withinWarpXAxisReshapedAxis = 4; + static constexpr int repCountReshapedAxis = 4; + static constexpr int withinWarpXAxisReshapedAxis = 6; LogicalResult matchAndRewrite(ReduceOp op, PatternRewriter &rewriter) const final { @@ -148,12 +171,14 @@ struct DpasOperandPattern final : OpRewritePattern { if (op.getAxis() != preferredReductionAxis) return failure(); - // We want to transpose matrices of (threads_per_warp)^2 shape for now. + // We want to transpose matrices of N*threads_per_warpxthreads_per_warp + // shape. if ( // X axis condition encoding.getExecutionSize() != encoding.getSubGroupSize() || - // Y axis condition - encoding.getRepeatCount() * encoding.getRepCluster()[0] != - encoding.getSubGroupSize()) + // Y axis conditions + (encoding.getRepeatCount() * encoding.getRepCluster()[0]) % + encoding.getSubGroupSize() != + 0) return failure(); LLVM_DEBUG(llvm::dbgs() << "Optimizing reduction: " << op << "\n"); @@ -190,7 +215,15 @@ struct DpasOperandPattern final : OpRewritePattern { LLVM_DEBUG(llvm::dbgs() << "Final reduction performed: " << operand << "\n"); - operand = convertToOriginalType(op, rewriter, operand); + operand = convertLayoutToOriginalType(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Converted layout to original type: " << operand << "\n"); + + operand = reshapeToOriginalType(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Reshaped to original type: " << operand << "\n"); rewriter.replaceOp(op, operand); @@ -207,10 +240,14 @@ struct DpasOperandPattern final : OpRewritePattern { ArrayRef oldShape = oldType.getShape(); auto oldEncoding = cast(oldType.getEncoding()); - constexpr size_t rank = 5; + constexpr size_t rank = 7; std::array shape{ // Y axis - oldShape[0], + oldEncoding.getExecutionSize(), + (oldEncoding.getRepeatCount() * oldEncoding.getRepCluster()[0]) / + oldEncoding.getExecutionSize(), + oldShape[0] / + (oldEncoding.getRepeatCount() * oldEncoding.getRepCluster()[0]), // X axis contiguous elements distributed within individual threads in a // warp. oldEncoding.getExecutionSize(), @@ -222,15 +259,15 @@ struct DpasOperandPattern final : OpRewritePattern { oldShape[1] / (oldEncoding.getExecutionSize() * oldEncoding.getRepCluster()[1] * oldEncoding.getWarpsPerCTA()[1])}; - std::array sizePerThread{oldEncoding.getRepeatCount() * - oldEncoding.getRepCluster()[0], - 1, 1, 1, 1}; - std::array threadsPerWarp{1, oldEncoding.getExecutionSize(), - 1, 1, 1}; - std::array warpsPerCTA{oldEncoding.getWarpsPerCTA()[0], 1, - 1, oldEncoding.getWarpsPerCTA()[1], - 1}; - std::array order{1, 2, 3, 4, 0}; + std::array sizePerThread{ + oldEncoding.getExecutionSize(), 1, 1, 1, 1, 1, 1}; + std::array threadsPerWarp{ + 1, 1, 1, oldEncoding.getExecutionSize(), 1, 1, 1}; + std::array warpsPerCTA{ + 1, 1, oldEncoding.getWarpsPerCTA()[0], + 1, 1, oldEncoding.getWarpsPerCTA()[1], + 1}; + std::array order{3, 4, 5, 6, 0, 1, 2}; CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); auto encoding = rewriter.getAttr( @@ -278,15 +315,16 @@ struct DpasOperandPattern final : OpRewritePattern { cast(op.getOperands().front().getType()) .getEncoding()); - constexpr size_t rank = 3; + constexpr size_t rank = 5; ArrayRef shape = oldType.getShape(); - std::array sizePerThread{1, dpasEncoding.getExecutionSize(), - 1}; + std::array sizePerThread{ + 1, 1, 1, dpasEncoding.getExecutionSize(), 1}; std::array threadsPerWarp{dpasEncoding.getExecutionSize(), - 1, 1}; - std::array warpsPerCTA{dpasEncoding.getWarpsPerCTA()[0], 1, + 1, 1, 1, 1}; + std::array warpsPerCTA{1, 1, + dpasEncoding.getWarpsPerCTA()[0], 1, dpasEncoding.getWarpsPerCTA()[1]}; - std::array order{1, 2, 0}; + std::array order{3, 4, 0, 1, 2}; CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); auto encoding = rewriter.getAttr( @@ -304,15 +342,16 @@ struct DpasOperandPattern final : OpRewritePattern { ArrayRef oldShape = oldType.getShape(); auto oldEncoding = cast(oldType.getEncoding()); - constexpr size_t rank = 2; - std::array shape{oldShape[0], oldShape[1] * oldShape[2]}; - std::array sizePerThread{1, - oldEncoding.getSizePerThread()[1]}; + constexpr size_t rank = 4; + std::array shape{oldShape[0], oldShape[1], oldShape[2], + oldShape[3] * oldShape[4]}; + std::array sizePerThread{1, 1, 1, + oldEncoding.getSizePerThread()[3]}; std::array threadsPerWarp{ - oldEncoding.getThreadsPerWarp()[0], 1}; - std::array warpsPerCTA{oldEncoding.getWarpsPerCTA()[0], - oldEncoding.getWarpsPerCTA()[2]}; - std::array order{1, 0}; + oldEncoding.getThreadsPerWarp()[0], 1, 1, 1}; + std::array warpsPerCTA{ + 1, 1, oldEncoding.getWarpsPerCTA()[2], oldEncoding.getWarpsPerCTA()[4]}; + std::array order{3, 0, 1, 2}; CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); auto encoding = rewriter.getAttr( @@ -328,13 +367,45 @@ struct DpasOperandPattern final : OpRewritePattern { Value performFinalReduction(ReduceOp op, PatternRewriter &rewriter, Value val) const { - return performReduction(op, rewriter, val, /*axis=*/preferredReductionAxis); + return performReduction(op, rewriter, val, /*axis=*/finalReductionAxis); + } + + Value convertLayoutToOriginalType(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + auto oldType = cast(val.getType()); + auto dpasEncoding = cast( + cast(op.getOperands().front().getType()) + .getEncoding()); + + // Only Y axis (X axis has already been reduced) + constexpr size_t rankBeforeLastReduction = 4; + ArrayRef shape = oldType.getShape(); + std::array sizePerThread{ + dpasEncoding.getExecutionSize(), 1, 1, 1}; + std::array threadsPerWarp{ + 1, 1, 1, dpasEncoding.getExecutionSize()}; + std::array warpsPerCTA{ + 1, 1, dpasEncoding.getWarpsPerCTA()[0], + dpasEncoding.getWarpsPerCTA()[1]}; + std::array order{3, 0, 1, 2}; + CTALayoutAttr ctaLayout = + getIdentityCTALayoutAttr(rewriter, rankBeforeLastReduction); + + auto blockedEncoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + auto encoding = rewriter.getAttr(finalReductionAxis, + blockedEncoding); + + RankedTensorType type = + RankedTensorType::get(shape, oldType.getElementType(), encoding); + + return rewriter.create(op.getLoc(), type, val); } - Value convertToOriginalType(ReduceOp op, PatternRewriter &rewriter, + Value reshapeToOriginalType(ReduceOp op, PatternRewriter &rewriter, Value val) const { - return rewriter.create( - op.getLoc(), op.getResult().front().getType(), val); + return createReshapeForReduction(rewriter, op.getLoc(), + op.getResult().front().getType(), val); } };