Skip to content

Commit 2bf2c13

Browse files
[TritonGENToLLVM] Adjust base pointer in 2D block IO
1 parent d9fffb7 commit 2bf2c13

File tree

8 files changed

+165
-143
lines changed

8 files changed

+165
-143
lines changed

test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
4646
// CHECK-NEXT: [[INSERT1:%.*]] = llvm.insertelement {{.*}}, [[INSERT0]][[[ONE]] : i32] : vector<2xi32>
4747
%14 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x32xf16>, 1>
4848

49-
// CHECK: llvm.call spir_funccc @_Z36__spirv_Subgroup2DBlockPrefetchINTELiiiiPU3AS1viiiDv2_i({{.*}}, %arg0, {{.*}})
49+
// CHECK: llvm.call spir_funccc @_Z36__spirv_Subgroup2DBlockPrefetchINTELiiiiPU3AS1viiiDv2_i({{.*}})
5050
ttig.prefetch %14 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<8x32xf16>, 1>
5151
%18 = arith.divsi %1, %c4_i32 : i32
5252
%19 = arith.andi %18, %c7_i32 : i32
@@ -65,13 +65,13 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
6565
%62 = arith.cmpi slt, %40, %c4096_i32 : i32
6666
cf.cond_br %62, ^bb2, ^bb3
6767
^bb2:
68-
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, %arg0, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[A_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
68+
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, [[A_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
6969
// CHECK: [[A:%.*]] = llvm.load [[A_PTR]] : !llvm.ptr -> vector<64xi16>
7070
// CHECK-NEXT: [[castA:%.*]] = llvm.bitcast [[A]] : vector<64xi16> to vector<64xf16>
71-
// CHECK: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, %arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[B_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
71+
// CHECK: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, [[B_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
7272
// CHECK: [[B0:%.*]] = llvm.load [[B_PTR]] : !llvm.ptr -> vector<32xi32>
7373
// CHECK-NEXT: [[castB:%.*]] = llvm.bitcast [[B0]] : vector<32xi32> to vector<64xf16>
74-
// CHECK: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, %arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[B_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
74+
// CHECK: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, [[B_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
7575
// CHECK: [[B1:%.*]] = llvm.load [[B_PTR]] : !llvm.ptr -> vector<32xi32>
7676
// CHECK: [[subA1:%.*]] = llvm.shufflevector [[castA]], [[castA]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<64xf16>
7777
// CHECK: [[subB1:%.*]] = llvm.shufflevector [[castB]], [[castB]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<64xf16>
@@ -102,7 +102,7 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
102102
cf.br ^bb1(%119, %71, %115, %117, %118 : i32, tensor<8x16xf32>, !tt.ptr<tensor<32x32xf16>, 1>, !tt.ptr<tensor<32x32xf16>, 1>, !tt.ptr<tensor<32x32xf16>, 1>)
103103
^bb3:
104104
%120 = tt.make_tensor_ptr %arg2, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%21, %36] {order = array<i32: 1, 0>} : <tensor<8x16xf32>, 1>
105-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %arg2, {{.*}}
105+
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}})
106106
tt.store %120, %41 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<8x16xf32>, 1>
107107
tt.return
108108
}
@@ -120,9 +120,9 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
120120
%c0_i32 = arith.constant 0 : i32
121121
%0 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x8xf32>>
122122
%1 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x16xf32>>
123-
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, %arg0, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
123+
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
124124
%2 = tt.load %0 {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
125-
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, %arg0, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
125+
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
126126
%3 = tt.load %1 {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
127127
tt.return
128128
}
@@ -140,7 +140,7 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
140140
%c0_i32 = arith.constant 0 : i32
141141
%cst = arith.constant dense<0.000000e+00> : tensor<8x16xf16>
142142
%0 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x16xf16>>
143-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %arg0, {{.*}})
143+
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}})
144144
tt.store %0, %cst {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf16>>
145145
tt.return
146146
}

0 commit comments

Comments
 (0)