Skip to content

Commit 053938e

Browse files
committed
add complex control flow test
1 parent a1cc374 commit 053938e

File tree

1 file changed

+75
-4
lines changed

1 file changed

+75
-4
lines changed

test/TritonIntelGPU/optimize-block-io-encoding.mlir

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
// RUN: triton-opt %s -split-input-file --tritonintelgpu-optimize-block-io-encoding | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --tritonintelgpu-optimize-block-io-encoding | FileCheck %s
22

33
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
44
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
55
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
6-
// CHECK: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}>
7-
// CHECK: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}>
8-
// CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
6+
// CHECK-DAG: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}>
7+
// CHECK-DAG: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}>
8+
// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
99
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
1010
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} {
1111
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
@@ -65,3 +65,74 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
6565
tt.return
6666
}
6767
}
68+
69+
// -----
70+
71+
// COM: test complex control flow
72+
// COM: Note that instead of using tt.advance we make a new tensor ptr each time. This is nice, because it lets us test that we can find MakeTensorPtr op inside the scf.if.
73+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
74+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
75+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
76+
// CHECK-DAG: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}>
77+
// CHECK-DAG: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}>
78+
// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
79+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
80+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} {
81+
// CHECK-LABEL: @matmul_change_block_ptr_in_prologue
82+
tt.func @matmul_change_block_ptr_in_prologue(%a_base: !tt.ptr<f16>,
83+
%b_base: !tt.ptr<f16>) {
84+
%c0_i64 = arith.constant 0 : i64
85+
%c1_i64 = arith.constant 1 : i64
86+
%k_tiles = arith.constant 32 : i64
87+
%true = arith.constant true
88+
%false = arith.constant false
89+
90+
%zero = arith.constant dense<0.0> : tensor<128x128xf32, #blocked>
91+
92+
// CHECK: %[[A_UNDEF:.*]] = ub.poison : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>
93+
// CHECK: %[[B_UNDEF:.*]] = ub.poison : !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>>
94+
%a_ptr_undef = ub.poison : !tt.ptr<tensor<128x64xf16, #blocked1>>
95+
%b_ptr_undef = ub.poison : !tt.ptr<tensor<64x128xf16, #blocked2>>
96+
// CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[A_PTR:.*]] = %[[A_UNDEF]], %[[B_PTR:.*]] = %[[B_UNDEF]])
97+
scf.for %k = %c0_i64 to %k_tiles step %c1_i64 iter_args(%acc = %zero, %flag = %true, %a_ptr = %a_ptr_undef, %b_ptr = %b_ptr_undef) -> (tensor<128x128xf32, #blocked>, i1, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>>) : i64 {
98+
%do_prologue = "prologue_cond"(%k) : (i64) -> i1
99+
// CHECK: %[[PTRS:.*]]:2 = scf.if {{.*}} -> (!tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>>)
100+
%cur_a_ptr, %cur_b_ptr = scf.if %do_prologue -> (!tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>>) {
101+
%off_m, %off_n, %off_k = "get_offsets"(%k) : (i64) -> (i32, i32, i32)
102+
// CHECK tt.make_tensor_ptr {{.*}} : <tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>
103+
%next_a_ptr = tt.make_tensor_ptr %a_base, [%k, %k], [%c1_i64, %c1_i64], [%off_m, %off_k] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>>
104+
// CHECK tt.make_tensor_ptr {{.*}} : <tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>>
105+
%next_b_ptr = tt.make_tensor_ptr %b_base, [%k, %k], [%c1_i64, %c1_i64], [%off_n, %off_k] {order = array<i32: 1, 0>} : <tensor<64x128xf16, #blocked2>>
106+
// CHECK: scf.yield {{.*}} : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>>
107+
scf.yield %next_a_ptr, %next_b_ptr : !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>>
108+
} else {
109+
// CHECK: scf.yield {{.*}} : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>>
110+
scf.yield %a_ptr, %b_ptr : !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>>
111+
}
112+
113+
// CHECK: %[[A:.*]] = tt.load %[[PTRS]]#0 {{.*}} : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>
114+
%a = tt.load %cur_a_ptr {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #blocked1>>
115+
// CHECK: ttg.convert_layout %[[A]] : tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]> -> tensor<128x64xf16, #blocked1>
116+
// CHECK: %[[B:.*]] = tt.load %[[PTRS]]#1 {{.*}} : !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>>
117+
%b = tt.load %cur_b_ptr {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x128xf16, #blocked2>>
118+
// CHECK: {{.*}} = ttg.convert_layout %[[B]] : tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]> -> tensor<64x128xf16, #blocked2>
119+
%a_dot = ttg.convert_layout %a : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
120+
%b_dot = ttg.convert_layout %b : tensor<64x128xf16, #blocked2> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
121+
%a_dot_dpas = ttg.convert_layout %a_dot : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
122+
%b_dot_dpas = ttg.convert_layout %b_dot : tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
123+
%accum = ttg.convert_layout %acc : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma>
124+
%c = tt.dot %a_dot_dpas, %b_dot_dpas, %accum, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
125+
%c_out = ttg.convert_layout %c : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
126+
127+
%do_epilogue = arith.cmpi eq, %k, %c0_i64 : i64
128+
%use_acc = arith.select %do_epilogue, %false, %true : i1
129+
scf.if %do_epilogue {
130+
"acc_user"(%c_out) : (tensor<128x128xf32, #blocked>) -> ()
131+
}
132+
// CHECK: scf.yield {{.*}} : {{.*}}, i1, !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>>
133+
scf.yield %c_out, %use_acc, %cur_a_ptr, %cur_b_ptr : tensor<128x128xf32, #blocked>, i1, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>>
134+
}
135+
136+
tt.return
137+
}
138+
}

0 commit comments

Comments
 (0)