diff --git a/test/TritonIntelGPU/optimize-reduction.mlir b/test/TritonIntelGPU/optimize-reduction.mlir new file mode 100644 index 0000000000..79ff12e072 --- /dev/null +++ b/test/TritonIntelGPU/optimize-reduction.mlir @@ -0,0 +1,292 @@ +// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-reduction-locality | FileCheck %s + +// Test reduction in a single warp (16x16->16). + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + +// CHECK-DAG: #[[$ATTR_2:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [1, 2, 3, 4, 0]}> +// CHECK-DAG: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [1, 2, 0]}> + +// CHECK: tt.func @test_single( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<16x16xf32, #[[$ATTR_2]]> -> tensor<16x16x1x1x1xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x16x1x1x1xf32, #[[$ATTR_0]]>) -> tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>>) -> tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_0]]}>}>> -> tensor<16x16x1xf32, #[[$ATTR_3]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x16x1xf32, #[[$ATTR_3]]> -> tensor<16x16xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<16x16xf32, #[[$ATTR_1]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> +// CHECK: } + tt.func @test_single(%arg0: tensor<16x16xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<16x16xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across the non-reduction dimension (32x16->32). + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + +// CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 1, 1], order = [1, 2, 3, 4, 0]}> +// CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 1], order = [1, 2, 0]}> + +// CHECK: tt.func @test_single_twice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_5]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<32x16xf32, #[[$ATTR_5]]> -> tensor<32x16x1x1x1xf32, #[[$ATTR_3]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<32x16x1x1x1xf32, #[[$ATTR_3]]>) -> tensor<32x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<32x16x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>>) -> tensor<32x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<32x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_3]]}>}>> -> tensor<32x16x1xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<32x16x1xf32, #[[$BLOCKED]]> -> tensor<32x16xf32, #[[$ATTR_4]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<32x16xf32, #[[$ATTR_4]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> +// CHECK: } + tt.func @test_single_twice(%arg0: tensor<32x16xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x16xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across the reduction dimension (16x32->16). + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1]}> + +// CHECK-DAG: #[[$ATTR_8:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +// CHECK-DAG: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> +// CHECK-DAG: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [1, 2, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: tt.func @test_two_warps_red( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_8]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<16x32xf32, #[[$ATTR_8]]> -> tensor<16x16x1x2x1xf32, #[[$ATTR_6]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x16x1x2x1xf32, #[[$ATTR_6]]>) -> tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>>) -> tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_6]]}>}>> -> tensor<16x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x16x2xf32, #[[$BLOCKED]]> -> tensor<16x32xf32, #[[$ATTR_7]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<16x32xf32, #[[$ATTR_7]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_7]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_7]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> +// CHECK: } + tt.func @test_two_warps_red(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across both dimensions (32x32->32). + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1]}> + +// CHECK-DAG: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> +// CHECK-DAG: #[[$ATTR_10:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK-DAG: #[[$ATTR_11:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: tt.func @test_two_warps( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_11]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<32x32xf32, #[[$ATTR_11]]> -> tensor<32x16x1x2x1xf32, #[[$ATTR_9]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<32x16x1x2x1xf32, #[[$ATTR_9]]>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -> tensor<32x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<32x16x2xf32, #[[$BLOCKED]]> -> tensor<32x32xf32, #[[$ATTR_10]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<32x32xf32, #[[$ATTR_10]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: } + tt.func @test_two_warps(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } + +// CHECK-LABEL: tt.func @test_two_warps_twice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x32xf32, #[[$ATTR_11]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x32xf32, #[[$ATTR_11]]> -> tensor<64x16x1x2x1xf32, #[[$ATTR_9]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<64x16x1x2x1xf32, #[[$ATTR_9]]>) -> tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<64x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_9]]}>}>> -> tensor<64x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<64x16x2xf32, #[[$BLOCKED]]> -> tensor<64x32xf32, #[[$ATTR_10]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<64x32xf32, #[[$ATTR_10]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_10]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: } + tt.func @test_two_warps_twice(%arg0: tensor<64x32xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64x32xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension. + +// CHECK-DAG: #[[$ATTR_14:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2], A = [16, 8], B = [8, 32], C = [16, 32]}> +// CHECK-DAG: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> +// CHECK-DAG: #[[$ATTR_13:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 1, 2], order = [1, 2, 0]}> + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { +// CHECK: tt.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_14]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<64x64xf32, #[[$ATTR_14]]> -> tensor<64x16x2x2x1xf32, #[[$ATTR_12]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<64x16x2x2x1xf32, #[[$ATTR_12]]>) -> tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<64x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>>) -> tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<64x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -> tensor<64x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<64x16x2xf32, #[[$BLOCKED]]> -> tensor<64x32xf32, #[[$ATTR_13]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<64x32xf32, #[[$ATTR_13]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: } + tt.func @test(%arg0: tensor<64x64xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64x64xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } + +// CHECK: tt.func @test_repeat_layout( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32, #[[$ATTR_14]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<128x128xf32, #[[$ATTR_14]]> -> tensor<128x16x2x2x2xf32, #[[$ATTR_12]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<128x16x2x2x2xf32, #[[$ATTR_12]]>) -> tensor<128x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<128x16x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>>) -> tensor<128x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> +// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<128x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #[[$ATTR_12]]}>}>> -> tensor<128x16x2xf32, #[[$BLOCKED]]> +// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<128x16x2xf32, #[[$BLOCKED]]> -> tensor<128x32xf32, #[[$ATTR_13]]> +// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: tt.reduce.return %[[VAL_14]] : f32 +// CHECK: }) : (tensor<128x32xf32, #[[$ATTR_13]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_13]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_15]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: } + tt.func @test_repeat_layout(%arg0: tensor<128x128xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<128x128xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 42e386fe29..3b738b880e 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -284,4 +284,67 @@ def TritonIntelGPUMaterializeBlockPointer : Pass<"tritonintelgpu-materialize-blo "mlir::arith::ArithDialect"]; } +def TritonIntelGPUOptimizeReductionLocality + : Pass<"tritonintelgpu-optimize-reduction-locality", "mlir::ModuleOp"> { + let summary = "Minimize number of reductions within sub-groups"; + + let description = [{ + This pass performs layout conversions so `tt.reduce` operations resulting in + sub-group reductions are converted to `tt.reshape`, `tt.reduce`, and + `triton_gpu.convert_layout` operations, e.g.: + ```mlir +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> +tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> +} + ``` + Is converted to: + ```mlir +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [2, 1, 1, 2, 1], order = [1, 2, 3, 4, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> +tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = tt.reshape %arg0 {allow_reorder = true} : tensor<32x32xf32, #mma> -> tensor<32x16x1x2x1xf32, #blocked> + %1 = "tt.reduce"(%0) <{axis = 4 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %7 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %7 : f32 + }) : (tensor<32x16x1x2x1xf32, #blocked>) -> tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>> + %2 = "tt.reduce"(%1) <{axis = 2 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %7 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %7 : f32 + }) : (tensor<32x16x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #blocked}>>) -> tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> + %3 = triton_gpu.convert_layout %2 : tensor<32x16x2xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<32x16x2xf32, #blocked1> + %4 = tt.reshape %3 {allow_reorder = true} : tensor<32x16x2xf32, #blocked1> -> tensor<32x32xf32, #blocked2> + %5 = "tt.reduce"(%4) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %7 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %7 : f32 + }) : (tensor<32x32xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %6 = triton_gpu.convert_layout %5 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %6 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> +} + ``` + The `tt.reshape` operation is a NOP so that the following `tt.reduce` + operation performs a reduction within the work-item. Then, + `triton_gpu.convert_layout` performs the actual data movement to prevent + within the sub-group reductions so that the following `tt.reduce` performs + a reduction within the work-item again. Finally, we convert back to the + expected layout. + + Note this pass only supports `triton_intel_gpu.dpas` input layouts at the + moment, but it should be easily extended. + }]; + + let dependentDialects = ["mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect"]; +} + #endif // TRITON_INTEL_GPU_PASSES diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index 8c2e290ada..d215107a53 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonIntelGPUTransforms DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp + OptimizeReductionLocality.cpp Pipeliner/MatmulLoopPipeline.cpp Pipeliner/SoftwarePipeliner.cpp PrefetchBlock.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp new file mode 100644 index 0000000000..ccc564ea66 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -0,0 +1,374 @@ +//===- OptimizeReductionLocality.cpp ------------------------------------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// This file implements the `tritonintelgpu-optimize-reduction-locality` pass. +//===----------------------------------------------------------------------===// + +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#define DEBUG_TYPE "tritonintelgpu-optimize-reduction-locality" + +namespace mlir::triton::gpu::intel { +#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEREDUCTIONLOCALITY +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" + +namespace { +static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter, + size_t rank) { + SmallVector ctasPerCGA(rank, 1); + SmallVector ctaSplitNum(rank, 1); + SmallVector ctaOrder(rank); + std::iota(std::rbegin(ctaOrder), std::rend(ctaOrder), 0); + return rewriter.getAttr(ctasPerCGA, ctaSplitNum, ctaOrder); +} + +static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc, + Type type, Value val) { + auto reshapeOp = + rewriter.create(loc, type, val, /*allow_reorder=*/true); + reshapeOp.setEfficientLayout(true); + return reshapeOp; +} + +// clang-format off + /// Optimize reduction with DPAS-encoded input. + /// + /// This optimization reshapes and converts input tensor layouts to split the + /// reduction in three equivalent ones. + /// + /// This only works if the number of items for a given thread across dimension + /// 0 and the execution size are equal to the sub-group size. + /// + /// We first want to reshape the input tensor to obtain a tensor with an + /// equivalent encoding in terms of how elements are distributed across the + /// device, but with more dimensions across the reduction axis. This way, we + /// will be able to split the reduction in three steps: + /// + /// 1. Reduce within the work-item + /// 2. Convert layout for better locality + /// 3. Reduce within the sub-group and work-group + /// + /// Step 1 may involve more than one dimension depending on the input encoding + /// (2 in this case). After step 1, each thread will hold a single element + /// across the reduction axis dimension, so step 2 will be cheaper. + /// + /// For step 1, we first go from a DPAS layout to an equivalent blocked layout + /// as follows: + /// + /// DPAS: + /// ``` + /// warpsPerCTA[1] + /// <--------------------------------------------------------------------------------> + /// repCluster[1] + /// <-----------------------------------> + /// execution size + /// <----------------> + /// ^ ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^ + /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | repeatCount | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[0] + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// repCluster[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// ``` + /// Blocked (#triton_gpu.blocked<{sizePerThread = [repCluster[0]*repeatCount, 1, 1, 1, 1], threadsPerWarp = [1, executionSize, 1, 1, 1], warpsPerCTA = [warpsPerCTA[0], 1, 1, warpsPerCTA[1], 1], order = [1, 2, 3, 4, 0]}>): + /// ``` + /// warpsPerCTA[3] + /// <-------------------------------------------------------------------------------> + /// size[2] + /// <----------------------------------> + /// threadsPerWarp[1] + /// <----------------> + /// ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^ + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[0] + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// sizePerThread[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// ``` + /// So we can reduce on dimensions 4 and 2 to get to: + /// ``` + /// warpsPerCTA[2] + /// <------------------------------------> + /// threadsPerWarp[1] + /// <------------------> + /// ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn ^ + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | warpsPerCTA[0] + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// sizePerThread[0] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// ``` + /// + /// Now on with step 2: After reshaping and layout conversion, we can get to + /// the actual layout optimization we wanted to achieve: + /// Blocked (#triton_gpu.blocked<{sizePerThread = [1, repCluster[0]*repeatCount], threadsPerWarp = [executionSize, 1], warpsPerCTA = [warpsPerCTA[0], warpsPerCTA[1]], order = [1, 0]}>): + /// ``` + /// warpsPerCTA[1] + /// <------------------------------------> + /// sizePerThread[1] + /// <------------------> + /// ^ t0 t0 t0 t0 ... t0 tn1 tn1 tn1 ... tn1 ^ + /// | t1 t1 t1 t1 ... t1 tn2 tn2 tn2 ... tn2 | + /// threadsPerWarp[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[0] + /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | + /// ``` + /// And on with step 3, reducing on dimension 1 and converting the layout to + /// the original one leads to the same output as the original operation. +// clang-format on +struct DpasOperandPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static constexpr int preferredNonReductionAxis = 0; + static constexpr int preferredReductionAxis = 1; + static constexpr int repCountReshapedAxis = 2; + static constexpr int withinWarpXAxisReshapedAxis = 4; + + LogicalResult matchAndRewrite(ReduceOp op, + PatternRewriter &rewriter) const final { + ValueRange operands = op.getOperands(); + // Allowing single operand for now + if (operands.size() != 1) + return failure(); + // Check this is has `triton_intel_gpu.dpas` encoding. + Value operand = operands.front(); + auto type = cast(operand.getType()); + auto encoding = + llvm::dyn_cast_or_null(type.getEncoding()); + if (!encoding) + return failure(); + + // Axis 1 will lead to within-warp reduction. + assert(type.getRank() == 2 && "Expecting 2D tensor"); + if (op.getAxis() != preferredReductionAxis) + return failure(); + + // We want to transpose matrices of (threads_per_warp)^2 shape for now. + if ( // X axis condition + encoding.getExecutionSize() != encoding.getSubGroupSize() || + // Y axis condition + encoding.getRepeatCount() * encoding.getRepCluster()[0] != + encoding.getSubGroupSize()) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "Optimizing reduction: " << op << "\n"); + + operand = reshapeForElementWiseReduction(op, rewriter); + + LLVM_DEBUG(llvm::dbgs() + << "Reshaped for elementwise reduction: " << operand << "\n"); + + operand = performElementWiseReductionAcrossRepCounts(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Performed elementwise reduction across repCount: " << operand + << "\n"); + + operand = performElementWiseReductionWithinRepCount(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Performed elementwise reduction within repCount: " << operand + << "\n"); + + operand = convertLayoutForFinalReduction(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Converted layout for final reduction: " << operand << "\n"); + + operand = reshapeForFinalReduction(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Reshaped for final reduction: " << operand << "\n"); + + operand = performFinalReduction(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Final reduction performed: " << operand << "\n"); + + operand = convertToOriginalType(op, rewriter, operand); + + rewriter.replaceOp(op, operand); + + return success(); + } + +private: + Value reshapeForElementWiseReduction(ReduceOp op, + PatternRewriter &rewriter) const { + assert(op.getOperands().size() == 1 && "Expecting a single operand"); + + Value val = op.getOperands().front(); + auto oldType = cast(val.getType()); + ArrayRef oldShape = oldType.getShape(); + auto oldEncoding = cast(oldType.getEncoding()); + + constexpr size_t rank = 5; + std::array shape{ + // Y axis + oldShape[0], + // X axis contiguous elements distributed within individual threads in a + // warp. + oldEncoding.getExecutionSize(), + // X axis contiguous elements distributed within a warp. + oldEncoding.getRepCluster()[1], + // X axis number of warps. + oldEncoding.getWarpsPerCTA()[1], + // X axis rest. + oldShape[1] / + (oldEncoding.getExecutionSize() * oldEncoding.getRepCluster()[1] * + oldEncoding.getWarpsPerCTA()[1])}; + std::array sizePerThread{oldEncoding.getRepeatCount() * + oldEncoding.getRepCluster()[0], + 1, 1, 1, 1}; + std::array threadsPerWarp{1, oldEncoding.getExecutionSize(), + 1, 1, 1}; + std::array warpsPerCTA{oldEncoding.getWarpsPerCTA()[0], 1, + 1, oldEncoding.getWarpsPerCTA()[1], + 1}; + std::array order{1, 2, 3, 4, 0}; + CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType type = + RankedTensorType::get(shape, oldType.getElementType(), encoding); + + // Although this is a NOP, we have to pass allow_reorder=true as static + // analysis will fail to infer it. + return createReshapeForReduction(rewriter, op.getLoc(), type, val); + } + + Value performReduction(ReduceOp op, PatternRewriter &rewriter, Value val, + int axis) const { + assert(axis >= 0 && "Expecting positive axis"); + + auto newOp = rewriter.create(op.getLoc(), val, /*axis=*/axis); + auto &newCombineOp = newOp.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + assert(newOp.getResult().size() == 1 && "Expecting single result"); + return newOp.getResult().front(); + } + + Value performElementWiseReductionWithinRepCount(ReduceOp op, + PatternRewriter &rewriter, + Value val) const { + return performReduction(op, rewriter, val, /*axis=*/repCountReshapedAxis); + } + + Value performElementWiseReductionAcrossRepCounts(ReduceOp op, + PatternRewriter &rewriter, + Value val) const { + return performReduction(op, rewriter, val, + /*axis=*/withinWarpXAxisReshapedAxis); + } + + Value convertLayoutForFinalReduction(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + assert(op.getOperands().size() == 1 && "Expecting a single operand"); + + auto oldType = cast(val.getType()); + auto dpasEncoding = cast( + cast(op.getOperands().front().getType()) + .getEncoding()); + + constexpr size_t rank = 3; + ArrayRef shape = oldType.getShape(); + std::array sizePerThread{1, dpasEncoding.getExecutionSize(), + 1}; + std::array threadsPerWarp{dpasEncoding.getExecutionSize(), + 1, 1}; + std::array warpsPerCTA{dpasEncoding.getWarpsPerCTA()[0], 1, + dpasEncoding.getWarpsPerCTA()[1]}; + std::array order{1, 2, 0}; + CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType type = + RankedTensorType::get(shape, oldType.getElementType(), encoding); + + return rewriter.create(op.getLoc(), type, val); + } + + Value reshapeForFinalReduction(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + auto oldType = cast(val.getType()); + ArrayRef oldShape = oldType.getShape(); + auto oldEncoding = cast(oldType.getEncoding()); + + constexpr size_t rank = 2; + std::array shape{oldShape[0], oldShape[1] * oldShape[2]}; + std::array sizePerThread{1, + oldEncoding.getSizePerThread()[1]}; + std::array threadsPerWarp{ + oldEncoding.getThreadsPerWarp()[0], 1}; + std::array warpsPerCTA{oldEncoding.getWarpsPerCTA()[0], + oldEncoding.getWarpsPerCTA()[2]}; + std::array order{1, 0}; + CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType type = + RankedTensorType::get(shape, oldType.getElementType(), encoding); + + // Although this is a NOP, we have to pass allow_reorder=true as static + // analysis will fail to infer it. + return createReshapeForReduction(rewriter, op.getLoc(), type, val); + } + + Value performFinalReduction(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + return performReduction(op, rewriter, val, /*axis=*/preferredReductionAxis); + } + + Value convertToOriginalType(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + return rewriter.create( + op.getLoc(), op.getResult().front().getType(), val); + } +}; + +struct TritonIntelGPUOptimizeReductionLocality final + : impl::TritonIntelGPUOptimizeReductionLocalityBase< + TritonIntelGPUOptimizeReductionLocality> { + using impl::TritonIntelGPUOptimizeReductionLocalityBase< + TritonIntelGPUOptimizeReductionLocality>:: + TritonIntelGPUOptimizeReductionLocalityBase; + + void runOnOperation() final { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace +} // namespace mlir::triton::gpu::intel