Skip to content

Commit 93ac1c1

Browse files
Remove tritongen lowering from tensor-pointer-load-block-2d.mlir (#4525)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 4b882ad commit 93ac1c1

File tree

1 file changed

+7
-30
lines changed

1 file changed

+7
-30
lines changed

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm --cse | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
22

3-
// CHECK: llvm.func spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv
43
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
54
module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
65
tt.func public @matmul_tensor_pointer_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: !llvm.ptr<3>) attributes {noinline = false} {
@@ -57,10 +56,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
5756
%65 = tt.splat %64 : i32 -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
5857
%66 = arith.cmpi slt, %38, %65 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
5958
%67 = tt.broadcast %66 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<128x64xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
60-
// CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
61-
// CHECK: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
62-
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
63-
// CHECK-COUNT-16: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C8]], [[C2]], {{.*}})
59+
// CHECK-COUNT-16: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 2
6460
%68 = tt.load %60, %67, %cst_3 {ttig.block_io = "row_major"} : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
6561
%74 = tt.addptr %60, %cst_0 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<128x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
6662
%76 = arith.addi %58, %c1_i32 : i32
@@ -72,7 +68,6 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
7268

7369
// -----
7470

75-
// CHECK: llvm.func spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv
7671
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
7772
module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
7873
tt.func public @matmul_tensor_pointer_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: !llvm.ptr<3>) attributes {noinline = false} {
@@ -129,11 +124,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
129124
%69 = tt.splat %64 : i32 -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
130125
%70 = arith.cmpi slt, %45, %69 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
131126
%71 = tt.broadcast %70 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
132-
// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
133-
// CHECK: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
134-
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
135-
// CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
136-
// CHECK-COUNT-8: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C32]], [[C1]], {{.*}})
127+
// CHECK-COUNT-8: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 1
137128
%72 = tt.load %61, %71, %cst_4 {ttig.block_io = "row_major"} : tensor<64x256x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
138129
%75 = tt.addptr %61, %57 : tensor<64x256x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
139130
%76 = arith.addi %58, %c1_i32 : i32
@@ -154,31 +145,17 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
154145
%arg1: tensor<256x64x!tt.ptr<f16>, #mma_1>,
155146
%arg2: tensor<128x64x!tt.ptr<f16>, #mma_2>,
156147
%arg3: tensor<256x64x!tt.ptr<f16>, #mma_2>) {
157-
// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
158-
// CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
159-
// CHECK: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
160-
161-
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
162-
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
163-
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
164-
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
165-
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
166-
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
167-
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
168-
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
148+
// CHECK-COUNT-4: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
169149
%0 = tt.load %arg0 {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma>
170150

171-
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
172-
// CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
173-
174-
// CHECK-COUNT-16: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C8]], [[C1]], {{.*}})
151+
// CHECK-COUNT-16: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1
175152
%1 = tt.load %arg1 {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma_1>
176153

177-
// CHECK-COUNT-2: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C32]], [[C2]], {{.*}})
154+
// CHECK-COUNT-2: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 2
178155
%2 = tt.load %arg3 {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma_2>
179156

180157
// COM: The data is duplicated in the warps because the warp shape is 32*8=256 larger than the tensor shape 128
181-
// CHECK-COUNT-2: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C32]], [[C2]], {{.*}})
158+
// CHECK-COUNT-2: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 2
182159
%3 = tt.load %arg2 {ttig.block_io = "row_major"} : tensor<128x64x!tt.ptr<f16>, #mma_2>
183160
tt.return
184161
}

0 commit comments

Comments
 (0)