Skip to content

Commit dc5b0c6

Browse files
authored
[AMD] Pipeline small tensors w/ registers only on GFX950 (#7171)
Fixes a perf regression on gfx942 but preserves functionality for gfx950 (and above).
1 parent 5389ed7 commit dc5b0c6

File tree

5 files changed

+28
-10
lines changed

5 files changed

+28
-10
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ bool isPureUnaryInlineAsm(Operation *op);
208208
int getNVIDIAComputeCapability(Operation *module);
209209

210210
// Read the amd target from the module attributes
211-
StringRef getAMDArch(Operation *module);
211+
std::optional<StringRef> getAMDArch(Operation *module);
212212

213213
std::optional<mlir::triton::gpu::SwizzledSharedEncodingAttr>
214214
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,14 +1055,19 @@ int getNVIDIAComputeCapability(Operation *module) {
10551055
return computeCapability;
10561056
}
10571057

1058-
StringRef getAMDArch(Operation *module) {
1058+
std::optional<StringRef> getAMDArch(Operation *module) {
10591059
StringAttr targetAttr =
10601060
module->getAttrOfType<StringAttr>(triton::gpu::AttrTargetName);
1061-
assert(targetAttr && "Expected a target attribute on the module operation");
1061+
if (!targetAttr) {
1062+
LDBG("Expected a target attribute on the module operation");
1063+
return {};
1064+
}
10621065

10631066
StringRef ref = targetAttr.strref();
1064-
assert(ref.starts_with("hip:") &&
1065-
"expected target attribute to be prefixed with \"hip:\"");
1067+
if (!ref.starts_with("hip:")) {
1068+
LDBG("expected target attribute to be prefixed with \"hip:\"");
1069+
return {};
1070+
}
10661071

10671072
return ref.drop_front(4); // drop the "hip:"
10681073
}

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
582582
#blocked4 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 2], warpsPerCTA = [8, 1], order = [1, 0]}>
583583
#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [8, 1], order = [1, 0]}>
584584
#blocked6 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 2], order = [1, 0]}>
585-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
585+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
586586
tt.func public @pipeline_small_vector(%arg0: !tt.ptr<f8E5M2>, %arg1: !tt.ptr<f8E5M2>, %arg2: !tt.ptr<f32>, %arg3: !tt.ptr<i8>, %arg4: !tt.ptr<i8>, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) -> tensor<128x256xf32, #blocked3> {
587587
%c128_i32 = arith.constant 128 : i32
588588
%c256_i32 = arith.constant 256 : i32

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/FMA.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ class AMDFMAVectorMultiplier : public FMAVectorMultiplier {
3131
auto dElemTy = dOpTy.getElementType();
3232
auto mod = op->getParentOfType<ModuleOp>();
3333
auto arch = getAMDArch(mod);
34+
assert(arch.has_value() && "expected arch");
3435
DotIntrinsic chosenOp;
3536

36-
bool dotAvailable = AMD::supportsVDot(arch);
37+
bool dotAvailable = AMD::supportsVDot(*arch);
3738
auto b = TritonLLVMOpBuilder(loc, rewriter);
3839
if (dotAvailable) {
3940
if ((aElemTy.isF16() || aElemTy.isBF16()) && dElemTy.isF32()) {

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "TritonAMDGPUTransforms/Passes.h"
2+
#include "amd/lib/TritonAMDGPUToLLVM/TargetInfo.h"
23
#include "mlir/Support/LLVM.h"
34
#include "third_party/amd/include/Analysis/AxisInfoExt.h"
45
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
@@ -485,6 +486,11 @@ findPipelineableLoads(scf::ForOp forOp,
485486
DenseSet<Operation *> seen;
486487
// Recursively visit the given op and its operands to discover all load ops
487488
// and collect their distances and uses.
489+
490+
auto arch = getAMDArch(forOp->getParentOfType<ModuleOp>());
491+
triton::AMD::ISAFamily isaFamily = triton::AMD::ISAFamily::Unknown;
492+
if (arch)
493+
isaFamily = triton::AMD::deduceISAFamily(*arch);
488494
std::function<void(Operation * op, int distance, Operation *use)> dfs =
489495
[&](Operation *op, int distance, Operation *use) {
490496
// Skip previously visited load ops.
@@ -507,12 +513,17 @@ findPipelineableLoads(scf::ForOp forOp,
507513
}
508514
auto pointeeTy = cast<tt::PointerType>(tensorTy.getElementType())
509515
.getPointeeType();
510-
// If the max continugous bits we can read is < 32, buffer in
511-
// registers.
512-
if (vecContiguity * pointeeTy.getIntOrFloatBitWidth() >= 32) {
516+
unsigned width =
517+
vecContiguity * pointeeTy.getIntOrFloatBitWidth();
518+
// Limit shared memory sharing to width >= 32 elements.
519+
LDBG("Load " << *loadOp << " has width " << width);
520+
if (width >= 32) {
513521
sharedEncoding =
514522
getSharedEncIfAllUsersAreDotEnc(op->getResult(0))
515523
.value_or(nullptr);
524+
} else if (isaFamily != triton::AMD::ISAFamily::CDNA4) {
525+
LDBG("Skip width<32 load " << loadOp << " for arch " << arch);
526+
return;
516527
}
517528
} else if (auto useOp = dyn_cast<tt::LoadOp>(use)) {
518529
// The use of this loadOp is another loadOp. If the use is not in
@@ -790,6 +801,7 @@ LogicalResult preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages,
790801
int numBuffers = 1;
791802
std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> clusters;
792803
tt::CoarseSchedule schedule(numStages);
804+
793805
// Schedule the loads and root ops (dot ops) in the loop. This will give us
794806
// a scaffold for the final schedule.
795807
FailureOr<llvm::MapVector<Operation *, LoadInfo>> loadToInfo =

0 commit comments

Comments
 (0)