Skip to content

Commit 6569ca6

Browse files
AlexAUTanmyachev
authored andcommitted
[AMD] Allow async load global to load block dimension duplication (#8788)
Broadcasts in the `block` dimensions are not redundant so we should not mask them. This way each CTA has their own copy in shared memory, note that the multicast mask will be set in such cases to efficiently load the data.
1 parent a1acb06 commit 6569ca6

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
8181
// CHECK-LABEL: async_load_multicast_to_half_ctas
8282
tt.func public @async_load_multicast_to_half_ctas(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
8383
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
84-
// CHECK: llvm.amdgcn.cluster.workgroup.id.x
8584
// CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
8685
// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-7 : i32) : i32
8786
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
@@ -104,7 +103,6 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha
104103
tt.func public @async_load_multicast_group_of_2_strided_by_8(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
105104
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
106105
// Skip the first cluster id because it's emitted for address calculation
107-
// CHECK: llvm.amdgcn.cluster.workgroup.id.x
108106
// CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
109107
// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32
110108
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
@@ -146,7 +144,6 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha
146144
tt.func public @async_load_multi_cta_linear_layout(%arg0: tensor<32x32x!tt.ptr<f32>, #linear> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
147145
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
148146
// Skip the first cluster id because it's emitted for address calculation
149-
// CHECK: llvm.amdgcn.cluster.workgroup.id.x
150147
// CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
151148
// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32
152149
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,8 +1034,13 @@ struct AsyncCopyGlobalToLocalOpConversion
10341034
zipLoadValues(rewriter, loc, vec, srcElems, srcPtrTy, maskElements,
10351035
otherElems, otherTy, swizzledLaneOffsets);
10361036

1037-
Value threadPred = emitRedundantThreadPredicate(getFreeVariableMasks(srcTy),
1038-
rewriter, loc, targetInfo);
1037+
auto freeVarMasks = getFreeVariableMasks(srcTy);
1038+
// We load redundant data on different CTAs so each CTA has a copy in its
1039+
// shared memory; the multicast mask will be used by the hardware to
1040+
// efficiently broadcast to different CTAs.
1041+
freeVarMasks[rewriter.getStringAttr("block")] = 0;
1042+
Value threadPred =
1043+
emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
10391044

10401045
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
10411046
auto emitGlobalLoadLds =

0 commit comments

Comments
 (0)