|
1 |
| -// RUN: triton-opt %s -split-input-file --tritonintelgpu-optimize-block-io-encoding | FileCheck %s |
| 1 | +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --tritonintelgpu-optimize-block-io-encoding | FileCheck %s |
2 | 2 |
|
3 | 3 | #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
|
4 | 4 | #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
|
5 | 5 | #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
|
6 |
| -// CHECK: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> |
7 |
| -// CHECK: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> |
8 |
| -// CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> |
| 6 | +// CHECK-DAG: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> |
| 7 | +// CHECK-DAG: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> |
| 8 | +// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> |
9 | 9 | #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
|
10 | 10 | module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} {
|
11 | 11 | tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
|
@@ -65,3 +65,74 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
|
65 | 65 | tt.return
|
66 | 66 | }
|
67 | 67 | }
|
| 68 | + |
| 69 | +// ----- |
| 70 | + |
| 71 | +// COM: test complex control flow |
| 72 | +// COM: Note that instead of using tt.advance we make a new tensor ptr each time. This is nice, because it lets us test that we can find MakeTensorPtr op inside the scf.if. |
| 73 | +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> |
| 74 | +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> |
| 75 | +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> |
| 76 | +// CHECK-DAG: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> |
| 77 | +// CHECK-DAG: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> |
| 78 | +// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> |
| 79 | +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> |
| 80 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { |
| 81 | +// CHECK-LABEL: @matmul_change_block_ptr_in_prologue |
| 82 | +tt.func @matmul_change_block_ptr_in_prologue(%a_base: !tt.ptr<f16>, |
| 83 | + %b_base: !tt.ptr<f16>) { |
| 84 | + %c0_i64 = arith.constant 0 : i64 |
| 85 | + %c1_i64 = arith.constant 1 : i64 |
| 86 | + %k_tiles = arith.constant 32 : i64 |
| 87 | + %true = arith.constant true |
| 88 | + %false = arith.constant false |
| 89 | + |
| 90 | + %zero = arith.constant dense<0.0> : tensor<128x128xf32, #blocked> |
| 91 | + |
| 92 | + // CHECK: %[[A_UNDEF:.*]] = ub.poison : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>> |
| 93 | + // CHECK: %[[B_UNDEF:.*]] = ub.poison : !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> |
| 94 | + %a_ptr_undef = ub.poison : !tt.ptr<tensor<128x64xf16, #blocked1>> |
| 95 | + %b_ptr_undef = ub.poison : !tt.ptr<tensor<64x128xf16, #blocked2>> |
| 96 | + // CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[A_PTR:.*]] = %[[A_UNDEF]], %[[B_PTR:.*]] = %[[B_UNDEF]]) |
| 97 | + scf.for %k = %c0_i64 to %k_tiles step %c1_i64 iter_args(%acc = %zero, %flag = %true, %a_ptr = %a_ptr_undef, %b_ptr = %b_ptr_undef) -> (tensor<128x128xf32, #blocked>, i1, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>>) : i64 { |
| 98 | + %do_prologue = "prologue_cond"(%k) : (i64) -> i1 |
| 99 | + // CHECK: %[[PTRS:.*]]:2 = scf.if {{.*}} -> (!tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>>) |
| 100 | + %cur_a_ptr, %cur_b_ptr = scf.if %do_prologue -> (!tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>>) { |
| 101 | + %off_m, %off_n, %off_k = "get_offsets"(%k) : (i64) -> (i32, i32, i32) |
| 102 | + // CHECK tt.make_tensor_ptr {{.*}} : <tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>> |
| 103 | + %next_a_ptr = tt.make_tensor_ptr %a_base, [%k, %k], [%c1_i64, %c1_i64], [%off_m, %off_k] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>> |
| 104 | + // CHECK tt.make_tensor_ptr {{.*}} : <tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> |
| 105 | + %next_b_ptr = tt.make_tensor_ptr %b_base, [%k, %k], [%c1_i64, %c1_i64], [%off_n, %off_k] {order = array<i32: 1, 0>} : <tensor<64x128xf16, #blocked2>> |
| 106 | + // CHECK: scf.yield {{.*}} : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> |
| 107 | + scf.yield %next_a_ptr, %next_b_ptr : !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>> |
| 108 | + } else { |
| 109 | + // CHECK: scf.yield {{.*}} : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> |
| 110 | + scf.yield %a_ptr, %b_ptr : !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>> |
| 111 | + } |
| 112 | + |
| 113 | + // CHECK: %[[A:.*]] = tt.load %[[PTRS]]#0 {{.*}} : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>> |
| 114 | + %a = tt.load %cur_a_ptr {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #blocked1>> |
| 115 | + // CHECK: ttg.convert_layout %[[A]] : tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]> -> tensor<128x64xf16, #blocked1> |
| 116 | + // CHECK: %[[B:.*]] = tt.load %[[PTRS]]#1 {{.*}} : !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> |
| 117 | + %b = tt.load %cur_b_ptr {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x128xf16, #blocked2>> |
| 118 | + // CHECK: {{.*}} = ttg.convert_layout %[[B]] : tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]> -> tensor<64x128xf16, #blocked2> |
| 119 | + %a_dot = ttg.convert_layout %a : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> |
| 120 | + %b_dot = ttg.convert_layout %b : tensor<64x128xf16, #blocked2> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> |
| 121 | + %a_dot_dpas = ttg.convert_layout %a_dot : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> |
| 122 | + %b_dot_dpas = ttg.convert_layout %b_dot : tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> |
| 123 | + %accum = ttg.convert_layout %acc : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> |
| 124 | + %c = tt.dot %a_dot_dpas, %b_dot_dpas, %accum, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> |
| 125 | + %c_out = ttg.convert_layout %c : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked> |
| 126 | + |
| 127 | + %do_epilogue = arith.cmpi eq, %k, %c0_i64 : i64 |
| 128 | + %use_acc = arith.select %do_epilogue, %false, %true : i1 |
| 129 | + scf.if %do_epilogue { |
| 130 | + "acc_user"(%c_out) : (tensor<128x128xf32, #blocked>) -> () |
| 131 | + } |
| 132 | + // CHECK: scf.yield {{.*}} : {{.*}}, i1, !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> |
| 133 | + scf.yield %c_out, %use_acc, %cur_a_ptr, %cur_b_ptr : tensor<128x128xf32, #blocked>, i1, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>> |
| 134 | + } |
| 135 | + |
| 136 | + tt.return |
| 137 | + } |
| 138 | +} |
0 commit comments