Skip to content

Commit b01b9bb

Browse files
committed
validate scf for loop type changes
1 parent 8703b9a commit b01b9bb

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

test/TritonIntelGPU/optimize-block-io-encoding.mlir

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --tritonintelgpu-optimize-block-io-encoding | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file --tritonintelgpu-optimize-block-io-encoding | FileCheck %s
22

33
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
44
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
@@ -8,7 +8,7 @@
88
// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
99
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
1010
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} {
11-
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) attributes {noinline = false} {
11+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
1212
%c4_i32 = arith.constant 4 : i32
1313
%c256_i32 = arith.constant 256 : i32
1414
%c1024_i64 = arith.constant 1024 : i64
@@ -30,17 +30,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.tar
3030
%7 = arith.remsi %0, %c64_i32 : i32
3131
%8 = arith.divsi %7, %4 : i32
3232
%9 = arith.muli %6, %c256_i32 : i32
33-
// CHECK: tt.make_tensor_ptr {{.*}} : <tensor<256x32xf16, #mma>>
33+
// CHECK: %[[MAKE_TENSOR_PTR_A:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xf16, #mma>>
3434
%10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #blocked1>>
3535
%11 = arith.muli %8, %c256_i32 : i32
36-
// CHECK: tt.make_tensor_ptr {{.*}} : <tensor<32x256xf16, #mma1>>
36+
// CHECK: %[[MAKE_TENSOR_PTR_B:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xf16, #mma1>>
3737
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked2>>
38+
// CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[ARG5:.*]] = %[[MAKE_TENSOR_PTR_A]], %[[ARG6:.*]] = %[[MAKE_TENSOR_PTR_B]])
3839
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #blocked1>>, !tt.ptr<tensor<32x256xf16, #blocked2>>) : i32 {
3940
%17 = tt.load %arg5 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #blocked1>>
40-
// CHECK: %[[A_LOAD:.*]] = tt.load %arg5 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #mma>>
41+
// CHECK: %[[A_LOAD:.*]] = tt.load %[[ARG5]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #mma>>
4142
// CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1>
4243
%18 = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked2>>
43-
// CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #mma1>>
44+
// CHECK: %[[B_LOAD:.*]] = tt.load %[[ARG6]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #mma1>>
4445
// CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2>
4546
%19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
4647
%20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
@@ -50,10 +51,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.tar
5051
// CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2>
5152
%24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
5253
%25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked>
53-
// CHECK: tt.advance {{.*}} : <tensor<256x32xf16, #mma>>
54+
// CHECK: %[[ADVANCE_A:.*]] = tt.advance {{.*}} : <tensor<256x32xf16, #mma>>
5455
%26 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>>
55-
// CHECK: tt.advance {{.*}} : <tensor<32x256xf16, #mma1>>
56+
// CHECK: %[[ADVANCE_B:.*]] = tt.advance {{.*}} : <tensor<32x256xf16, #mma1>>
5657
%27 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked2>>
58+
// CHECK: scf.yield {{.*}}, %[[ADVANCE_A]], %[[ADVANCE_B]]
5759
scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #blocked1>>, !tt.ptr<tensor<32x256xf16, #blocked2>>
5860
}
5961
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #blocked2>>

0 commit comments

Comments
 (0)