Skip to content

Commit 6f5a8a8

Browse files
[TritonGENToLLVM] Adjust base pointer in 2D block IO (#4335)
2D block IO instructions requires base pointer to be 64 bytes aligned. When base pointer is not 64 bytes aligned, we need to make it 64 bytes aligned and adjust base width and offset x accordingly to make it accesses the desire memory. To make a non 64 bytes align pointer to be 64 bytes align, we clear the lower 6 bits. We didn't clear the lower 6 bits before, because it is ignored on PVC, and we can save on the number of instructions, but it is not ignored on BMG, hence the incorrect behavior. We should not rely on a undocumented behavior, and explicitly perform the action in Triton. Fixes #4304 `Build and test B580, Linux` CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15290645209 `Triton benchmarks BMG` CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15308884810 GEMM tensor of pointer benchmark is fixed, no functional regression on other benchmarks. `Triton benchmarks PVC` CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15290677510 Potential improvement on flex attention, no geomean regression on other benchmarks. ![Screenshot 2025-05-28 100851](https://github.com/user-attachments/assets/4e5ab000-015e-461b-8f12-6ad44600de30) --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 9b4d906 commit 6f5a8a8

File tree

5 files changed

+114
-97
lines changed

5 files changed

+114
-97
lines changed

test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
4646
// CHECK-NEXT: [[INSERT1:%.*]] = llvm.insertelement {{.*}}, [[INSERT0]][[[ONE]] : i32] : vector<2xi32>
4747
%14 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x32xf16>, 1>
4848

49-
// CHECK: llvm.call spir_funccc @_Z36__spirv_Subgroup2DBlockPrefetchINTELiiiiPU3AS1viiiDv2_i({{.*}}, %arg0, {{.*}})
49+
// CHECK: llvm.call spir_funccc @_Z36__spirv_Subgroup2DBlockPrefetchINTELiiiiPU3AS1viiiDv2_i({{.*}})
5050
ttig.prefetch %14 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<8x32xf16>, 1>
5151
%18 = arith.divsi %1, %c4_i32 : i32
5252
%19 = arith.andi %18, %c7_i32 : i32
@@ -65,13 +65,13 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
6565
%62 = arith.cmpi slt, %40, %c4096_i32 : i32
6666
cf.cond_br %62, ^bb2, ^bb3
6767
^bb2:
68-
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, %arg0, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[A_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
68+
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, [[A_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
6969
// CHECK: [[A:%.*]] = llvm.load [[A_PTR]] : !llvm.ptr -> vector<64xi16>
7070
// CHECK-NEXT: [[castA:%.*]] = llvm.bitcast [[A]] : vector<64xi16> to vector<64xf16>
71-
// CHECK: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, %arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[B_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
71+
// CHECK: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, [[B_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
7272
// CHECK: [[B0:%.*]] = llvm.load [[B_PTR]] : !llvm.ptr -> vector<32xi32>
7373
// CHECK-NEXT: [[castB:%.*]] = llvm.bitcast [[B0]] : vector<32xi32> to vector<64xf16>
74-
// CHECK: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, %arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[B_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
74+
// CHECK: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, [[B_PTR:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
7575
// CHECK: [[B1:%.*]] = llvm.load [[B_PTR]] : !llvm.ptr -> vector<32xi32>
7676
// CHECK: [[subA1:%.*]] = llvm.shufflevector [[castA]], [[castA]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<64xf16>
7777
// CHECK: [[subB1:%.*]] = llvm.shufflevector [[castB]], [[castB]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<64xf16>
@@ -102,7 +102,7 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
102102
cf.br ^bb1(%119, %71, %115, %117, %118 : i32, tensor<8x16xf32>, !tt.ptr<tensor<32x32xf16>, 1>, !tt.ptr<tensor<32x32xf16>, 1>, !tt.ptr<tensor<32x32xf16>, 1>)
103103
^bb3:
104104
%120 = tt.make_tensor_ptr %arg2, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%21, %36] {order = array<i32: 1, 0>} : <tensor<8x16xf32>, 1>
105-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %arg2, {{.*}}
105+
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}})
106106
tt.store %120, %41 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<8x16xf32>, 1>
107107
tt.return
108108
}
@@ -120,9 +120,9 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
120120
%c0_i32 = arith.constant 0 : i32
121121
%0 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x8xf32>>
122122
%1 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x16xf32>>
123-
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, %arg0, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
123+
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
124124
%2 = tt.load %0 {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
125-
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}, %arg0, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
125+
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv({{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
126126
%3 = tt.load %1 {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
127127
tt.return
128128
}
@@ -140,7 +140,7 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
140140
%c0_i32 = arith.constant 0 : i32
141141
%cst = arith.constant dense<0.000000e+00> : tensor<8x16xf16>
142142
%0 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x16xf16>>
143-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %arg0, {{.*}})
143+
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}})
144144
tt.store %0, %cst {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf16>>
145145
tt.return
146146
}

0 commit comments

Comments
 (0)