From 7a268e6a941209358ef7d981773d02523b7d5a69 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Tue, 15 Oct 2024 13:02:57 +0100 Subject: [PATCH 1/9] [OptRed] Define `-triton-intelgpu-optimize-reduction` pass Add pass splitting reductions and performing layout conversions to avoid using sub-group operations. Signed-off-by: victor-eds --- test/TritonIntelGPU/optimize-reduction.mlir | 252 +++++++++++++ .../TritonIntelGPU/Transforms/Passes.td | 62 +++ .../TritonIntelGPUTransforms/CMakeLists.txt | 1 + .../OptimizeReductionLocality.cpp | 353 ++++++++++++++++++ 4 files changed, 668 insertions(+) create mode 100644 test/TritonIntelGPU/optimize-reduction.mlir create mode 100644 third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp diff --git a/test/TritonIntelGPU/optimize-reduction.mlir b/test/TritonIntelGPU/optimize-reduction.mlir new file mode 100644 index 0000000000..59ca572aae --- /dev/null +++ b/test/TritonIntelGPU/optimize-reduction.mlir @@ -0,0 +1,252 @@ +// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-reduction-locality -canonicalize | FileCheck %s + +// Test reduction in a single warp (16x16->16). + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { + +// CHECK-DAG: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> +// CHECK-DAG: #[[BLOCKED0:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [1, 0]}> + +// CHECK: tt.func @test_single( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[DPAS]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<16x16xf32, #[[DPAS]]> -> tensor<16x16x1x1x1xf32, #[[BLOCKED0]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x16x1x1x1xf32, #[[BLOCKED0]]>) -> tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>>) -> tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -> tensor<16x16xf32, #[[BLOCKED1]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<16x16xf32, #[[BLOCKED1]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> +// CHECK: } + tt.func @test_single(%arg0: tensor<16x16xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<16x16xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across the non-reduction dimension (32x16->32). + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [1, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { + +// CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> +// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [1, 0]}> + +// CHECK: tt.func @test_single_twice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_5]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<32x16xf32, #[[$ATTR_5]]> -> tensor<32x16x1x1x1xf32, #[[$ATTR_3]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<32x16x1x1x1xf32, #[[$ATTR_3]]>) -> tensor<32x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<32x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>>) -> tensor<32x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<32x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>}>> -> tensor<32x16xf32, #[[$ATTR_4]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<32x16xf32, #[[$ATTR_4]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> +// CHECK: } + tt.func @test_single_twice(%arg0: tensor<32x16xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x16xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across the reduction dimension (16x32->16). + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [1, 1]}> + +// CHECK-DAG: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> +// CHECK-DAG: #[[BLOCKED0:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { + +// CHECK: tt.func @test_two_warps( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[DPAS]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<16x32xf32, #[[DPAS]]> -> tensor<16x16x1x2x1xf32, #[[BLOCKED0]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x16x1x2x1xf32, #[[BLOCKED0]]>) -> tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>>) -> tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -> tensor<16x32xf32, #[[BLOCKED1]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<16x32xf32, #[[BLOCKED1]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> +// CHECK: } + tt.func @test_two_warps(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across both dimensions (32x32->32). + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> + +// CHECK-DAG: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> +// CHECK-DAG: #[[BLOCKED0:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { + +// CHECK: tt.func @test_two_warps_twice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[DPAS]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<32x32xf32, #[[DPAS]]> -> tensor<32x16x1x2x1xf32, #[[BLOCKED0]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<32x16x1x2x1xf32, #[[BLOCKED0]]>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -> tensor<32x32xf32, #[[BLOCKED1]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<32x32xf32, #[[BLOCKED1]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> +// CHECK: } + tt.func @test_two_warps_twice(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension with different layout in reduction dimension. + +#mma0 = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1]}> +#mma1 = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 2]}> + +// CHECK-DAG: #[[DPAS0:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> +// CHECK-DAG: #[[DPAS1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 2], A = [8, 8], B = [8, 32], C = [8, 32]}> +// CHECK-DAG: #[[BLOCKED0:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { + +// CHECK: tt.func @test_0( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[DPAS0]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS0]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<64x64xf32, #[[DPAS0]]> -> tensor<64x16x1x2x2xf32, #[[BLOCKED0]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<64x16x1x2x2xf32, #[[BLOCKED0]]>) -> tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -> tensor<64x32xf32, #[[BLOCKED1]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<64x32xf32, #[[BLOCKED1]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS0]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS0]]}>> +// CHECK: } + tt.func @test_0(%arg0: tensor<64x64xf32, #mma0>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64x64xf32, #mma0>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + } + +// CHECK: tt.func @test_1( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[DPAS1]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS1]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<64x64xf32, #[[DPAS1]]> -> tensor<64x16x2x2x1xf32, #[[BLOCKED0]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<64x16x2x2x1xf32, #[[BLOCKED0]]>) -> tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -> tensor<64x32xf32, #[[BLOCKED1]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<64x32xf32, #[[BLOCKED1]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS1]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS1]]}>> +// CHECK: } + tt.func @test_1(%arg0: tensor<64x64xf32, #mma1>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64x64xf32, #mma1>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + } +} diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 42e386fe29..996cf0d2ec 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -284,4 +284,66 @@ def TritonIntelGPUMaterializeBlockPointer : Pass<"tritonintelgpu-materialize-blo "mlir::arith::ArithDialect"]; } +def TritonIntelGPUOptimizeReductionLocality + : Pass<"tritonintelgpu-optimize-reduction-locality", "mlir::ModuleOp"> { + let summary = "Minimize number of reductions within sub-groups"; + + let description = [{ + This pass performs layout conversions so `tt.reduce` operations resulting in + sub-group reductions are converted to `tt.reshape`, `tt.reduce`, and + `triton_gpu.convert_layout` operations, e.g.: + ```mlir +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { + tt.func @test.work(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + } +} + ``` + Is converted to: + ```mlir +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 2, 1], order = [0, 2, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 2, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { + tt.func @test.work(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> { + %0 = tt.reshape %arg0 {allow_reorder = true} : tensor<32x32xf32, #mma> -> tensor<32x32x1xf32, #blocked> + %1 = "tt.reduce"(%0) <{axis = 2 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %2 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %2 : f32 + }) : (tensor<32x32x1xf32, #blocked>) -> tensor<32x32xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> + %3 = triton_gpu.convert_layout %1 : tensor<32x32xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<32x32xf32, #blocked2> + %4 = "tt.reduce"(%3) <{axis = 0 : i32}> ({ + ^bb0(%arg3: f32, %arg4: f32): + %5 = arith.maxnumf %arg3, %arg4 : f32 + tt.reduce.return %5 : f32 + }) : (tensor<32x32xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %6 = triton_gpu.convert_layout %4 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + tt.return %6 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + } +} + ``` + The `tt.reshape` operation is a NOP so that the following `tt.reduce` + operation performs a reduction within the work-item. Then, + `triton_gpu.convert_layout` performs the actual data movement to prevent + within the sub-group reductions so that the following `tt.reduce` performs + a reduction within the work-item again. Finally, we convert back to the + expected layout. + + Note this pass only supports `triton_intel_gpu.dpas` input layouts at the + moment, but it should be easily extended. + }]; + + let dependentDialects = ["mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect"]; +} + #endif // TRITON_INTEL_GPU_PASSES diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index 8c2e290ada..d215107a53 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonIntelGPUTransforms DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp + OptimizeReductionLocality.cpp Pipeliner/MatmulLoopPipeline.cpp Pipeliner/SoftwarePipeliner.cpp PrefetchBlock.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp new file mode 100644 index 0000000000..73edb3a378 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -0,0 +1,353 @@ +//===- OptimizeReductionLocality.cpp ------------------------------------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// This file implements the `tritonintelgpu-optimize-reduction-locality` pass. +//===----------------------------------------------------------------------===// + +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#define DEBUG_TYPE "tritonintelgpu-optimize-reduction-locality" + +namespace mlir::triton::gpu::intel { +#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEREDUCTIONLOCALITY +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" + +namespace { +// clang-format off + /// Optimize reduction with DPAS-encoded input. + /// + /// This optimization reshapes and converts input tensor layouts to split the + /// reduction in three equivalent ones: + /// + /// This only works if the number of items for a given thread across dimension + /// 0 and the execution size are equal to the sub-group size. + /// + /// First, we go from a DPAS layout to an equivalent blocked layout as follows: + /// + /// DPAS: + /// ``` + /// warpsPerCTA[1] + /// <--------------------------------------------------------------------------------> + /// repCluster[1] + /// <-----------------------------------> + /// execution size + /// <----------------> + /// ^ ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^ + /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | repeatCount | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[0] + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// repCluster[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// ``` + /// Blocked (#triton_gpu.blocked<{sizePerThread = [repCluster[0]*repeatCount, 1, 1, 1, 1], threadsPerWarp = [1, executionSize, 1, 1, 1], warpsPerCTA = [warpsPerCTA[0], 1, 1, warpsPerCTA[1], 1], order = [4, 0, 1, 2, 3]}>): + /// ``` + /// warpsPerCTA[3] + /// <-------------------------------------------------------------------------------> + /// size[2] + /// <----------------------------------> + /// threadsPerWarp[1] + /// <----------------> + /// ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^ + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[0] + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// sizePerThread[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// ``` + /// So we can reduce on dimensions 4 and 2 to get to: + /// ``` + /// warpsPerCTA[2] + /// <------------------------------------> + /// threadsPerWarp[1] + /// <------------------> + /// ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn ^ + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | warpsPerCTA[0] + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// sizePerThread[0] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// ``` + /// After reshaping and layout conversion, we can get to the actual layout + /// optimization we wanted to achieve: + /// Blocked (#triton_gpu.blocked<{sizePerThread = [1, repCluster[0]*repeatCount], threadsPerWarp = [executionSize, 1], warpsPerCTA = [warpsPerCTA[0], warpsPerCTA[1]], order = [1, 0]}>): + /// ``` + /// warpsPerCTA[1] + /// <------------------------------------> + /// sizePerThread[1] + /// <------------------> + /// ^ t0 t0 t0 t0 ... t0 tn1 tn1 tn1 ... tn1 ^ + /// | t1 t1 t1 t1 ... t1 tn2 tn2 tn2 ... tn2 | + /// sizePerThread[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[0] + /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | + /// ``` + /// And reducing on dimension 1 and converting the layout to the original one + /// leads to the same output as the original operation. +// clang-format on +struct DPasOperandPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static constexpr int preferredNonReductionAxis = 0; + static constexpr int preferredReductionAxis = 1; + static constexpr int repCountReshapedAxis = 2; + static constexpr int withinWarpXAxisReshapedAxis = 4; + + LogicalResult matchAndRewrite(ReduceOp op, + PatternRewriter &rewriter) const final { + ValueRange operands = op.getOperands(); + // Allowing single operand for now + if (operands.size() != 1) + return failure(); + // Check this is has `triton_intel_gpu.dpas` encoding. + Value operand = operands.front(); + auto type = cast(operand.getType()); + auto encoding = + llvm::dyn_cast_or_null(type.getEncoding()); + if (!encoding) + return failure(); + + // Axis 1 will lead to within-warp reduction. + assert(type.getRank() == 2 && "Expecting 2D tensor"); + if (op.getAxis() != preferredReductionAxis) + return failure(); + + // We want to transpose matrices of (threads_per_warp)^2 shape for now. + if ( // X axis condition + encoding.getExecutionSize() != encoding.getSubGroupSize() || + // Y axis condition + type.getShape()[0] / encoding.getWarpsPerCTA()[0] != + encoding.getSubGroupSize()) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "Optimizing reduction: " << op << "\n"); + + operand = reshapeForElementWiseReduction(op, rewriter); + + LLVM_DEBUG(llvm::dbgs() + << "Reshaped for elementwise reduction: " << operand << "\n"); + + operand = performElementWiseReductionAcrossRepCounts(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Performed elementwise reduction across repCount: " << operand + << "\n"); + + operand = performElementWiseReductionWithinRepCount(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Performed elementwise reduction within repCount: " << operand + << "\n"); + + operand = convertLayoutForFinalReduction(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Converted layout for final reduction: " << operand << "\n"); + + operand = reshapeForFinalReduction(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Reshaped for final reduction: " << operand << "\n"); + + operand = performFinalReduction(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Final reduction performed: " << operand << "\n"); + + operand = convertToOriginalType(op, rewriter, operand); + + rewriter.replaceOp(op, operand); + + return success(); + } + + Value reshapeForElementWiseReduction(ReduceOp op, + PatternRewriter &rewriter) const { + assert(op.getOperands().size() == 1 && "Expecting a single operand"); + + Value val = op.getOperands().front(); + auto oldType = cast(val.getType()); + ArrayRef oldShape = oldType.getShape(); + auto oldEncoding = cast(oldType.getEncoding()); + + constexpr std::size_t rank = 5; + std::array shape{ + // Y axis + oldShape[0], + // X axis contiguous elements distributed within individual threads in a + // warp. + oldEncoding.getExecutionSize(), + // X axis contiguous elements distributed within a warp. + oldEncoding.getRepCluster()[1], + // X axis number of warps. + oldEncoding.getWarpsPerCTA()[1], + // X axis rest. + oldShape[1] / + (oldEncoding.getExecutionSize() * oldEncoding.getRepCluster()[1] * + oldEncoding.getWarpsPerCTA()[1])}; + std::array sizePerThread{oldEncoding.getRepeatCount() * + oldEncoding.getRepCluster()[0], + 1, 1, 1, 1}; + std::array threadsPerWarp{1, oldEncoding.getExecutionSize(), + 1, 1, 1}; + std::array warpsPerCTA{oldEncoding.getWarpsPerCTA()[0], 1, + 1, oldEncoding.getWarpsPerCTA()[1], + 1}; + std::array order{4, 0, 1, 2, 3}; + std::array ctasPerCGA{1, 1, 1, 1, 1}; + std::array ctaSplitNum{1, 1, 1, 1, 1}; + std::array ctaOrder{4, 3, 2, 1, 0}; + auto ctaLayout = + rewriter.getAttr(ctasPerCGA, ctaSplitNum, ctaOrder); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType type = + RankedTensorType::get(shape, oldType.getElementType(), encoding); + + // Although this is a NOP, we have to pass allow_reorder=true as static + // analysis will fail to infer it. + return rewriter.create(op.getLoc(), type, val, + /*allow_reorder=*/true); + } + + Value performReduction(ReduceOp op, PatternRewriter &rewriter, Value val, + int axis) const { + auto newOp = rewriter.create(op.getLoc(), val, /*axis=*/axis); + auto &newCombineOp = newOp.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + assert(newOp.getResult().size() == 1 && "Expecting single result"); + return newOp.getResult().front(); + } + + Value performElementWiseReductionWithinRepCount(ReduceOp op, + PatternRewriter &rewriter, + Value val) const { + return performReduction(op, rewriter, val, /*axis=*/repCountReshapedAxis); + } + + Value performElementWiseReductionAcrossRepCounts(ReduceOp op, + PatternRewriter &rewriter, + Value val) const { + return performReduction(op, rewriter, val, + /*axis=*/withinWarpXAxisReshapedAxis); + } + + Value convertLayoutForFinalReduction(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + assert(op.getOperands().size() == 1 && "Expecting a single operand"); + + auto oldType = cast(val.getType()); + auto dpasEncoding = cast( + cast(op.getOperands().front().getType()) + .getEncoding()); + + constexpr std::size_t rank = 3; + ArrayRef shape = oldType.getShape(); + std::array sizePerThread{1, dpasEncoding.getExecutionSize(), + 1}; + std::array threadsPerWarp{dpasEncoding.getExecutionSize(), + 1, 1}; + std::array warpsPerCTA{dpasEncoding.getWarpsPerCTA()[0], 1, + dpasEncoding.getWarpsPerCTA()[1]}; + std::array order{2, 0, 1}; + std::array ctasPerCGA{1, 1, 1}; + std::array ctaSplitNum{1, 1, 1}; + std::array ctaOrder{2, 1, 0}; + auto ctaLayout = + rewriter.getAttr(ctasPerCGA, ctaSplitNum, ctaOrder); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType type = + RankedTensorType::get(shape, oldType.getElementType(), encoding); + + return rewriter.create(op.getLoc(), type, val); + } + + Value reshapeForFinalReduction(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + auto oldType = cast(val.getType()); + ArrayRef oldShape = oldType.getShape(); + auto oldEncoding = cast(oldType.getEncoding()); + + constexpr std::size_t rank = 2; + std::array shape{oldShape[0], oldShape[1] * oldShape[2]}; + std::array sizePerThread{1, + oldEncoding.getSizePerThread()[1]}; + std::array threadsPerWarp{ + oldEncoding.getThreadsPerWarp()[0], 1}; + std::array warpsPerCTA{oldEncoding.getWarpsPerCTA()[0], + oldEncoding.getWarpsPerCTA()[2]}; + std::array order{1, 0}; + std::array ctasPerCGA{1, 1}; + std::array ctaSplitNum{1, 1}; + std::array ctaOrder{1, 0}; + auto ctaLayout = + rewriter.getAttr(ctasPerCGA, ctaSplitNum, ctaOrder); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType type = + RankedTensorType::get(shape, oldType.getElementType(), encoding); + + // Although this is a NOP, we have to pass allow_reorder=true as static + // analysis will fail to infer it. + return rewriter.create(op.getLoc(), type, val, + /*allow_reorder=*/true); + } + + Value performFinalReduction(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + return performReduction(op, rewriter, val, /*axis=*/preferredReductionAxis); + } + + Value convertToOriginalType(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + return rewriter.create( + op.getLoc(), op.getResult().front().getType(), val); + } +}; + +struct TritonIntelGPUOptimizeReductionLocality final + : impl::TritonIntelGPUOptimizeReductionLocalityBase< + TritonIntelGPUOptimizeReductionLocality> { + using impl::TritonIntelGPUOptimizeReductionLocalityBase< + TritonIntelGPUOptimizeReductionLocality>:: + TritonIntelGPUOptimizeReductionLocalityBase; + + void runOnOperation() final { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace +} // namespace mlir::triton::gpu::intel From 807676f53ea86b11c1c598d15efa847cbd91bc0a Mon Sep 17 00:00:00 2001 From: victor-eds Date: Tue, 15 Oct 2024 13:14:24 +0100 Subject: [PATCH 2/9] Update doc --- .../TritonIntelGPU/Transforms/Passes.td | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 996cf0d2ec..de5827b23e 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -299,7 +299,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func @test.work(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 + %1 = arith.addf %arg1, %arg2 : f32 tt.reduce.return %1 : f32 }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> @@ -308,26 +308,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : ``` Is converted to: ```mlir -#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 2, 1], order = [0, 2, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 2, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> - +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [2, 0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { - tt.func @test.work(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> { - %0 = tt.reshape %arg0 {allow_reorder = true} : tensor<32x32xf32, #mma> -> tensor<32x32x1xf32, #blocked> - %1 = "tt.reduce"(%0) <{axis = 2 : i32}> ({ + tt.func @test_two_warps_twice(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = tt.reshape %arg0 {allow_reorder = true} : tensor<32x32xf32, #mma> -> tensor<32x16x1x2x1xf32, #blocked> + %1 = "tt.reduce"(%0) <{axis = 4 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %7 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %7 : f32 + }) : (tensor<32x16x1x2x1xf32, #blocked>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>> + %2 = "tt.reduce"(%1) <{axis = 2 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %7 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %7 : f32 + }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> + %3 = triton_gpu.convert_layout %2 : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<32x16x2xf32, #blocked1> + %4 = tt.reshape %3 {allow_reorder = true} : tensor<32x16x2xf32, #blocked1> -> tensor<32x32xf32, #blocked2> + %5 = "tt.reduce"(%4) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): - %2 = arith.maxnumf %arg1, %arg2 : f32 - tt.reduce.return %2 : f32 - }) : (tensor<32x32x1xf32, #blocked>) -> tensor<32x32xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> - %3 = triton_gpu.convert_layout %1 : tensor<32x32xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<32x32xf32, #blocked2> - %4 = "tt.reduce"(%3) <{axis = 0 : i32}> ({ - ^bb0(%arg3: f32, %arg4: f32): - %5 = arith.maxnumf %arg3, %arg4 : f32 - tt.reduce.return %5 : f32 - }) : (tensor<32x32xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %6 = triton_gpu.convert_layout %4 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - tt.return %6 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %7 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %7 : f32 + }) : (tensor<32x32xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %6 = triton_gpu.convert_layout %5 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %6 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } } ``` From dc51329735979c30018d0d9097e7647d2235754c Mon Sep 17 00:00:00 2001 From: victor-eds Date: Tue, 15 Oct 2024 17:22:12 +0100 Subject: [PATCH 3/9] Update code --- test/TritonIntelGPU/optimize-reduction.mlir | 178 ++++++++++-------- .../OptimizeReductionLocality.cpp | 2 +- 2 files changed, 104 insertions(+), 76 deletions(-) diff --git a/test/TritonIntelGPU/optimize-reduction.mlir b/test/TritonIntelGPU/optimize-reduction.mlir index 59ca572aae..1fdc5f967d 100644 --- a/test/TritonIntelGPU/optimize-reduction.mlir +++ b/test/TritonIntelGPU/optimize-reduction.mlir @@ -2,35 +2,35 @@ // Test reduction in a single warp (16x16->16). -#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { -// CHECK-DAG: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> -// CHECK-DAG: #[[BLOCKED0:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> -// CHECK-DAG: #[[BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +// CHECK-DAG: #[[$ATTR_2:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [1, 0]}> // CHECK: tt.func @test_single( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[DPAS]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<16x16xf32, #[[DPAS]]> -> tensor<16x16x1x1x1xf32, #[[BLOCKED0]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<16x16xf32, #[[$ATTR_2]]> -> tensor<16x16x1x1x1xf32, #[[$ATTR_0]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x16x1x1x1xf32, #[[BLOCKED0]]>) -> tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>> +// CHECK: }) : (tensor<16x16x1x1x1xf32, #[[$ATTR_0]]>) -> tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>>) -> tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -> tensor<16x16xf32, #[[BLOCKED1]]> +// CHECK: }) : (tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>>) -> tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>}>> -> tensor<16x16xf32, #[[$ATTR_1]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x16xf32, #[[BLOCKED1]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> +// CHECK: }) : (tensor<16x16xf32, #[[$ATTR_1]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> // CHECK: } tt.func @test_single(%arg0: tensor<16x16xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -46,12 +46,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // Test reduction in two warps across the non-reduction dimension (32x16->32). -#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [1, 1]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { -// CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> -// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> // CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [1, 0]}> // CHECK: tt.func @test_single_twice( @@ -90,37 +90,37 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // Test reduction in two warps across the reduction dimension (16x32->16). -#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [1, 1]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1]}> -// CHECK-DAG: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> -// CHECK-DAG: #[[BLOCKED0:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> -// CHECK-DAG: #[[BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [1, 0]}> +// CHECK-DAG: #[[$ATTR_8:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +// CHECK-DAG: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [1, 0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { -// CHECK: tt.func @test_two_warps( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[DPAS]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<16x32xf32, #[[DPAS]]> -> tensor<16x16x1x2x1xf32, #[[BLOCKED0]]> +// CHECK-LABEL: tt.func @test_two_warps_red( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_8]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<16x32xf32, #[[$ATTR_8]]> -> tensor<16x16x1x2x1xf32, #[[$ATTR_6]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x16x1x2x1xf32, #[[BLOCKED0]]>) -> tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>> +// CHECK: }) : (tensor<16x16x1x2x1xf32, #[[$ATTR_6]]>) -> tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>>) -> tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -> tensor<16x32xf32, #[[BLOCKED1]]> +// CHECK: }) : (tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>>) -> tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>}>> -> tensor<16x32xf32, #[[$ATTR_7]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x32xf32, #[[BLOCKED1]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> +// CHECK: }) : (tensor<16x32xf32, #[[$ATTR_7]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_7]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_7]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> // CHECK: } - tt.func @test_two_warps(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + tt.func @test_two_warps_red(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %1 = arith.addf %arg1, %arg2 : f32 @@ -134,37 +134,37 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // Test reduction in two warps across both dimensions (32x32->32). -#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1]}> -// CHECK-DAG: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> -// CHECK-DAG: #[[BLOCKED0:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> -// CHECK-DAG: #[[BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +// CHECK: #[[$ATTR_10:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: #[[$ATTR_11:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { -// CHECK: tt.func @test_two_warps_twice( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[DPAS]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<32x32xf32, #[[DPAS]]> -> tensor<32x16x1x2x1xf32, #[[BLOCKED0]]> +// CHECK-LABEL: tt.func @test_two_warps( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_11]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<32x32xf32, #[[$ATTR_11]]> -> tensor<32x16x1x2x1xf32, #[[$ATTR_9]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<32x16x1x2x1xf32, #[[BLOCKED0]]>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>> +// CHECK: }) : (tensor<32x16x1x2x1xf32, #[[$ATTR_9]]>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -> tensor<32x32xf32, #[[BLOCKED1]]> +// CHECK: }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -> tensor<32x32xf32, #[[$ATTR_10]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<32x32xf32, #[[BLOCKED1]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS]]}>> +// CHECK: }) : (tensor<32x32xf32, #[[$ATTR_10]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> // CHECK: } - tt.func @test_two_warps_twice(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + tt.func @test_two_warps(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %1 = arith.addf %arg1, %arg2 : f32 @@ -172,81 +172,109 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } + +// CHECK-LABEL: tt.func @test_two_warps_twice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x32xf32, #[[$ATTR_11]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<64x32xf32, #[[$ATTR_11]]> -> tensor<64x16x1x2x1xf32, #[[$ATTR_9]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<64x16x1x2x1xf32, #[[$ATTR_9]]>) -> tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -> tensor<64x32xf32, #[[$ATTR_10]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<64x32xf32, #[[$ATTR_10]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: } + tt.func @test_two_warps_twice(%arg0: tensor<64x32xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64x32xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } } // ----- -// Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension with different layout in reduction dimension. +// Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension. -#mma0 = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1]}> -#mma1 = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 2]}> +// CHECK-DAG: #[[$ATTR_14:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2], A = [16, 8], B = [8, 32], C = [16, 32]}> +// CHECK-DAG: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_13:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [1, 0]}> -// CHECK-DAG: #[[DPAS0:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> -// CHECK-DAG: #[[DPAS1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 2], A = [8, 8], B = [8, 32], C = [8, 32]}> -// CHECK-DAG: #[[BLOCKED0:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> -// CHECK-DAG: #[[BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [1, 0]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { - -// CHECK: tt.func @test_0( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[DPAS0]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS0]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<64x64xf32, #[[DPAS0]]> -> tensor<64x16x1x2x2xf32, #[[BLOCKED0]]> +// CHECK: tt.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_14]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<64x64xf32, #[[$ATTR_14]]> -> tensor<64x16x2x2x1xf32, #[[$ATTR_12]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<64x16x1x2x2xf32, #[[BLOCKED0]]>) -> tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>> +// CHECK: }) : (tensor<64x16x2x2x1xf32, #[[$ATTR_12]]>) -> tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -> tensor<64x32xf32, #[[BLOCKED1]]> +// CHECK: }) : (tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -> tensor<64x32xf32, #[[$ATTR_13]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<64x32xf32, #[[BLOCKED1]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS0]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS0]]}>> +// CHECK: }) : (tensor<64x32xf32, #[[$ATTR_13]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> // CHECK: } - tt.func @test_0(%arg0: tensor<64x64xf32, #mma0>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> { + tt.func @test(%arg0: tensor<64x64xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %1 = arith.maxnumf %arg1, %arg2 : f32 tt.reduce.return %1 : f32 - }) : (tensor<64x64xf32, #mma0>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + }) : (tensor<64x64xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } -// CHECK: tt.func @test_1( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[DPAS1]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS1]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<64x64xf32, #[[DPAS1]]> -> tensor<64x16x2x2x1xf32, #[[BLOCKED0]]> +// CHECK: tt.func @test_repeat_layout( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32, #[[$ATTR_14]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<128x128xf32, #[[$ATTR_14]]> -> tensor<128x16x2x2x2xf32, #[[$ATTR_12]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<64x16x2x2x1xf32, #[[BLOCKED0]]>) -> tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>> +// CHECK: }) : (tensor<128x16x2x2x2xf32, #[[$ATTR_12]]>) -> tensor<128x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[BLOCKED0]]}>}>> -> tensor<64x32xf32, #[[BLOCKED1]]> +// CHECK: }) : (tensor<128x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>>) -> tensor<128x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<128x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -> tensor<128x32xf32, #[[$ATTR_13]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 // CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<64x32xf32, #[[BLOCKED1]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[BLOCKED1]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS1]]}>> -// CHECK: tt.return %[[VAL_15]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[DPAS1]]}>> +// CHECK: }) : (tensor<128x32xf32, #[[$ATTR_13]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> // CHECK: } - tt.func @test_1(%arg0: tensor<64x64xf32, #mma1>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> { + tt.func @test_repeat_layout(%arg0: tensor<128x128xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %1 = arith.maxnumf %arg1, %arg2 : f32 tt.reduce.return %1 : f32 - }) : (tensor<64x64xf32, #mma1>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + }) : (tensor<128x128xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } } diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 73edb3a378..1516eedcf7 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -135,7 +135,7 @@ struct DPasOperandPattern final : OpRewritePattern { if ( // X axis condition encoding.getExecutionSize() != encoding.getSubGroupSize() || // Y axis condition - type.getShape()[0] / encoding.getWarpsPerCTA()[0] != + encoding.getRepeatCount() * encoding.getRepCluster()[0] != encoding.getSubGroupSize()) return failure(); From 9a1a7027e06e3de9232f188c88a7cc4fd437c656 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Tue, 15 Oct 2024 17:53:24 +0100 Subject: [PATCH 4/9] Use aux to get CTA layout --- .../OptimizeReductionLocality.cpp | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 1516eedcf7..3cc88a399e 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -22,6 +22,15 @@ namespace mlir::triton::gpu::intel { #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" namespace { +static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter, + std::size_t rank) { + SmallVector ctasPerCGA(rank, 1); + SmallVector ctaSplitNum(rank, 1); + SmallVector ctaOrder(rank); + std::iota(std::rbegin(ctaOrder), std::rend(ctaOrder), 0); + return rewriter.getAttr(ctasPerCGA, ctaSplitNum, ctaOrder); +} + // clang-format off /// Optimize reduction with DPAS-encoded input. /// @@ -213,11 +222,7 @@ struct DPasOperandPattern final : OpRewritePattern { 1, oldEncoding.getWarpsPerCTA()[1], 1}; std::array order{4, 0, 1, 2, 3}; - std::array ctasPerCGA{1, 1, 1, 1, 1}; - std::array ctaSplitNum{1, 1, 1, 1, 1}; - std::array ctaOrder{4, 3, 2, 1, 0}; - auto ctaLayout = - rewriter.getAttr(ctasPerCGA, ctaSplitNum, ctaOrder); + CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); auto encoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); @@ -272,11 +277,7 @@ struct DPasOperandPattern final : OpRewritePattern { std::array warpsPerCTA{dpasEncoding.getWarpsPerCTA()[0], 1, dpasEncoding.getWarpsPerCTA()[1]}; std::array order{2, 0, 1}; - std::array ctasPerCGA{1, 1, 1}; - std::array ctaSplitNum{1, 1, 1}; - std::array ctaOrder{2, 1, 0}; - auto ctaLayout = - rewriter.getAttr(ctasPerCGA, ctaSplitNum, ctaOrder); + CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); auto encoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); @@ -302,11 +303,7 @@ struct DPasOperandPattern final : OpRewritePattern { std::array warpsPerCTA{oldEncoding.getWarpsPerCTA()[0], oldEncoding.getWarpsPerCTA()[2]}; std::array order{1, 0}; - std::array ctasPerCGA{1, 1}; - std::array ctaSplitNum{1, 1}; - std::array ctaOrder{1, 0}; - auto ctaLayout = - rewriter.getAttr(ctasPerCGA, ctaSplitNum, ctaOrder); + CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); auto encoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); From adcd22dd914f03d1eacba3876e2d11534466637c Mon Sep 17 00:00:00 2001 From: victor-eds Date: Wed, 16 Oct 2024 11:06:53 +0100 Subject: [PATCH 5/9] Set in operations --- test/TritonIntelGPU/optimize-reduction.mlir | 46 ++++++++++++------- .../OptimizeReductionLocality.cpp | 14 ++++-- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/test/TritonIntelGPU/optimize-reduction.mlir b/test/TritonIntelGPU/optimize-reduction.mlir index 1fdc5f967d..c1b8aab5cd 100644 --- a/test/TritonIntelGPU/optimize-reduction.mlir +++ b/test/TritonIntelGPU/optimize-reduction.mlir @@ -9,10 +9,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-DAG: #[[$ATTR_2:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> // CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> // CHECK-DAG: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [2, 0, 1]}> // CHECK: tt.func @test_single( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<16x16xf32, #[[$ATTR_2]]> -> tensor<16x16x1x1x1xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<16x16xf32, #[[$ATTR_2]]> -> tensor<16x16x1x1x1xf32, #[[$ATTR_0]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 @@ -23,7 +24,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 // CHECK: }) : (tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>>) -> tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>}>> -> tensor<16x16xf32, #[[$ATTR_1]]> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>}>> -> tensor<16x16x1xf32, #[[$ATTR_3]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x16x1xf32, #[[$ATTR_3]]> -> tensor<16x16xf32, #[[$ATTR_1]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 @@ -53,10 +55,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> // CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> // CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 1], order = [2, 0, 1]}> // CHECK: tt.func @test_single_twice( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_5]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<32x16xf32, #[[$ATTR_5]]> -> tensor<32x16x1x1x1xf32, #[[$ATTR_3]]> +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<32x16xf32, #[[$ATTR_5]]> -> tensor<32x16x1x1x1xf32, #[[$ATTR_3]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 @@ -67,7 +70,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 // CHECK: }) : (tensor<32x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>>) -> tensor<32x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<32x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>}>> -> tensor<32x16xf32, #[[$ATTR_4]]> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<32x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>}>> -> tensor<32x16x1xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<32x16x1xf32, #[[$BLOCKED]]> -> tensor<32x16xf32, #[[$ATTR_4]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 @@ -95,12 +99,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK-DAG: #[[$ATTR_8:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> // CHECK-DAG: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> // CHECK-DAG: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [2, 0, 1]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { // CHECK-LABEL: tt.func @test_two_warps_red( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_8]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<16x32xf32, #[[$ATTR_8]]> -> tensor<16x16x1x2x1xf32, #[[$ATTR_6]]> +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<16x32xf32, #[[$ATTR_8]]> -> tensor<16x16x1x2x1xf32, #[[$ATTR_6]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 @@ -111,7 +116,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 // CHECK: }) : (tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>>) -> tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>}>> -> tensor<16x32xf32, #[[$ATTR_7]]> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>}>> -> tensor<16x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x16x2xf32, #[[$BLOCKED]]> -> tensor<16x32xf32, #[[$ATTR_7]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 @@ -136,15 +142,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1]}> -// CHECK: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> -// CHECK: #[[$ATTR_10:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> -// CHECK: #[[$ATTR_11:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +// CHECK-DAG: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_10:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK-DAG: #[[$ATTR_11:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [2, 0, 1]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { // CHECK-LABEL: tt.func @test_two_warps( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_11]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<32x32xf32, #[[$ATTR_11]]> -> tensor<32x16x1x2x1xf32, #[[$ATTR_9]]> +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<32x32xf32, #[[$ATTR_11]]> -> tensor<32x16x1x2x1xf32, #[[$ATTR_9]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 @@ -155,7 +162,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 // CHECK: }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -> tensor<32x32xf32, #[[$ATTR_10]]> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -> tensor<32x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<32x16x2xf32, #[[$BLOCKED]]> -> tensor<32x32xf32, #[[$ATTR_10]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 @@ -175,7 +183,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: tt.func @test_two_warps_twice( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64x32xf32, #[[$ATTR_11]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<64x32xf32, #[[$ATTR_11]]> -> tensor<64x16x1x2x1xf32, #[[$ATTR_9]]> +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x32xf32, #[[$ATTR_11]]> -> tensor<64x16x1x2x1xf32, #[[$ATTR_9]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 @@ -186,7 +194,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 // CHECK: }) : (tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -> tensor<64x32xf32, #[[$ATTR_10]]> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -> tensor<64x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<64x16x2xf32, #[[$BLOCKED]]> -> tensor<64x32xf32, #[[$ATTR_10]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 @@ -212,13 +221,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: #[[$ATTR_14:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2], A = [16, 8], B = [8, 32], C = [16, 32]}> // CHECK-DAG: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> // CHECK-DAG: #[[$ATTR_13:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 1, 2], order = [2, 0, 1]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { // CHECK: tt.func @test( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_14]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<64x64xf32, #[[$ATTR_14]]> -> tensor<64x16x2x2x1xf32, #[[$ATTR_12]]> +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x64xf32, #[[$ATTR_14]]> -> tensor<64x16x2x2x1xf32, #[[$ATTR_12]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 @@ -229,7 +239,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 // CHECK: }) : (tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -> tensor<64x32xf32, #[[$ATTR_13]]> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -> tensor<64x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<64x16x2xf32, #[[$BLOCKED]]> -> tensor<64x32xf32, #[[$ATTR_13]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 @@ -249,7 +260,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: tt.func @test_repeat_layout( // CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32, #[[$ATTR_14]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true} : tensor<128x128xf32, #[[$ATTR_14]]> -> tensor<128x16x2x2x2xf32, #[[$ATTR_12]]> +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<128x128xf32, #[[$ATTR_14]]> -> tensor<128x16x2x2x2xf32, #[[$ATTR_12]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 @@ -260,7 +271,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 // CHECK: }) : (tensor<128x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>>) -> tensor<128x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[VAL_6]] {allow_reorder = true} : tensor<128x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -> tensor<128x32xf32, #[[$ATTR_13]]> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<128x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -> tensor<128x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<128x16x2xf32, #[[$BLOCKED]]> -> tensor<128x32xf32, #[[$ATTR_13]]> // CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): // CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 3cc88a399e..f6919d8237 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -31,6 +31,14 @@ static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter, return rewriter.getAttr(ctasPerCGA, ctaSplitNum, ctaOrder); } +static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc, + Type type, Value val) { + auto reshapeOp = + rewriter.create(loc, type, val, /*allow_reorder=*/true); + reshapeOp.setEfficientLayout(true); + return reshapeOp; +} + // clang-format off /// Optimize reduction with DPAS-encoded input. /// @@ -232,8 +240,7 @@ struct DPasOperandPattern final : OpRewritePattern { // Although this is a NOP, we have to pass allow_reorder=true as static // analysis will fail to infer it. - return rewriter.create(op.getLoc(), type, val, - /*allow_reorder=*/true); + return createReshapeForReduction(rewriter, op.getLoc(), type, val); } Value performReduction(ReduceOp op, PatternRewriter &rewriter, Value val, @@ -313,8 +320,7 @@ struct DPasOperandPattern final : OpRewritePattern { // Although this is a NOP, we have to pass allow_reorder=true as static // analysis will fail to infer it. - return rewriter.create(op.getLoc(), type, val, - /*allow_reorder=*/true); + return createReshapeForReduction(rewriter, op.getLoc(), type, val); } Value performFinalReduction(ReduceOp op, PatternRewriter &rewriter, From a6a3f1354f248033a7564a351ffe6cef2a7b2e7b Mon Sep 17 00:00:00 2001 From: victor-eds Date: Fri, 18 Oct 2024 16:45:02 +0100 Subject: [PATCH 6/9] Address comments --- test/TritonIntelGPU/optimize-reduction.mlir | 12 ++-- .../TritonIntelGPU/Transforms/Passes.td | 61 +++++++++---------- .../OptimizeReductionLocality.cpp | 15 +++-- 3 files changed, 43 insertions(+), 45 deletions(-) diff --git a/test/TritonIntelGPU/optimize-reduction.mlir b/test/TritonIntelGPU/optimize-reduction.mlir index c1b8aab5cd..2ecf30e0b3 100644 --- a/test/TritonIntelGPU/optimize-reduction.mlir +++ b/test/TritonIntelGPU/optimize-reduction.mlir @@ -1,10 +1,10 @@ -// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-reduction-locality -canonicalize | FileCheck %s +// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-reduction-locality | FileCheck %s // Test reduction in a single warp (16x16->16). #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-DAG: #[[$ATTR_2:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> // CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> @@ -50,7 +50,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> // CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> @@ -101,7 +101,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK-DAG: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [1, 0]}> // CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [2, 0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_two_warps_red( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_8]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> { @@ -147,7 +147,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK-DAG: #[[$ATTR_11:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> // CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [2, 0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_two_warps( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_11]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { @@ -225,7 +225,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK: tt.func @test( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_14]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { // CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x64xf32, #[[$ATTR_14]]> -> tensor<64x16x2x2x1xf32, #[[$ATTR_12]]> diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index de5827b23e..52d1bb9091 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -294,16 +294,13 @@ def TritonIntelGPUOptimizeReductionLocality `triton_gpu.convert_layout` operations, e.g.: ```mlir #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { - tt.func @test.work(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> { - %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - } +tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> } ``` Is converted to: @@ -312,29 +309,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [2, 0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { - tt.func @test_two_warps_twice(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { - %0 = tt.reshape %arg0 {allow_reorder = true} : tensor<32x32xf32, #mma> -> tensor<32x16x1x2x1xf32, #blocked> - %1 = "tt.reduce"(%0) <{axis = 4 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %7 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %7 : f32 - }) : (tensor<32x16x1x2x1xf32, #blocked>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>> - %2 = "tt.reduce"(%1) <{axis = 2 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %7 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %7 : f32 - }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> - %3 = triton_gpu.convert_layout %2 : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<32x16x2xf32, #blocked1> - %4 = tt.reshape %3 {allow_reorder = true} : tensor<32x16x2xf32, #blocked1> -> tensor<32x32xf32, #blocked2> - %5 = "tt.reduce"(%4) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %7 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %7 : f32 - }) : (tensor<32x32xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %6 = triton_gpu.convert_layout %5 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - tt.return %6 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - } +tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = tt.reshape %arg0 {allow_reorder = true} : tensor<32x32xf32, #mma> -> tensor<32x16x1x2x1xf32, #blocked> + %1 = "tt.reduce"(%0) <{axis = 4 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %7 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %7 : f32 + }) : (tensor<32x16x1x2x1xf32, #blocked>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>> + %2 = "tt.reduce"(%1) <{axis = 2 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %7 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %7 : f32 + }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> + %3 = triton_gpu.convert_layout %2 : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<32x16x2xf32, #blocked1> + %4 = tt.reshape %3 {allow_reorder = true} : tensor<32x16x2xf32, #blocked1> -> tensor<32x32xf32, #blocked2> + %5 = "tt.reduce"(%4) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %7 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %7 : f32 + }) : (tensor<32x32xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %6 = triton_gpu.convert_layout %5 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %6 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } ``` The `tt.reshape` operation is a NOP so that the following `tt.reduce` diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index f6919d8237..bc01f9899c 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -23,7 +23,7 @@ namespace mlir::triton::gpu::intel { namespace { static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter, - std::size_t rank) { + size_t rank) { SmallVector ctasPerCGA(rank, 1); SmallVector ctaSplitNum(rank, 1); SmallVector ctaOrder(rank); @@ -121,7 +121,7 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc, /// And reducing on dimension 1 and converting the layout to the original one /// leads to the same output as the original operation. // clang-format on -struct DPasOperandPattern final : OpRewritePattern { +struct DpasOperandPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; static constexpr int preferredNonReductionAxis = 0; @@ -197,6 +197,7 @@ struct DPasOperandPattern final : OpRewritePattern { return success(); } +private: Value reshapeForElementWiseReduction(ReduceOp op, PatternRewriter &rewriter) const { assert(op.getOperands().size() == 1 && "Expecting a single operand"); @@ -206,7 +207,7 @@ struct DPasOperandPattern final : OpRewritePattern { ArrayRef oldShape = oldType.getShape(); auto oldEncoding = cast(oldType.getEncoding()); - constexpr std::size_t rank = 5; + constexpr size_t rank = 5; std::array shape{ // Y axis oldShape[0], @@ -245,6 +246,8 @@ struct DPasOperandPattern final : OpRewritePattern { Value performReduction(ReduceOp op, PatternRewriter &rewriter, Value val, int axis) const { + assert(axis >= 0 && "Expecting positive axis"); + auto newOp = rewriter.create(op.getLoc(), val, /*axis=*/axis); auto &newCombineOp = newOp.getCombineOp(); rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, @@ -275,7 +278,7 @@ struct DPasOperandPattern final : OpRewritePattern { cast(op.getOperands().front().getType()) .getEncoding()); - constexpr std::size_t rank = 3; + constexpr size_t rank = 3; ArrayRef shape = oldType.getShape(); std::array sizePerThread{1, dpasEncoding.getExecutionSize(), 1}; @@ -301,7 +304,7 @@ struct DPasOperandPattern final : OpRewritePattern { ArrayRef oldShape = oldType.getShape(); auto oldEncoding = cast(oldType.getEncoding()); - constexpr std::size_t rank = 2; + constexpr size_t rank = 2; std::array shape{oldShape[0], oldShape[1] * oldShape[2]}; std::array sizePerThread{1, oldEncoding.getSizePerThread()[1]}; @@ -346,7 +349,7 @@ struct TritonIntelGPUOptimizeReductionLocality final Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add(ctx); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); From bc5471cd364d16ab65e93e4de197db59fcd18828 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Fri, 18 Oct 2024 16:49:18 +0100 Subject: [PATCH 7/9] Fix layout order --- test/TritonIntelGPU/optimize-reduction.mlir | 20 +++++++++---------- .../OptimizeReductionLocality.cpp | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/test/TritonIntelGPU/optimize-reduction.mlir b/test/TritonIntelGPU/optimize-reduction.mlir index 2ecf30e0b3..79ff12e072 100644 --- a/test/TritonIntelGPU/optimize-reduction.mlir +++ b/test/TritonIntelGPU/optimize-reduction.mlir @@ -7,9 +7,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-DAG: #[[$ATTR_2:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [1, 2, 3, 4, 0]}> // CHECK-DAG: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [1, 0]}> -// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [2, 0, 1]}> +// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [1, 2, 0]}> // CHECK: tt.func @test_single( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> { @@ -53,9 +53,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [1, 2, 3, 4, 0]}> // CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [1, 0]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 1], order = [2, 0, 1]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 1], order = [1, 2, 0]}> // CHECK: tt.func @test_single_twice( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_5]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> { @@ -97,9 +97,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1]}> // CHECK-DAG: #[[$ATTR_8:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> // CHECK-DAG: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [1, 0]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [2, 0, 1]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [1, 2, 0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { @@ -142,10 +142,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1]}> -// CHECK-DAG: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> // CHECK-DAG: #[[$ATTR_10:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> // CHECK-DAG: #[[$ATTR_11:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [2, 0, 1]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { @@ -219,9 +219,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension. // CHECK-DAG: #[[$ATTR_14:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2], A = [16, 8], B = [8, 32], C = [16, 32]}> -// CHECK-DAG: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> +// CHECK-DAG: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> // CHECK-DAG: #[[$ATTR_13:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [1, 0]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 1, 2], order = [2, 0, 1]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 1, 2], order = [1, 2, 0]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}> diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index bc01f9899c..3fac79046c 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -230,7 +230,7 @@ struct DpasOperandPattern final : OpRewritePattern { std::array warpsPerCTA{oldEncoding.getWarpsPerCTA()[0], 1, 1, oldEncoding.getWarpsPerCTA()[1], 1}; - std::array order{4, 0, 1, 2, 3}; + std::array order{1, 2, 3, 4, 0}; CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); auto encoding = rewriter.getAttr( @@ -247,7 +247,7 @@ struct DpasOperandPattern final : OpRewritePattern { Value performReduction(ReduceOp op, PatternRewriter &rewriter, Value val, int axis) const { assert(axis >= 0 && "Expecting positive axis"); - + auto newOp = rewriter.create(op.getLoc(), val, /*axis=*/axis); auto &newCombineOp = newOp.getCombineOp(); rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, @@ -286,7 +286,7 @@ struct DpasOperandPattern final : OpRewritePattern { 1, 1}; std::array warpsPerCTA{dpasEncoding.getWarpsPerCTA()[0], 1, dpasEncoding.getWarpsPerCTA()[1]}; - std::array order{2, 0, 1}; + std::array order{1, 2, 0}; CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); auto encoding = rewriter.getAttr( From 5ed11cb9e357e64993fe23e7ce1d28b08aaec48a Mon Sep 17 00:00:00 2001 From: victor-eds Date: Fri, 18 Oct 2024 16:52:20 +0100 Subject: [PATCH 8/9] Update doc --- .../Dialect/TritonIntelGPU/Transforms/Passes.td | 4 ++-- .../OptimizeReductionLocality.cpp | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 52d1bb9091..3b738b880e 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -305,8 +305,8 @@ tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slic ``` Is converted to: ```mlir -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [4, 0, 1, 2, 3]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [2, 0, 1]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 3fac79046c..131e85c2d3 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -69,7 +69,7 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc, /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | /// ``` - /// Blocked (#triton_gpu.blocked<{sizePerThread = [repCluster[0]*repeatCount, 1, 1, 1, 1], threadsPerWarp = [1, executionSize, 1, 1, 1], warpsPerCTA = [warpsPerCTA[0], 1, 1, warpsPerCTA[1], 1], order = [4, 0, 1, 2, 3]}>): + /// Blocked (#triton_gpu.blocked<{sizePerThread = [repCluster[0]*repeatCount, 1, 1, 1, 1], threadsPerWarp = [1, executionSize, 1, 1, 1], warpsPerCTA = [warpsPerCTA[0], 1, 1, warpsPerCTA[1], 1], order = [1, 2, 3, 4, 0]}>): /// ``` /// warpsPerCTA[3] /// <-------------------------------------------------------------------------------> @@ -113,10 +113,10 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc, /// <------------------------------------> /// sizePerThread[1] /// <------------------> - /// ^ t0 t0 t0 t0 ... t0 tn1 tn1 tn1 ... tn1 ^ - /// | t1 t1 t1 t1 ... t1 tn2 tn2 tn2 ... tn2 | - /// sizePerThread[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[0] - /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | + /// ^ t0 t0 t0 t0 ... t0 tn1 tn1 tn1 ... tn1 ^ + /// | t1 t1 t1 t1 ... t1 tn2 tn2 tn2 ... tn2 | + /// threadsPerWarp[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[0] + /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | /// ``` /// And reducing on dimension 1 and converting the layout to the original one /// leads to the same output as the original operation. From 07fed7015bf1a16f8e3c710f69df4c8432c59fbe Mon Sep 17 00:00:00 2001 From: victor-eds Date: Tue, 22 Oct 2024 13:34:12 +0100 Subject: [PATCH 9/9] Improve documentation --- .../OptimizeReductionLocality.cpp | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 131e85c2d3..ccc564ea66 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -43,12 +43,26 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc, /// Optimize reduction with DPAS-encoded input. /// /// This optimization reshapes and converts input tensor layouts to split the - /// reduction in three equivalent ones: + /// reduction in three equivalent ones. /// /// This only works if the number of items for a given thread across dimension /// 0 and the execution size are equal to the sub-group size. /// - /// First, we go from a DPAS layout to an equivalent blocked layout as follows: + /// We first want to reshape the input tensor to obtain a tensor with an + /// equivalent encoding in terms of how elements are distributed across the + /// device, but with more dimensions across the reduction axis. This way, we + /// will be able to split the reduction in three steps: + /// + /// 1. Reduce within the work-item + /// 2. Convert layout for better locality + /// 3. Reduce within the sub-group and work-group + /// + /// Step 1 may involve more than one dimension depending on the input encoding + /// (2 in this case). After step 1, each thread will hold a single element + /// across the reduction axis dimension, so step 2 will be cheaper. + /// + /// For step 1, we first go from a DPAS layout to an equivalent blocked layout + /// as follows: /// /// DPAS: /// ``` @@ -105,8 +119,9 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc, /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | /// v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | /// ``` - /// After reshaping and layout conversion, we can get to the actual layout - /// optimization we wanted to achieve: + /// + /// Now on with step 2: After reshaping and layout conversion, we can get to + /// the actual layout optimization we wanted to achieve: /// Blocked (#triton_gpu.blocked<{sizePerThread = [1, repCluster[0]*repeatCount], threadsPerWarp = [executionSize, 1], warpsPerCTA = [warpsPerCTA[0], warpsPerCTA[1]], order = [1, 0]}>): /// ``` /// warpsPerCTA[1] @@ -118,8 +133,8 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc, /// threadsPerWarp[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[0] /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | /// ``` - /// And reducing on dimension 1 and converting the layout to the original one - /// leads to the same output as the original operation. + /// And on with step 3, reducing on dimension 1 and converting the layout to + /// the original one leads to the same output as the original operation. // clang-format on struct DpasOperandPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern;