1
- // RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm --cse | FileCheck %s --implicit-check-not=llvm.inline_asm
1
+ // RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
2
2
3
- // CHECK: llvm.func spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv
4
3
#mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [2 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
5
4
module attributes {ttig.min_sg_size = 16 : i32 , ttig.support_bf16_conversion , ttig.support_dpas , ttig.support_sg_2d_block , ttig.target_arch = " spir64" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.shared = 33280 : i32 , ttg.target = " xpu" , " ttg.threads-per-warp" = 16 : i32 } {
6
5
tt.func public @matmul_tensor_pointer_kernel (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: !llvm.ptr <3 >) attributes {noinline = false } {
@@ -57,10 +56,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
57
56
%65 = tt.splat %64 : i32 -> tensor <1 x64 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
58
57
%66 = arith.cmpi slt , %38 , %65 : tensor <1 x64 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
59
58
%67 = tt.broadcast %66 : tensor <1 x64 xi1 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> -> tensor <128 x64 xi1 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
60
- // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
61
- // CHECK: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
62
- // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
63
- // CHECK-COUNT-16: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C8]], [[C2]], {{.*}})
59
+ // CHECK-COUNT-16: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 2
64
60
%68 = tt.load %60 , %67 , %cst_3 {ttig.block_io = " row_major" } : tensor <128 x64 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
65
61
%74 = tt.addptr %60 , %cst_0 : tensor <128 x64 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>, tensor <128 x64 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
66
62
%76 = arith.addi %58 , %c1_i32 : i32
@@ -72,7 +68,6 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
72
68
73
69
// -----
74
70
75
- // CHECK: llvm.func spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv
76
71
#mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [2 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
77
72
module attributes {ttig.min_sg_size = 16 : i32 , ttig.support_bf16_conversion , ttig.support_dpas , ttig.support_sg_2d_block , ttig.target_arch = " spir64" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.shared = 33280 : i32 , ttg.target = " xpu" , " ttg.threads-per-warp" = 16 : i32 } {
78
73
tt.func public @matmul_tensor_pointer_kernel (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: !llvm.ptr <3 >) attributes {noinline = false } {
@@ -129,11 +124,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
129
124
%69 = tt.splat %64 : i32 -> tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
130
125
%70 = arith.cmpi slt , %45 , %69 : tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
131
126
%71 = tt.broadcast %70 : tensor <64 x1 xi1 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x256 xi1 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
132
- // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
133
- // CHECK: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
134
- // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
135
- // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
136
- // CHECK-COUNT-8: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C32]], [[C1]], {{.*}})
127
+ // CHECK-COUNT-8: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 1
137
128
%72 = tt.load %61 , %71 , %cst_4 {ttig.block_io = " row_major" } : tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
138
129
%75 = tt.addptr %61 , %57 : tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
139
130
%76 = arith.addi %58 , %c1_i32 : i32
@@ -154,31 +145,17 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
154
145
%arg1: tensor <256 x64 x!tt.ptr <f16 >, #mma_1 >,
155
146
%arg2: tensor <128 x64 x!tt.ptr <f16 >, #mma_2 >,
156
147
%arg3: tensor <256 x64 x!tt.ptr <f16 >, #mma_2 >) {
157
- // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
158
- // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
159
- // CHECK: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
160
-
161
- // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
162
- // CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
163
- // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
164
- // CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
165
- // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
166
- // CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
167
- // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
168
- // CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
148
+ // CHECK-COUNT-4: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
169
149
%0 = tt.load %arg0 {ttig.block_io = " row_major" } : tensor <256 x64 x!tt.ptr <f16 >, #mma >
170
150
171
- // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
172
- // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
173
-
174
- // CHECK-COUNT-16: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C8]], [[C1]], {{.*}})
151
+ // CHECK-COUNT-16: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1
175
152
%1 = tt.load %arg1 {ttig.block_io = " row_major" } : tensor <256 x64 x!tt.ptr <f16 >, #mma_1 >
176
153
177
- // CHECK-COUNT-2: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C32]], [[C2]], {{.*}})
154
+ // CHECK-COUNT-2: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 2
178
155
%2 = tt.load %arg3 {ttig.block_io = " row_major" } : tensor <256 x64 x!tt.ptr <f16 >, #mma_2 >
179
156
180
157
// COM: The data is duplicated in the warps because the warp shape is 32*8=256 larger than the tensor shape 128
181
- // CHECK-COUNT-2: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C32]], [[C2]], {{.*}})
158
+ // CHECK-COUNT-2: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 2
182
159
%3 = tt.load %arg2 {ttig.block_io = " row_major" } : tensor <128 x64 x!tt.ptr <f16 >, #mma_2 >
183
160
tt.return
184
161
}
0 commit comments