diff --git a/test/TritonIntelGPU/optimize-reduction.mlir b/test/TritonIntelGPU/optimize-reduction.mlir index 7250d1ad2c..8c7b44192f 100644 --- a/test/TritonIntelGPU/optimize-reduction.mlir +++ b/test/TritonIntelGPU/optimize-reduction.mlir @@ -2,39 +2,43 @@ // Test reduction in a single warp (16x16->16). +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 8, 2, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [0, 1, 2, 3, 4]}> +// CHECK: #[[$ATTR_2:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1], warpsPerCTA = [1, 1, 1, 1], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +// CHECK: #[[$ATTR_4:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { -// CHECK-DAG: #[[$ATTR_2:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [3, 4, 5, 6, 0, 1, 2]}> -// CHECK-DAG: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1], order = [3, 0, 1, 2]}> -// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [3, 4, 0, 1, 2]}> -// CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 1, 1], order = [3, 0, 1, 2]}> - -// CHECK: tt.func @test_single( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x16xf32, #[[$ATTR_2]]> -> tensor<16x1x1x16x1x1x1xf32, #[[$ATTR_0]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ +// CHECK-LABEL: tt.func @test_single( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_4]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x16xf32, #[[$ATTR_4]]> -> tensor<16x8x1x2x1x1x1xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x1x1x16x1x1x1xf32, #[[$ATTR_0]]>) -> tensor<16x1x1x16x1x1xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_0]]}>> +// CHECK: }) : (tensor<16x8x1x2x1x1x1xf32, #[[$ATTR_0]]>) -> tensor<16x8x2x1x1x1xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_0]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x1x1x16x1x1xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_0]]}>>) -> tensor<16x1x1x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_0]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x1x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_0]]}>}>> -> tensor<16x1x1x16x1xf32, #[[$ATTR_3]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x1x1x16x1xf32, #[[$ATTR_3]]> -> tensor<16x1x1x16xf32, #[[$ATTR_1]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ -// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): -// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 -// CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x1x1x16xf32, #[[$ATTR_1]]>) -> tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_1]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_1]]}>> -> tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_4]]}>> -// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_4]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> -// CHECK: tt.return %[[VAL_16]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> +// CHECK: }) : (tensor<16x8x2x1x1x1xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_0]]}>>) -> tensor<16x8x2x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x8x2x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> -> tensor<16x8x2x1x1xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x1x1xf32, #[[$ATTR_1]]> -> tensor<16x16x1x1xf32, #[[$ATTR_2]]> +// CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: tt.reduce.return %[[VAL_15]] : f32 +// CHECK: }) : (tensor<16x16x1x1xf32, #[[$ATTR_2]]>) -> tensor<16x1x1xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_2]]}>> +// CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: tt.reduce.return %[[VAL_19]] : f32 +// CHECK: }) : (tensor<16x1x1xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_2]]}>>) -> tensor<16x1xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_2]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x1xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_2]]}>}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_21:.*]] = triton_gpu.convert_layout %[[VAL_20]] : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_3]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>> // CHECK: } tt.func @test_single(%arg0: tensor<16x16xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -50,39 +54,43 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // Test reduction in two warps across the non-reduction dimension (32x16->32). +// CHECK: #[[$ATTR_5:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 2], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 8, 2, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2], order = [0, 1, 2, 3, 4]}> +// CHECK: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1], warpsPerCTA = [1, 1, 1, 2], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_8:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 2], order = [0, 1]}> +// CHECK: #[[$ATTR_9:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { -// CHECK-DAG: #[[$ATTR_5:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1, 1, 1, 1], order = [3, 4, 5, 6, 0, 1, 2]}> -// CHECK-DAG: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1], order = [3, 0, 1, 2]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1, 1], order = [3, 4, 0, 1, 2]}> -// CHECK-DAG: #[[$BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 2, 1], order = [3, 0, 1, 2]}> - -// CHECK: tt.func @test_single_twice( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_5]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x16xf32, #[[$ATTR_5]]> -> tensor<16x1x2x16x1x1x1xf32, #[[$ATTR_3]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ +// CHECK-LABEL: tt.func @test_single_twice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_9]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_9]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x16xf32, #[[$ATTR_9]]> -> tensor<16x8x1x2x1x1x2xf32, #[[$ATTR_5]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x1x2x16x1x1x1xf32, #[[$ATTR_3]]>) -> tensor<16x1x2x16x1x1xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_3]]}>> +// CHECK: }) : (tensor<16x8x1x2x1x1x2xf32, #[[$ATTR_5]]>) -> tensor<16x8x2x1x1x2xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_5]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x1x2x16x1x1xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_3]]}>>) -> tensor<16x1x2x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_3]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x2x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_3]]}>}>> -> tensor<16x1x2x16x1xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x1x2x16x1xf32, #[[$BLOCKED]]> -> tensor<16x1x2x16xf32, #[[$ATTR_4]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ -// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): -// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 -// CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x1x2x16xf32, #[[$ATTR_4]]>) -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_4]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_4]]}>> -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> -// CHECK: tt.return %[[VAL_16]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>> +// CHECK: }) : (tensor<16x8x2x1x1x2xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_5]]}>>) -> tensor<16x8x2x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x8x2x1x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> -> tensor<16x8x2x1x2xf32, #[[$ATTR_6]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x1x2xf32, #[[$ATTR_6]]> -> tensor<16x16x1x2xf32, #[[$ATTR_7]]> +// CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: tt.reduce.return %[[VAL_15]] : f32 +// CHECK: }) : (tensor<16x16x1x2xf32, #[[$ATTR_7]]>) -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_7]]}>> +// CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: tt.reduce.return %[[VAL_19]] : f32 +// CHECK: }) : (tensor<16x1x2xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_7]]}>>) -> tensor<16x2xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_7]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x2xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_7]]}>}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_8]]}>> +// CHECK: %[[VAL_21:.*]] = triton_gpu.convert_layout %[[VAL_20]] : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_8]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_9]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_9]]}>> // CHECK: } tt.func @test_single_twice(%arg0: tensor<32x16xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -98,39 +106,43 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // Test reduction in two warps across the reduction dimension (16x32->16). +// CHECK: #[[$ATTR_10:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 1], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_11:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 8, 2, 1, 1], warpsPerCTA = [1, 1, 1, 2, 1], order = [0, 1, 2, 3, 4]}> +// CHECK: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 1], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_13:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}> +// CHECK: #[[$ATTR_14:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1]}> -// CHECK-DAG: #[[$ATTR_8:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$ATTR_6:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 2, 1], order = [3, 4, 5, 6, 0, 1, 2]}> -// CHECK-DAG: #[[$ATTR_7:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2], order = [3, 0, 1, 2]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2], order = [3, 4, 0, 1, 2]}> -// CHECK-DAG: #[[$BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 1, 2], order = [3, 0, 1, 2]}> - module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_two_warps_red( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_8]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x32xf32, #[[$ATTR_8]]> -> tensor<16x1x1x16x1x2x1xf32, #[[$ATTR_6]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_14]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x32xf32, #[[$ATTR_14]]> -> tensor<16x8x1x2x2x1x1xf32, #[[$ATTR_10]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x1x1x16x1x2x1xf32, #[[$ATTR_6]]>) -> tensor<16x1x1x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_6]]}>> +// CHECK: }) : (tensor<16x8x1x2x2x1x1xf32, #[[$ATTR_10]]>) -> tensor<16x8x2x2x1x1xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_10]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x1x1x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_6]]}>>) -> tensor<16x1x1x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_6]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x1x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_6]]}>}>> -> tensor<16x1x1x16x2xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x1x1x16x2xf32, #[[$BLOCKED]]> -> tensor<16x1x1x32xf32, #[[$ATTR_7]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ -// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): -// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 -// CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x1x1x32xf32, #[[$ATTR_7]]>) -> tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_7]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_7]]}>> -> tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x1x1xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> -// CHECK: tt.return %[[VAL_16]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_8]]}>> +// CHECK: }) : (tensor<16x8x2x2x1x1xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_10]]}>>) -> tensor<16x8x2x2x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x8x2x2x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> -> tensor<16x8x2x2x1xf32, #[[$ATTR_11]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x1xf32, #[[$ATTR_11]]> -> tensor<16x16x2x1xf32, #[[$ATTR_12]]> +// CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: tt.reduce.return %[[VAL_15]] : f32 +// CHECK: }) : (tensor<16x16x2x1xf32, #[[$ATTR_12]]>) -> tensor<16x2x1xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_12]]}>> +// CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: tt.reduce.return %[[VAL_19]] : f32 +// CHECK: }) : (tensor<16x2x1xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_12]]}>>) -> tensor<16x1xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_12]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x1xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_12]]}>}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_21:.*]] = triton_gpu.convert_layout %[[VAL_20]] : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_13]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> // CHECK: } tt.func @test_two_warps_red(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -146,39 +158,43 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // Test reduction in two warps across both dimensions (32x32->32). +// CHECK: #[[$ATTR_15:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 2], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_16:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 8, 2, 1, 1], warpsPerCTA = [1, 1, 1, 2, 2], order = [0, 1, 2, 3, 4]}> +// CHECK: #[[$ATTR_17:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 2], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_18:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}> +// CHECK: #[[$ATTR_19:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1]}> -// CHECK-DAG: #[[$ATTR_9:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1, 1, 2, 1], order = [3, 4, 5, 6, 0, 1, 2]}> -// CHECK-DAG: #[[$ATTR_10:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 2], order = [3, 0, 1, 2]}> -// CHECK-DAG: #[[$ATTR_11:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1, 2], order = [3, 4, 0, 1, 2]}> -// CHECK-DAG: #[[$BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 2, 2], order = [3, 0, 1, 2]}> - module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_two_warps( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_11]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x32xf32, #[[$ATTR_11]]> -> tensor<16x1x2x16x1x2x1xf32, #[[$ATTR_9]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_19]]>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_19]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x32xf32, #[[$ATTR_19]]> -> tensor<16x8x1x2x2x1x2xf32, #[[$ATTR_15]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x1x2x16x1x2x1xf32, #[[$ATTR_9]]>) -> tensor<16x1x2x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>> +// CHECK: }) : (tensor<16x8x1x2x2x1x2xf32, #[[$ATTR_15]]>) -> tensor<16x8x2x2x1x2xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_15]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x1x2x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>>) -> tensor<16x1x2x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x2x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>}>> -> tensor<16x1x2x16x2xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x1x2x16x2xf32, #[[$BLOCKED]]> -> tensor<16x1x2x32xf32, #[[$ATTR_10]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ -// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): -// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 -// CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x1x2x32xf32, #[[$ATTR_10]]>) -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_10]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_10]]}>> -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> -// CHECK: tt.return %[[VAL_16]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> +// CHECK: }) : (tensor<16x8x2x2x1x2xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_15]]}>>) -> tensor<16x8x2x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x8x2x2x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> -> tensor<16x8x2x2x2xf32, #[[$ATTR_16]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x2xf32, #[[$ATTR_16]]> -> tensor<16x16x2x2xf32, #[[$ATTR_17]]> +// CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: tt.reduce.return %[[VAL_15]] : f32 +// CHECK: }) : (tensor<16x16x2x2xf32, #[[$ATTR_17]]>) -> tensor<16x2x2xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_17]]}>> +// CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: tt.reduce.return %[[VAL_19]] : f32 +// CHECK: }) : (tensor<16x2x2xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_17]]}>>) -> tensor<16x2xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_17]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x2xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_17]]}>}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_18]]}>> +// CHECK: %[[VAL_21:.*]] = triton_gpu.convert_layout %[[VAL_20]] : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_18]]}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_19]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_19]]}>> // CHECK: } tt.func @test_two_warps(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -188,77 +204,48 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } - -// CHECK-LABEL: tt.func @test_two_warps_twice( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x32xf32, #[[$ATTR_11]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<64x32xf32, #[[$ATTR_11]]> -> tensor<16x1x4x16x1x2x1xf32, #[[$ATTR_9]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x1x4x16x1x2x1xf32, #[[$ATTR_9]]>) -> tensor<16x1x4x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>> -// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x1x4x16x1x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>>) -> tensor<16x1x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_9]]}>}>> -> tensor<16x1x4x16x2xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x1x4x16x2xf32, #[[$BLOCKED]]> -> tensor<16x1x4x32xf32, #[[$ATTR_10]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ -// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): -// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 -// CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x1x4x32xf32, #[[$ATTR_10]]>) -> tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_10]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_10]]}>> -> tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> -// CHECK: tt.return %[[VAL_16]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_11]]}>> -// CHECK: } - tt.func @test_two_warps_twice(%arg0: tensor<64x32xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { - %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<64x32xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - } } // ----- // Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension. -// CHECK-DAG: #[[$ATTR_14:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2], A = [16, 8], B = [8, 32], C = [16, 32]}> -// CHECK-DAG: #[[$ATTR_12:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 4, 1, 1, 2, 1], order = [3, 4, 5, 6, 0, 1, 2]}> -// CHECK-DAG: #[[$ATTR_13:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 4, 2], order = [3, 0, 1, 2]}> -// CHECK-DAG: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 4, 1, 2], order = [3, 4, 0, 1, 2]}> -// CHECK-DAG: #[[$BLOCKED1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 4, 2], order = [3, 0, 1, 2]}> - +// CHECK: #[[$ATTR_20:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 8, 2, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 4], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_21:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 8, 2, 1, 1], warpsPerCTA = [1, 1, 1, 2, 4], order = [0, 1, 2, 3, 4]}> +// CHECK: #[[$ATTR_22:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 4], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_23:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [0, 1]}> +// CHECK: #[[$ATTR_24:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2], A = [16, 8], B = [8, 32], C = [16, 32]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { -// CHECK: tt.func @test( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_14]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<64x64xf32, #[[$ATTR_14]]> -> tensor<16x1x4x16x2x2x1xf32, #[[$ATTR_12]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ +// CHECK-LABEL: tt.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_24]]>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_24]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<64x64xf32, #[[$ATTR_24]]> -> tensor<16x8x2x2x2x1x4xf32, #[[$ATTR_20]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x1x4x16x2x2x1xf32, #[[$ATTR_12]]>) -> tensor<16x1x4x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>> +// CHECK: }) : (tensor<16x8x2x2x2x1x4xf32, #[[$ATTR_20]]>) -> tensor<16x8x2x2x1x4xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_20]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x1x4x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>>) -> tensor<16x1x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>}>> -> tensor<16x1x4x16x2xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x1x4x16x2xf32, #[[$BLOCKED]]> -> tensor<16x1x4x32xf32, #[[$ATTR_13]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ -// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): -// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 -// CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x1x4x32xf32, #[[$ATTR_13]]>) -> tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_13]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_13]]}>> -> tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x1x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> -// CHECK: tt.return %[[VAL_16]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: }) : (tensor<16x8x2x2x1x4xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_20]]}>>) -> tensor<16x8x2x2x4xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x8x2x2x4xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> -> tensor<16x8x2x2x4xf32, #[[$ATTR_21]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x4xf32, #[[$ATTR_21]]> -> tensor<16x16x2x4xf32, #[[$ATTR_22]]> +// CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): +// CHECK: %[[VAL_15:.*]] = arith.maxnumf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: tt.reduce.return %[[VAL_15]] : f32 +// CHECK: }) : (tensor<16x16x2x4xf32, #[[$ATTR_22]]>) -> tensor<16x2x4xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_22]]}>> +// CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): +// CHECK: %[[VAL_19:.*]] = arith.maxnumf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: tt.reduce.return %[[VAL_19]] : f32 +// CHECK: }) : (tensor<16x2x4xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_22]]}>>) -> tensor<16x4xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_22]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x4xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_22]]}>}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_23]]}>> +// CHECK: %[[VAL_21:.*]] = triton_gpu.convert_layout %[[VAL_20]] : tensor<64xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_23]]}>> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_24]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_24]]}>> // CHECK: } tt.func @test(%arg0: tensor<64x64xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -268,77 +255,48 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : }) : (tensor<64x64xf32, #mma>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } - -// CHECK: tt.func @test_repeat_layout( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32, #[[$ATTR_14]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<128x128xf32, #[[$ATTR_14]]> -> tensor<16x1x8x16x2x2x2xf32, #[[$ATTR_12]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x1x8x16x2x2x2xf32, #[[$ATTR_12]]>) -> tensor<16x1x8x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>> -// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x1x8x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>>) -> tensor<16x1x8x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x1x8x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$ATTR_12]]}>}>> -> tensor<16x1x8x16x2xf32, #[[$BLOCKED]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x1x8x16x2xf32, #[[$BLOCKED]]> -> tensor<16x1x8x32xf32, #[[$ATTR_13]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ -// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): -// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 -// CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x1x8x32xf32, #[[$ATTR_13]]>) -> tensor<16x1x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_13]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x1x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$ATTR_13]]}>> -> tensor<16x1x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x1x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED1]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> -// CHECK: tt.return %[[VAL_16]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_14]]}>> -// CHECK: } - tt.func @test_repeat_layout(%arg0: tensor<128x128xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { - %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<128x128xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - } } // ----- // Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension with repCluster[0] = 4. -// CHECK-DAG: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [4, 2], A = [32, 8], B = [8, 32], C = [32, 32]}> -// CHECK-DAG: #[[$BLOCKED_EW:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 4, 1, 1, 2, 1], order = [3, 4, 5, 6, 0, 1, 2]}> -// CHECK-DAG: #[[$BLOCKED_RED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 4, 2], order = [3, 0, 1, 2]}> -// CHECK-DAG: #[[$BLOCKED_TRANS:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 4, 1, 2], order = [3, 4, 0, 1, 2]}> -// CHECK-DAG: #[[$BLOCKED_FINAL:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 4, 2], order = [3, 0, 1, 2]}> - +// CHECK: #[[$ATTR_25:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 8, 2, 4, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 4], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_26:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 2, 1, 1, 1], threadsPerWarp = [1, 4, 4, 1, 1], warpsPerCTA = [1, 1, 1, 2, 4], order = [0, 1, 2, 3, 4]}> +// CHECK: #[[$ATTR_27:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 2, 1, 1], threadsPerWarp = [1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 4], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_28:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [0, 1]}> +// CHECK: #[[$ATTR_29:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [4, 2], A = [32, 8], B = [8, 32], C = [32, 32]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [4, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { -// CHECK: tt.func @test( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x64xf32, #[[$DPAS]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<128x64xf32, #[[$DPAS]]> -> tensor<16x2x4x16x2x2x1xf32, #[[$BLOCKED_EW]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ +// CHECK-LABEL: tt.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x64xf32, #[[$ATTR_29]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_29]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<128x64xf32, #[[$ATTR_29]]> -> tensor<16x8x2x4x2x1x4xf32, #[[$ATTR_25]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): // CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x2x4x16x2x2x1xf32, #[[$BLOCKED_EW]]>) -> tensor<16x2x4x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>> +// CHECK: }) : (tensor<16x8x2x4x2x1x4xf32, #[[$ATTR_25]]>) -> tensor<16x8x4x2x1x4xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_25]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): // CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x2x4x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>>) -> tensor<16x2x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x2x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>> -> tensor<16x2x4x16x2xf32, #[[$BLOCKED_TRANS]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x2x4x16x2xf32, #[[$BLOCKED_TRANS]]> -> tensor<16x2x4x32xf32, #[[$BLOCKED_RED]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ -// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): -// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 -// CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x2x4x32xf32, #[[$BLOCKED_RED]]>) -> tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>> -> tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> -// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> -// CHECK: tt.return %[[VAL_16]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> +// CHECK: }) : (tensor<16x8x4x2x1x4xf32, #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_25]]}>>) -> tensor<16x8x4x2x4xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_25]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x8x4x2x4xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #[[$ATTR_25]]}>}>> -> tensor<16x8x4x2x4xf32, #[[$ATTR_26]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x4x2x4xf32, #[[$ATTR_26]]> -> tensor<16x32x2x4xf32, #[[$ATTR_27]]> +// CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): +// CHECK: %[[VAL_15:.*]] = arith.maxnumf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: tt.reduce.return %[[VAL_15]] : f32 +// CHECK: }) : (tensor<16x32x2x4xf32, #[[$ATTR_27]]>) -> tensor<32x2x4xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_27]]}>> +// CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): +// CHECK: %[[VAL_19:.*]] = arith.maxnumf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: tt.reduce.return %[[VAL_19]] : f32 +// CHECK: }) : (tensor<32x2x4xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_27]]}>>) -> tensor<32x4xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_27]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<32x4xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_27]]}>}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_28]]}>> +// CHECK: %[[VAL_21:.*]] = triton_gpu.convert_layout %[[VAL_20]] : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #[[$ATTR_28]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_29]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_29]]}>> // CHECK: } tt.func @test(%arg0: tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ @@ -348,37 +306,4 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> tt.return %0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } - -// CHECK: tt.func @test_repeat_layout( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x64xf32, #[[$DPAS]]>) -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<256x64xf32, #[[$DPAS]]> -> tensor<16x2x8x16x2x2x1xf32, #[[$BLOCKED_EW]]> -// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x2x8x16x2x2x1xf32, #[[$BLOCKED_EW]]>) -> tensor<16x2x8x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>> -// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x2x8x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>>) -> tensor<16x2x8x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>> -// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x2x8x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>> -> tensor<16x2x8x16x2xf32, #[[$BLOCKED_TRANS]]> -// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x2x8x16x2xf32, #[[$BLOCKED_TRANS]]> -> tensor<16x2x8x32xf32, #[[$BLOCKED_RED]]> -// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({ -// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): -// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32 -// CHECK: tt.reduce.return %[[VAL_14]] : f32 -// CHECK: }) : (tensor<16x2x8x32xf32, #[[$BLOCKED_RED]]>) -> tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>> -// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>> -> tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> -// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> -// CHECK: tt.return %[[VAL_16]] : tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> -// CHECK: } - tt.func @test_repeat_layout(%arg0: tensor<256x64xf32, #mma>) -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { - %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<256x64xf32, #mma>) -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - } } diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 1d81bc4741..f6aea2e8fa 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -310,44 +310,49 @@ def TritonIntelGPUOptimizeReductionLocality `triton_gpu.convert_layout` operations, e.g.: ```mlir #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2]}> -tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { +tt.func @test(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %1 = arith.addf %arg1, %arg2 : f32 tt.reduce.return %1 : f32 - }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + }) : (tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } ``` Is converted to: ```mlir -#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [3, 4, 5, 6, 0, 1, 2]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [3, 4, 0, 1, 2]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1], order = [3, 0, 1, 2]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16], warpsPerCTA = [1, 1, 1, 1], order = [3, 0, 1, 2]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 2, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [0, 1, 2, 3, 4, 5, 6]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 8, 2, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [0, 1, 2, 3, 4]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1], warpsPerCTA = [1, 1, 1, 1], order = [0, 1, 2, 3]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> -tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { - %0 = tt.reshape %arg0 allow_reorder efficient_layout : tensor<32x32xf32, #mma> -> tensor<16x1x2x16x2x1x1xf32, #blocked> - %1 = "tt.reduce"(%0) <{axis = 6 : i32}> ({ +tt.func @test(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { + %0 = tt.reshape %arg0 allow_reorder efficient_layout : tensor<16x32xf32, #mma> -> tensor<16x8x2x2x1x1x1xf32, #blocked> + %1 = "tt.reduce"(%0) <{axis = 2 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): - %8 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %8 : f32 - }) : (tensor<16x1x2x16x2x1x1xf32, #blocked>) -> tensor<16x1x2x16x2x1xf32, #triton_gpu.slice<{dim = 6, parent = #blocked}>> + %9 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %9 : f32 + }) : (tensor<16x8x2x2x1x1x1xf32, #blocked>) -> tensor<16x8x2x1x1x1xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> %2 = "tt.reduce"(%1) <{axis = 4 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): - %8 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %8 : f32 - }) : (tensor<16x1x2x16x2x1xf32, #triton_gpu.slice<{dim = 6, parent = #blocked}>>) -> tensor<16x1x2x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #blocked}>}>> - %3 = triton_gpu.convert_layout %2 : tensor<16x1x2x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #blocked}>}>> -> tensor<16x1x2x16x1xf32, #blocked1> - %4 = tt.reshape %3 allow_reorder efficient_layout : tensor<16x1x2x16x1xf32, #blocked1> -> tensor<16x1x2x16xf32, #blocked2> - %5 = "tt.reduce"(%4) <{axis = 3 : i32}> ({ + %9 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %9 : f32 + }) : (tensor<16x8x2x1x1x1xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>>) -> tensor<16x8x2x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #blocked}>}>> + %3 = triton_gpu.convert_layout %2 : tensor<16x8x2x1x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 2, parent = #blocked}>}>> -> tensor<16x8x2x1x1xf32, #blocked1> + %4 = tt.reshape %3 allow_reorder efficient_layout : tensor<16x8x2x1x1xf32, #blocked1> -> tensor<16x16x1x1xf32, #blocked2> + %5 = "tt.reduce"(%4) <{axis = 0 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): - %8 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %8 : f32 - }) : (tensor<16x1x2x16xf32, #blocked2>) -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #blocked2}>> - %6 = triton_gpu.convert_layout %5 : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #blocked2}>> -> tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #blocked3}>> - %7 = tt.reshape %6 allow_reorder efficient_layout : tensor<16x1x2xf32, #triton_gpu.slice<{dim = 3, parent = #blocked3}>> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - tt.return %7 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %9 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %9 : f32 + }) : (tensor<16x16x1x1xf32, #blocked2>) -> tensor<16x1x1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %6 = "tt.reduce"(%5) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %9 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %9 : f32 + }) : (tensor<16x1x1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<16x1xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #blocked2}>}>> + %7 = tt.reshape %6 allow_reorder efficient_layout : tensor<16x1xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 0, parent = #blocked2}>}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %8 = triton_gpu.convert_layout %7 : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + tt.return %8 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> } ``` The `tt.reshape` operation is a NOP so that the following `tt.reduce` @@ -355,10 +360,14 @@ tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slic `triton_gpu.convert_layout` performs the actual data movement to prevent within the sub-group reductions so that the following `tt.reduce` performs a reduction within the work-item again. Finally, we convert back to the - expected layout. + original type. Note the order of the operations to go back to the original + type is important: reshape to original shape and set an anchor for the + layout conversion removal pass, and convert to original layout. Note this pass only supports `triton_intel_gpu.dpas` input layouts at the moment, but it should be easily extended. + + See pass implementation for more detailed implementation documentation. }]; let dependentDialects = ["mlir::triton::TritonDialect", diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 914e851b70..1a023d8432 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -53,7 +53,7 @@ namespace { /// <--------------------------------------------------------------------------------> /// repCluster[1] /// <-----------------------------------> - /// execution size + /// executionSize /// <----------------> /// ^ ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^ /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | @@ -66,48 +66,105 @@ namespace { /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | /// ``` - /// Blocked (#triton_gpu.blocked<{sizePerThread = [executionSize, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], 1, 1, warpsPerCTA[1], 1], order = [3, 4, 5, 6, 0, 1, 2]}>): + /// - Shape: [executionSize, + /// repeatCount, + /// repCluster[1], + /// repCluster[0], + /// warpsPerCTA[1], + /// oldShape[1] / (executionSize * repCluster[1] * warpsPerCTA[1]), + /// warpsPerCTA[0]] + /// - Encoding: `#triton_gpu.blocked<{ + /// sizePerThread = [1, repeatCount, repCluster[1], repCluster[0], 1, oldShape[1] / (executionSize * repCluster[1] * warpsPerCTA[1]), 1], + /// threadsPerWarp = [executionSize, 1, 1, 1, 1, 1, 1], + /// warpsPerCTA = [1, 1, 1, 1, warpsPerCTA[1], 1, warpsPerCTA[0]], + /// order = [0, 1, 2, 3, 4, 5, 6]}>`. + /// + /// Notes: + /// - The implicit [1, 0] order translates to taking elements from the + /// original encoding referring to X and Y dimension alternatively when + /// building the block layout. + /// - Dimensions 1, 3 and 6 refer to the original dimension 0 + /// - Dimensions 0, 2, 4 and 5 refer to the original dimension 1 + /// - Order is preserved + /// - We enforce executionSize * repCluster[0] * warpsPerCTA[0] = oldShape[0] /// ``` - /// warpsPerCTA[5] - /// <-------------------------------------------------------------------------------> - /// getShape()[4] - /// <----------------------------------> - /// threadsPerWarp[3] - /// <----------------> - /// ^ ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^ - /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// | sizePerThread[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// | v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// | ..................................................................................| - /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[2] - /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// getShape()[1] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | - /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// sizePerThread[5] + /// <---------------------------------------------------------------------------------- + /// warpsPerCTA[4] + /// <-------------------------------------------------------------------------------> + /// sizePerThread[2] + /// <----------------------------------> + /// threadsPerWarp[0] + /// <----------------> + /// ^ ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^ + /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | sizePerThread[1] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | ..................................................................................| + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[6] + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// sizePerThread[3] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | + /// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | /// ``` - /// So we can reduce on dimensions 6 and 4 to get to: + /// So we can reduce on dimensions 2 and 4 (5 - 1 as we have already squashed + /// dimension 2) to get to: /// ``` - /// warpsPerCTA[3] - /// <-------------------------------------> - /// threadsPerWarp[3] - /// <----------------> - /// ^ ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn ^ - /// | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// | sizePerThread[0] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// | v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// | .......................................| - /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | warpsPerCTA[2] - /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// getShape()[1] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | - /// v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// warpsPerCTA[3] + /// <-------------------------------------> + /// threadsPerWarp[0] + /// <----------------> + /// ^ ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn ^ + /// | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | sizePerThread[1] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | .......................................| + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | warpsPerCTA[4] + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// sizePerThread[2] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | + /// v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | /// ``` /// /// Now on with step 2: After reshaping and layout conversion, we can get to - /// the actual layout optimization we wanted to achieve: - /// Blocked (#triton_gpu.blocked<{sizePerThread = [1, 1, 1, executionSize], threadsPerWarp = [executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], warpsPerCTA[1]], order = [3, 0, 1, 2]}>): + /// the actual layout optimization we wanted to achieve by a simple layout + /// conversion to: + /// - Shape (unchanged): [executionSize, + /// repeatCount, + /// repCluster[0], + /// warpsPerCTA[1], + /// warpsPerCTA[0]] + /// - Encoding: `#triton_gpu.blocked<{ + /// sizePerThread = [executionSize, repeatCount * repCluster[0] / executionSize, 1, 1, 1], + /// threadsPerWarp = [1, executionSize / repCluster[0], repCluster[0], 1, 1], + /// warpsPerCTA = [1, 1, 1, warpsPerCTA[1], warpsPerCTA[0]], + /// order = [0, 1, 2, 3, 4]}>`. + /// Notes: + /// - The layout conversion performs a sub-group transpose by setting + /// sizePerThread[0] to executionSize + /// - sizePerThread[2] = 1 as we know + /// executionSize <= repeatCount * repCluster[0] (pattern application + /// condition: repeatCount * repCluster[0] % executionSize == 0). + /// We could say elements in dimension 2 are moved to dimension 1 to + /// simplify handling. + /// - sizePerThread[1] value is set to keep size per thread + /// - Dimensions 1, 2 and 4 refer to the original dimension 0 + /// - Dimensions 0, and 3 refer to the original dimension 1 + /// - Order is preserved + /// + /// Note at this point the transpose has already taken place. We just need a + /// reshape to be an anchor for this (see layout conversion elimination pass): + /// - Shape (unchanged): [executionSize, + /// repeatCount * repCluster[0], + /// warpsPerCTA[1], + /// warpsPerCTA[0]] + /// - Encoding: `#triton_gpu.blocked<{ + /// sizePerThread = [executionSize, repeatCount * repCluster[0] / executionSize, 1, 1], + /// threadsPerWarp = [1, executionSize, 1, 1], + /// warpsPerCTA = [1, 1, warpsPerCTA[1], warpsPerCTA[0]], + /// order = [0, 1, 2, 3]}>`. /// ``` /// warpsPerCTA[3] /// <------------------------------------> @@ -118,39 +175,44 @@ namespace { /// threadsPerWarp[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[2] /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | /// ``` - /// And on with step 3, after reducing on dimension 3, we'd get: - /// Blocked (#triton_gpu.blocked<{sizePerThread = [1, 1, 1, executionSize], threadsPerWarp = [executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], warpsPerCTA[1]], order = [3, 0, 1, 2]}>): - /// Sliced (#triton_gpu.sliced<{dim = 3, parent = #blocked}>) + /// Notes: + /// - The reshape simplifies the tensor and provides a layout anchor + /// - We can get shape, sizePerThread, threadsPerWarp and warpsPerCTA for + /// dimension 1 by multiplying such values from dimensions 1 and 2 in the + /// old tensor. + /// - Dimensions 1 and 3 refer to the original dimension 0 + /// - Dimensions 0, and 2 refer to the original dimension 1 + /// - Order is preserved + /// And on with step 3, after reducing on dimensions 0 and 1 (2 - 1 as 0 is + /// squashed), we'd get: /// ``` /// ^ t0 ^ /// | t1 | - /// threadsPerWarp[0] | t2 | warpsPerCTA[2] + /// threadsPerWarp[0] | t2 | warpsPerCTA[1] /// | t3 | /// ``` - /// Reshaping from this layout to the final state would not work, as we would - /// end up modifying the number of elements per work-item (not allowed in - /// `reshape`). - /// - /// In order to avoid that, we can just convert the layout to a sliced layout - /// equivalent to the end product we want to achieve: - /// Blocked (#triton_gpu.blocked<{sizePerThread = [1, 1, 1, executionSize], threadsPerWarp = [executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], warpsPerCTA[1]], order = [3, 0, 1, 2]}>) - /// Sliced (#triton_gpu.sliced<{dim = 3, parent = #blocked}>) + /// Now we can reshape to provide an anchor and go back to the original + /// result shape (back to a 1D tensor): /// ``` /// ^ t0 ^ - /// | t0 | - /// threadsPerWarp[0] | t0 | warpsPerCTA[2] - /// | t0 | + /// | t1 | + /// threadsPerWarp[0] | t2 | warpsPerCTA[0] + /// | t3 | /// ``` - /// And just reshape to the final type using a NOP `reshape`. + /// And untranspose with a layout conversion to the original layout. // clang-format on struct DpasOperandPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; + // Original reduction static constexpr int preferredNonReductionAxis = 0; - static constexpr int finalReductionAxis = 3; static constexpr int preferredReductionAxis = 1; - static constexpr int repCountReshapedAxis = 4; - static constexpr int withinWarpXAxisReshapedAxis = 6; + + // Intermediate reductions + static constexpr int finalElementwiseReductionAxis = 0; + static constexpr int finalWarpsReductionAxis = 1; + static constexpr int innerElementwiseReductionAxis = 2; + static constexpr int outerElementwiseReductionAxis = 4; LogicalResult matchAndRewrite(ReduceOp op, PatternRewriter &rewriter) const final { @@ -183,107 +245,112 @@ struct DpasOperandPattern final : OpRewritePattern { 0) return failure(); + // The encoding should cover the Y axis. + if (encoding.getRepeatCount() * encoding.getRepCluster()[0] * + encoding.getWarpsPerCTA()[0] != + type.getShape()[0]) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "Optimizing reduction: " << op << "\n"); - operand = reshapeForElementWiseReduction(op, rewriter); + operand = reshapeForElementWiseReduction(op, rewriter, encoding); LLVM_DEBUG(llvm::dbgs() << "Reshaped for elementwise reduction: " << operand << "\n"); - operand = performElementWiseReductionAcrossRepCounts(op, rewriter, operand); - - LLVM_DEBUG(llvm::dbgs() - << "Performed elementwise reduction across repCount: " << operand - << "\n"); - - operand = performElementWiseReductionWithinRepCount(op, rewriter, operand); + operand = performInitialElementWiseReductions(op, rewriter, operand); - LLVM_DEBUG(llvm::dbgs() - << "Performed elementwise reduction within repCount: " << operand - << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Performed initial elementwise reductions: " + << operand << "\n"); - operand = convertLayoutForFinalReduction(op, rewriter, operand); + operand = convertLayoutForFinalReduction(op, rewriter, operand, encoding); LLVM_DEBUG(llvm::dbgs() << "Converted layout for final reduction: " << operand << "\n"); - operand = reshapeForFinalReduction(op, rewriter, operand); + operand = reshapeForFinalReduction(op, rewriter, operand, encoding); LLVM_DEBUG(llvm::dbgs() << "Reshaped for final reduction: " << operand << "\n"); - operand = performFinalReduction(op, rewriter, operand); + operand = performFinalElementwiseReduction(op, rewriter, operand); LLVM_DEBUG(llvm::dbgs() - << "Final reduction performed: " << operand << "\n"); + << "Final elementwise reduction performed: " << operand << "\n"); - operand = convertLayoutToOriginalType(op, rewriter, operand); + operand = performFinalAcrossWarpsReduction(op, rewriter, operand); - LLVM_DEBUG(llvm::dbgs() - << "Converted layout to original type: " << operand << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Final across-warps reduction performed: " + << operand << "\n"); - operand = reshapeToOriginalType(op, rewriter, operand); + operand = reshapeToOriginalType(op, rewriter, operand, encoding); LLVM_DEBUG(llvm::dbgs() << "Reshaped to original type: " << operand << "\n"); + operand = convertLayoutToOriginalType(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() + << "Converted layout to original type: " << operand << "\n"); + rewriter.replaceOp(op, operand); return success(); } private: - Value reshapeForElementWiseReduction(ReduceOp op, - PatternRewriter &rewriter) const { + Value reshapeForElementWiseReduction(ReduceOp op, PatternRewriter &rewriter, + DpasEncodingAttr dpasEncoding) const { assert(op.getOperands().size() == 1 && "Expecting a single operand"); Value val = op.getOperands().front(); auto oldType = cast(val.getType()); ArrayRef oldShape = oldType.getShape(); - auto oldEncoding = cast(oldType.getEncoding()); constexpr size_t rank = 7; std::array shape{ - // Y axis contiguous elements handled by a single thread. - oldEncoding.getExecutionSize(), - // Y axis contiguous elements handled by a single thread. - // Needs to be split from previous dimension to perform transpose. - (oldEncoding.getRepeatCount() * oldEncoding.getRepCluster()[0]) / - oldEncoding.getExecutionSize(), - // Y axis rest. - oldShape[0] / - (oldEncoding.getRepeatCount() * oldEncoding.getRepCluster()[0]), - // X axis contiguous elements distributed within individual threads in a - // warp. - oldEncoding.getExecutionSize(), - // X axis contiguous elements distributed within a warp. - oldEncoding.getRepCluster()[1], - // X axis number of warps. - oldEncoding.getWarpsPerCTA()[1], - // X axis rest. + dpasEncoding.getExecutionSize(), + dpasEncoding.getRepeatCount(), + dpasEncoding.getRepCluster()[1], + dpasEncoding.getRepCluster()[0], + dpasEncoding.getWarpsPerCTA()[1], oldShape[1] / - (oldEncoding.getExecutionSize() * oldEncoding.getRepCluster()[1] * - oldEncoding.getWarpsPerCTA()[1])}; + (dpasEncoding.getExecutionSize() * dpasEncoding.getRepCluster()[1] * + dpasEncoding.getWarpsPerCTA()[1]), + dpasEncoding.getWarpsPerCTA()[0]}; std::array sizePerThread{ - oldEncoding.getExecutionSize(), 1, 1, 1, 1, 1, 1}; - std::array threadsPerWarp{ - 1, 1, 1, oldEncoding.getExecutionSize(), 1, 1, 1}; - std::array warpsPerCTA{ - 1, 1, oldEncoding.getWarpsPerCTA()[0], - 1, 1, oldEncoding.getWarpsPerCTA()[1], + 1, + dpasEncoding.getRepeatCount(), + dpasEncoding.getRepCluster()[1], + dpasEncoding.getRepCluster()[0], + 1, + static_cast(oldShape[1]) / + (dpasEncoding.getExecutionSize() * dpasEncoding.getRepCluster()[1] * + dpasEncoding.getWarpsPerCTA()[1]), 1}; - std::array order{3, 4, 5, 6, 0, 1, 2}; + std::array threadsPerWarp{ + dpasEncoding.getExecutionSize(), 1, 1, 1, 1, 1, 1}; + std::array warpsPerCTA{1, + 1, + 1, + 1, + dpasEncoding.getWarpsPerCTA()[1], + 1, + dpasEncoding.getWarpsPerCTA()[0]}; + constexpr std::array order{0, 1, 2, 3, 4, 5, 6}; CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); auto encoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); - RankedTensorType type = - RankedTensorType::get(shape, oldType.getElementType(), encoding); + RankedTensorType::Builder type(oldType); + type.setShape(shape); + type.setEncoding(encoding); // Although this is a NOP, we have to pass allow_reorder=true as static // analysis will fail to infer it. - return rewriter.create(op.getLoc(), type, val, + return rewriter.create(op.getLoc(), + static_cast(type), val, /*allow_reorder=*/true, /*efficient_layout=*/true); } @@ -300,124 +367,129 @@ struct DpasOperandPattern final : OpRewritePattern { return newOp.getResult().front(); } - Value performElementWiseReductionWithinRepCount(ReduceOp op, - PatternRewriter &rewriter, - Value val) const { - return performReduction(op, rewriter, val, /*axis=*/repCountReshapedAxis); - } - - Value performElementWiseReductionAcrossRepCounts(ReduceOp op, - PatternRewriter &rewriter, - Value val) const { - return performReduction(op, rewriter, val, - /*axis=*/withinWarpXAxisReshapedAxis); + Value performInitialElementWiseReductions(ReduceOp op, + PatternRewriter &rewriter, + Value val) const { + return performReduction( + op, rewriter, + performReduction(op, rewriter, val, + /*axis=*/innerElementwiseReductionAxis), + outerElementwiseReductionAxis); } Value convertLayoutForFinalReduction(ReduceOp op, PatternRewriter &rewriter, - Value val) const { - assert(op.getOperands().size() == 1 && "Expecting a single operand"); - + Value val, + DpasEncodingAttr dpasEncoding) const { auto oldType = cast(val.getType()); - auto dpasEncoding = cast( - cast(op.getOperands().front().getType()) - .getEncoding()); + RankedTensorType::Builder type(oldType); constexpr size_t rank = 5; - ArrayRef shape = oldType.getShape(); std::array sizePerThread{ - 1, 1, 1, dpasEncoding.getExecutionSize(), 1}; - std::array threadsPerWarp{dpasEncoding.getExecutionSize(), - 1, 1, 1, 1}; - std::array warpsPerCTA{1, 1, - dpasEncoding.getWarpsPerCTA()[0], 1, - dpasEncoding.getWarpsPerCTA()[1]}; - std::array order{3, 4, 0, 1, 2}; + dpasEncoding.getExecutionSize(), + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + 1, 1, 1}; + std::array threadsPerWarp{ + 1, dpasEncoding.getExecutionSize() / dpasEncoding.getRepCluster()[0], + dpasEncoding.getRepCluster()[0], 1, 1}; + std::array warpsPerCTA{1, 1, 1, + dpasEncoding.getWarpsPerCTA()[1], + dpasEncoding.getWarpsPerCTA()[0]}; + constexpr std::array order{0, 1, 2, 3, 4}; CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); auto encoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); - RankedTensorType type = - RankedTensorType::get(shape, oldType.getElementType(), encoding); + type.setEncoding(encoding); - return rewriter.create(op.getLoc(), type, val); + return rewriter.create( + op.getLoc(), static_cast(type), val); } Value reshapeForFinalReduction(ReduceOp op, PatternRewriter &rewriter, - Value val) const { + Value val, + DpasEncodingAttr dpasEncoding) const { auto oldType = cast(val.getType()); ArrayRef oldShape = oldType.getShape(); - auto oldEncoding = cast(oldType.getEncoding()); constexpr size_t rank = 4; - std::array shape{oldShape[0], oldShape[1], oldShape[2], - oldShape[3] * oldShape[4]}; - std::array sizePerThread{1, 1, 1, - oldEncoding.getSizePerThread()[3]}; + std::array shape{ + dpasEncoding.getExecutionSize(), + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0], + dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]}; + std::array sizePerThread{ + dpasEncoding.getExecutionSize(), + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + 1, 1}; std::array threadsPerWarp{ - oldEncoding.getThreadsPerWarp()[0], 1, 1, 1}; - std::array warpsPerCTA{ - 1, 1, oldEncoding.getWarpsPerCTA()[2], oldEncoding.getWarpsPerCTA()[4]}; - std::array order{3, 0, 1, 2}; + 1, dpasEncoding.getExecutionSize(), 1, 1}; + std::array warpsPerCTA{1, 1, + dpasEncoding.getWarpsPerCTA()[1], + dpasEncoding.getWarpsPerCTA()[0]}; + constexpr std::array order{0, 1, 2, 3}; CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); auto encoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); - RankedTensorType type = - RankedTensorType::get(shape, oldType.getElementType(), encoding); + RankedTensorType::Builder type(oldType); + type.setShape(shape); + type.setEncoding(encoding); // Although this is a NOP, we have to pass allow_reorder=true as static // analysis will fail to infer it. - return rewriter.create(op.getLoc(), type, val, + return rewriter.create(op.getLoc(), + static_cast(type), val, /*allow_reorder=*/true, /*efficient_layout=*/true); } - Value performFinalReduction(ReduceOp op, PatternRewriter &rewriter, - Value val) const { - return performReduction(op, rewriter, val, /*axis=*/finalReductionAxis); + Value performFinalElementwiseReduction(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + return performReduction(op, rewriter, val, + /*axis=*/finalElementwiseReductionAxis); } - Value convertLayoutToOriginalType(ReduceOp op, PatternRewriter &rewriter, - Value val) const { - auto oldType = cast(val.getType()); - auto dpasEncoding = cast( - cast(op.getOperands().front().getType()) - .getEncoding()); - - // Only Y axis (X axis has already been reduced) - constexpr size_t rankBeforeLastReduction = 4; - ArrayRef shape = oldType.getShape(); - std::array sizePerThread{ - dpasEncoding.getExecutionSize(), 1, 1, 1}; - std::array threadsPerWarp{ - 1, 1, 1, dpasEncoding.getExecutionSize()}; - std::array warpsPerCTA{ - 1, 1, dpasEncoding.getWarpsPerCTA()[0], - dpasEncoding.getWarpsPerCTA()[1]}; - std::array order{3, 0, 1, 2}; - CTALayoutAttr ctaLayout = - CTALayoutAttr::getDefault(getContext(), rankBeforeLastReduction); - - auto blockedEncoding = rewriter.getAttr( - sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); - auto encoding = rewriter.getAttr(finalReductionAxis, - blockedEncoding); + Value performFinalAcrossWarpsReduction(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + return performReduction(op, rewriter, val, + /*axis=*/finalWarpsReductionAxis); + } - RankedTensorType type = - RankedTensorType::get(shape, oldType.getElementType(), encoding); + Value reshapeToOriginalType(ReduceOp op, PatternRewriter &rewriter, Value val, + DpasEncodingAttr dpasEncoding) const { + RankedTensorType::Builder type( + cast(op.getResult().front().getType())); - return rewriter.create(op.getLoc(), type, val); - } + constexpr size_t rank = 2; + std::array sizePerThread{ + 1, dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize()}; + std::array threadsPerWarp{1, + dpasEncoding.getExecutionSize()}; + std::array warpsPerCTA{dpasEncoding.getWarpsPerCTA()[1], + dpasEncoding.getWarpsPerCTA()[0]}; + constexpr std::array order{0, 1}; + CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + + auto parentEncoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + type.setEncoding(parentEncoding.squeeze(0)); - Value reshapeToOriginalType(ReduceOp op, PatternRewriter &rewriter, - Value val) const { return rewriter.create(op.getLoc(), - op.getResult().front().getType(), val, + static_cast(type), val, /*allow_reorder=*/true, /*efficient_layout=*/true); } + + Value convertLayoutToOriginalType(ReduceOp op, PatternRewriter &rewriter, + Value val) const { + return rewriter.create( + op.getLoc(), op.getResult().front().getType(), val); + } }; struct TritonIntelGPUOptimizeReductionLocality final