Skip to content

Commit 7562a29

Browse files
committed
[AMD] Avoid async load to pipeline for less than 32bit load (triton-lang#7250)
We can only use AsyncCopy if the final load width can be >= 4 bytes. `triton::canBeConvertedToAsyncLoad` checks that the vecSize of the source is large enough. Additionally we need to ensure the register to shared layout (blocked+shared) does have enough contiguous elements since we cannot scatter into LDS. Before this PR we will abort compilation instead of falling back to pipelining through registers.
1 parent 5e56853 commit 7562a29

File tree

2 files changed

+74
-7
lines changed

2 files changed

+74
-7
lines changed

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,3 +700,42 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
700700
tt.return %75#0 : tensor<128x256xf32, #blocked3>
701701
}
702702
}
703+
704+
// -----
705+
706+
// Check we do not get AsyncCopyGlobalToLocal because the vec width will be < 32bit.
707+
// The order of the shared memory will be getMemoryOrder(#linear1) == [0, 1]
708+
// which differs from the order [1, 0] of the blocked layout. Since we have to
709+
// gather into lds with AsyncCopyGlobalToLocal we have to fallback to registers
710+
711+
// COMMON-LABEL: pipeline_scale_memory_order
712+
// COMMON-NOT: ttg.async_copy_global_to_local
713+
714+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [64, 1], warpsPerCTA = [8, 1], order = [1, 0]}>
715+
#linear = #ttg.linear<{register = [[0, 4], [16, 0], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[0, 0], [0, 0], [0, 0]], block = []}>
716+
#linear1 = #ttg.linear<{register = [[0, 4], [128, 0], [256, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[16, 0], [32, 0], [64, 0]], block = []}>
717+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 16], isTransposed = true}>
718+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
719+
tt.func public @pipeline_scale_memory_order(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: i64 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg3: tensor<128x512xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg4: tensor<128x512x!tt.ptr<f32>, #mma>, %arg5: tensor<512x8x!tt.ptr<i8>, #blocked>) attributes {noinline = false} {
720+
%cst = arith.constant dense<127> : tensor<128x8xi8, #linear>
721+
%cst_0 = arith.constant dense<8> : tensor<512x8xi32, #blocked>
722+
%c256_i64 = arith.constant 256 : i64
723+
%c0_i64 = arith.constant 0 : i64
724+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x512xf32, #mma>
725+
%0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
726+
%1 = arith.extsi %0 : tensor<8xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<8xi64, #ttg.slice<{dim = 0, parent = #blocked}>>
727+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<8xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x8xi64, #blocked>
728+
%3 = tt.splat %arg0 : !tt.ptr<i8> -> tensor<1x8x!tt.ptr<i8>, #blocked>
729+
%4 = tt.addptr %3, %2 : tensor<1x8x!tt.ptr<i8>, #blocked>, tensor<1x8xi64, #blocked>
730+
%5 = tt.broadcast %4 : tensor<1x8x!tt.ptr<i8>, #blocked> -> tensor<512x8x!tt.ptr<i8>, #blocked>
731+
%6:2 = scf.for %arg6 = %c0_i64 to %arg1 step %c256_i64 iter_args(%arg7 = %cst_1, %arg8 = %5) -> (tensor<128x512xf32, #mma>, tensor<512x8x!tt.ptr<i8>, #blocked>) : i64 {
732+
%7 = tt.load %arg8 : tensor<512x8x!tt.ptr<i8>, #blocked>
733+
%8 = ttg.convert_layout %7 : tensor<512x8xi8, #blocked> -> tensor<512x8xi8, #linear1>
734+
%9 = tt.dot_scaled %arg2 scale %cst, %arg3 scale %8, %arg7 lhs = e4m3 rhs = e2m1 {fastMath = true} : tensor<128x256xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<128x8xi8, #linear> * tensor<128x512xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<512x8xi8, #linear1> -> tensor<128x512xf32, #mma>
735+
%10 = tt.addptr %arg8, %cst_0 : tensor<512x8x!tt.ptr<i8>, #blocked>, tensor<512x8xi32, #blocked>
736+
scf.yield %9, %10 : tensor<128x512xf32, #mma>, tensor<512x8x!tt.ptr<i8>, #blocked>
737+
}
738+
tt.store %arg4, %6#0 : tensor<128x512x!tt.ptr<f32>, #mma>
739+
tt.return
740+
}
741+
}

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -706,13 +706,44 @@ void scheduleRemainingToLastStage(int numStages,
706706
schedule.insert(op, lastStage, cluster);
707707
}
708708

