Skip to content

Commit dfbea72

Browse files
authored
[AMD] Avoid async load to pipeline for less than 32bit load (#7250)
We can only use AsyncCopy if the final load width can be >= 4 bytes. `triton::canBeConvertedToAsyncLoad` checks that the vecSize of the source is large enough. Additionally we need to ensure the register to shared layout (blocked+shared) does have enough contiguous elements since we cannot scatter into LDS. Before this PR we will abort compilation instead of falling back to pipelining through registers.
1 parent ba5ac26 commit dfbea72

File tree

2 files changed

+74
-7
lines changed

2 files changed

+74
-7
lines changed

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,3 +700,42 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
700700
tt.return %75#0 : tensor<128x256xf32, #blocked3>
701701
}
702702
}
703+
704+
// -----
705+
706+
// Check we do not get AsyncCopyGlobalToLocal because the vec width will be < 32bit.
707+
// The order of the shared memory will be getMemoryOrder(#linear1) == [0, 1]
708+
// which differs from the order [1, 0] of the blocked layout. Since we have to
709+
// gather into lds with AsyncCopyGlobalToLocal we have to fallback to registers
710+
711+
// COMMON-LABEL: pipeline_scale_memory_order
712+
// COMMON-NOT: ttg.async_copy_global_to_local
713+
714+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [64, 1], warpsPerCTA = [8, 1], order = [1, 0]}>
715+
#linear = #ttg.linear<{register = [[0, 4], [16, 0], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[0, 0], [0, 0], [0, 0]], block = []}>
716+
#linear1 = #ttg.linear<{register = [[0, 4], [128, 0], [256, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[16, 0], [32, 0], [64, 0]], block = []}>
717+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 16], isTransposed = true}>
718+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
719+
tt.func public @pipeline_scale_memory_order(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: i64 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg3: tensor<128x512xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg4: tensor<128x512x!tt.ptr<f32>, #mma>, %arg5: tensor<512x8x!tt.ptr<i8>, #blocked>) attributes {noinline = false} {
720+
%cst = arith.constant dense<127> : tensor<128x8xi8, #linear>
721+
%cst_0 = arith.constant dense<8> : tensor<512x8xi32, #blocked>
722+
%c256_i64 = arith.constant 256 : i64
723+
%c0_i64 = arith.constant 0 : i64
724+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x512xf32, #mma>
725+
%0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
726+
%1 = arith.extsi %0 : tensor<8xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<8xi64, #ttg.slice<{dim = 0, parent = #blocked}>>
727+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<8xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x8xi64, #blocked>
728+
%3 = tt.splat %arg0 : !tt.ptr<i8> -> tensor<1x8x!tt.ptr<i8>, #blocked>
729+
%4 = tt.addptr %3, %2 : tensor<1x8x!tt.ptr<i8>, #blocked>, tensor<1x8xi64, #blocked>
730+
%5 = tt.broadcast %4 : tensor<1x8x!tt.ptr<i8>, #blocked> -> tensor<512x8x!tt.ptr<i8>, #blocked>
731+
%6:2 = scf.for %arg6 = %c0_i64 to %arg1 step %c256_i64 iter_args(%arg7 = %cst_1, %arg8 = %5) -> (tensor<128x512xf32, #mma>, tensor<512x8x!tt.ptr<i8>, #blocked>) : i64 {
732+
%7 = tt.load %arg8 : tensor<512x8x!tt.ptr<i8>, #blocked>
733+
%8 = ttg.convert_layout %7 : tensor<512x8xi8, #blocked> -> tensor<512x8xi8, #linear1>
734+
%9 = tt.dot_scaled %arg2 scale %cst, %arg3 scale %8, %arg7 lhs = e4m3 rhs = e2m1 {fastMath = true} : tensor<128x256xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<128x8xi8, #linear> * tensor<128x512xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<512x8xi8, #linear1> -> tensor<128x512xf32, #mma>
735+
%10 = tt.addptr %arg8, %cst_0 : tensor<512x8x!tt.ptr<i8>, #blocked>, tensor<512x8xi32, #blocked>
736+
scf.yield %9, %10 : tensor<128x512xf32, #mma>, tensor<512x8x!tt.ptr<i8>, #blocked>
737+
}
738+
tt.store %arg4, %6#0 : tensor<128x512x!tt.ptr<f32>, #mma>
739+
tt.return
740+
}
741+
}

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -717,13 +717,44 @@ void scheduleRemainingToLastStage(int numStages,
717717
schedule.insert(op, lastStage, cluster);
718718
}
719719

