Skip to content

Commit f737843

Browse files
authored
[AMD] Get rid of flat load/store instructions (#5137)
Flat instructions could be a reason of significant slowness of workloads due to undefined addresspace of pointers, so llvm passes can not apply some of optimizations to these instructions. This patch removes addresspace casting to keep original addresspace for pointers and generate proper loads/stores. Signed-off-by: Ilya Veselov <[email protected]>
1 parent 3c189dd commit f737843

File tree

4 files changed

+8
-11
lines changed

4 files changed

+8
-11
lines changed

python/test/unit/language/test_core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4090,13 +4090,12 @@ def _kernel(dst, src, CACHE: tl.constexpr):
40904090
cv_cache_modifier_str = 'sc0 sc1'
40914091
buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line]
40924092
global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line]
4093-
flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line]
40944093
if cache == '' or cache == '.ca':
40954094
assert cg_cache_modifier_str not in (global_load_line[0] if global_load_line else buffer_load_line[0])
40964095
if cache == '.cg':
40974096
assert cg_cache_modifier_str in global_load_line[0]
40984097
if cache == '.cv':
4099-
assert cv_cache_modifier_str in flat_load_line[0]
4098+
assert cv_cache_modifier_str in global_load_line[0]
41004099

41014100
if is_cuda():
41024101
ptx = pgm.asm['ptx']

test/Conversion/amd/load_store.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
1515
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
1616
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
1717
// Load 8 elements from A with two vectorized load instruction
18-
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
18+
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr<1>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
1919
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
2020
// Load 8 elements from B with two vectorized load instruction
21-
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
21+
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr<1>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
2222
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
2323
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
2424
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
@@ -51,7 +51,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
5151
%105 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #mma>
5252
%106 = tt.addptr %105, %104 : tensor<32x32x!tt.ptr<f16>, #mma>, tensor<32x32xi32, #mma>
5353
// Store 16 elements with four vectorized store instruction
54-
// CHECK-COUNT-4: llvm.intr.masked.store {{.*}}, {{.*}}, {{.*}} {alignment = 16 : i32} : vector<4xf16>, vector<4xi1> into !llvm.ptr
54+
// CHECK-COUNT-4: llvm.intr.masked.store {{.*}}, {{.*}}, {{.*}} {alignment = 16 : i32} : vector<4xf16>, vector<4xi1> into !llvm.ptr<1>
5555
tt.store %106, %2 : tensor<32x32x!tt.ptr<f16>, #mma>
5656
tt.return
5757
}

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
2525
// CHECK: llvm.cond_br
2626
// CHECK: llvm.atomicrmw
2727
// CHECK: llvm.atomicrmw
28-
// CHECK: %[[ADDR1:.*]] = llvm.addrspacecast
29-
// CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR1]]
30-
// CHECK: %[[ADDR2:.*]] = llvm.addrspacecast
31-
// CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR2]]
28+
// CHECK: llvm.intr.masked.store
29+
// CHECK: llvm.intr.masked.store
3230
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
3331
tt.store %arg0, %0 : tensor<256x!tt.ptr<f32>, #blocked0>
3432
tt.return

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
300300
assert(wordNElems * nWords * numVecs == numElems);
301301

302302
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
303-
Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]);
303+
Value ptr = ptrElems[vecStart];
304304

305305
Value falseVal = createZeroVector(rewriter, loc, cast<VectorType>(vecTy));
306306
// If we need to mask the loaded value with other elements
@@ -477,7 +477,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
477477

478478
SmallVector<std::pair<Value, std::string>> asmArgs;
479479
Value elem = valueElems[vecStart];
480-
Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]);
480+
Value ptr = ptrElems[vecStart];
481481

482482
// Create the store val
483483
Value storeVal = packElementRangeIntoVector(

0 commit comments

Comments
 (0)