709+
namespace {
710+
bool canBeConvertedToAsyncLoad(unsigned numBuffers, tt::LoadOp loadOp,
711+
Value alloc,
712+
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
713+
// If we have a single buffer we would require another barrier after the
714+
// local_reads so instead we fall back to pipeline with registers
715+
// Removing this check will create incorrect IR, see
716+
// MembarUtility.h:membarFilter
717+
if (numBuffers <= 1)
718+
return false;
719+
720+
// Compute the final vecSize we can use for the combination of sourceEncoding
721+
// and sharedEncoding. We can only use AsyncCopy if the width is >= 32 bit
722+
auto srcTy = cast<RankedTensorType>(loadOp.getPtr().getType());
723+
auto dstTy = cast<ttg::MemDescType>(alloc.getType());
724+
auto shape = srcTy.getShape();
725+
auto regLayout = triton::gpu::toLinearLayout(shape, srcTy.getEncoding());
726+
auto sharedLayout = triton::gpu::toLinearLayout(shape, dstTy.getEncoding());
727+
auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout);
728+
unsigned loadContig = regToSharedLayout.getNumConsecutiveInOut();
729+
unsigned width = loadContig * dstTy.getElementTypeBitWidth();
730+
if (width < 32)
731+
return false;
732+
733+
// Checks whether the global pointer's contiguity and mask alignment allows
734+
// for at least 32 bit wide loads
735+
return triton::canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis);
736+
}
737+
} // namespace
738+
709739
// Convert load ops into shared memory allocation loads and apply
710740
// multi-buffering based on the required number of buffers.
711741
SmallVector<std::pair<Operation *, Value>> createAndScheduleStreamOps(
712742
const llvm::MapVector<Operation *, LoadInfo> &loadToInfo, scf::ForOp &forOp,
713743
const int &numBuffers, bool useAsyncCopy, tt::CoarseSchedule &schedule,
714744
const int stages[SCHED_SIZE],
715-
const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
745+
const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters,
746+
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
716747
IRRewriter builder(forOp.getContext());
717748
Attribute sharedMemorySpace =
718749
ttg::SharedMemorySpaceAttr::get(forOp.getContext());
@@ -762,11 +793,8 @@ SmallVector<std::pair<Operation *, Value>> createAndScheduleStreamOps(
762793
// Replace tt.loads with async copies or stream copies
763794
for (auto &[op, alloc] : loadToAllocs) {
764795
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
765-
// If we have a single buffer we would require another barrier after the
766-
// local_reads so instead we fall back to pipeline with registers
767-
// Removing this check will create incorrect IR, see
768-
// MembarUtility.h:membarFilter
769-
if (useAsyncCopy && numBuffers > 1) {
796+
if (useAsyncCopy && canBeConvertedToAsyncLoad(numBuffers, loadOp, alloc,
797+
axisInfoAnalysis)) {
770798
createAndScheduleAsyncCopy(loadOp, alloc, extractIdx, forOp, schedule,
771799
stages, clusters);
772800
} else {
@@ -820,7 +848,7 @@ LogicalResult preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages,
820848
// Convert the loads into shared memory allocations and loads from them.
821849
SmallVector<std::pair<Operation *, Value>> sharedMemAllocs =
822850
createAndScheduleStreamOps(*loadToInfo, forOp, numBuffers, useAsyncCopy,
823-
schedule, stages, clusters);
851+
schedule, stages, clusters, axisInfoAnalysis);
824852

825853
scheduleDependencies(schedule, forOp, numStages);
826854
LLVM_DEBUG({

0 commit comments

Comments
 (0)