720+
namespace {
721+
bool canBeConvertedToAsyncLoad(unsigned numBuffers, tt::LoadOp loadOp,
722+
Value alloc,
723+
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
724+
// If we have a single buffer we would require another barrier after the
725+
// local_reads so instead we fall back to pipeline with registers
726+
// Removing this check will create incorrect IR, see
727+
// MembarUtility.h:membarFilter
728+
if (numBuffers <= 1)
729+
return false;
730+
731+
// Compute the final vecSize we can use for the combination of sourceEncoding
732+
// and sharedEncoding. We can only use AsyncCopy if the width is >= 32 bit
733+
auto srcTy = cast<RankedTensorType>(loadOp.getPtr().getType());
734+
auto dstTy = cast<ttg::MemDescType>(alloc.getType());
735+
auto shape = srcTy.getShape();
736+
auto regLayout = triton::gpu::toLinearLayout(shape, srcTy.getEncoding());
737+
auto sharedLayout = triton::gpu::toLinearLayout(shape, dstTy.getEncoding());
738+
auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout);
739+
unsigned loadContig = regToSharedLayout.getNumConsecutiveInOut();
740+
unsigned width = loadContig * dstTy.getElementTypeBitWidth();
741+
if (width < 32)
742+
return false;
743+
744+
// Checks whether the global pointer's contiguity and mask alignment allows
745+
// for at least 32 bit wide loads
746+
return triton::canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis);
747+
}
748+
} // namespace
749+
720750
// Convert load ops into shared memory allocation loads and apply
721751
// multi-buffering based on the required number of buffers.
722752
SmallVector<std::pair<Operation *, Value>> createAndScheduleStreamOps(
723753
const llvm::MapVector<Operation *, LoadInfo> &loadToInfo, scf::ForOp &forOp,
724754
const int &numBuffers, bool useAsyncCopy, tt::CoarseSchedule &schedule,
725755
const int stages[SCHED_SIZE],
726-
const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
756+
const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters,
757+
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
727758
IRRewriter builder(forOp.getContext());
728759
Attribute sharedMemorySpace =
729760
ttg::SharedMemorySpaceAttr::get(forOp.getContext());
@@ -773,11 +804,8 @@ SmallVector<std::pair<Operation *, Value>> createAndScheduleStreamOps(
773804
// Replace tt.loads with async copies or stream copies
774805
for (auto &[op, alloc] : loadToAllocs) {
775806
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
776-
// If we have a single buffer we would require another barrier after the
777-
// local_reads so instead we fall back to pipeline with registers
778-
// Removing this check will create incorrect IR, see
779-
// MembarUtility.h:membarFilter
780-
if (useAsyncCopy && numBuffers > 1) {
807+
if (useAsyncCopy && canBeConvertedToAsyncLoad(numBuffers, loadOp, alloc,
808+
axisInfoAnalysis)) {
781809
createAndScheduleAsyncCopy(loadOp, alloc, extractIdx, forOp, schedule,
782810
stages, clusters);
783811
} else {
@@ -832,7 +860,7 @@ LogicalResult preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages,
832860
// Convert the loads into shared memory allocations and loads from them.
833861
SmallVector<std::pair<Operation *, Value>> sharedMemAllocs =
834862
createAndScheduleStreamOps(*loadToInfo, forOp, numBuffers, useAsyncCopy,
835-
schedule, stages, clusters);
863+
schedule, stages, clusters, axisInfoAnalysis);
836864

837865
scheduleDependencies(schedule, forOp, numStages);
838866
LLVM_DEBUG({

0 commit comments

Comments
 (0)