Skip to content

Commit 37888a9

Browse files
[AMD] Add ISAFamily for RDNA4 (#8054)
RDNA4 is currently treated as RDNA3. This patch splits the current family into an RDNA3 and an RDNA4 family. Behavior should be unchanged. Co-authored-by: Paul Trojahn <[email protected]>
1 parent 1d004a9 commit 37888a9

File tree

7 files changed

+18
-10
lines changed

7 files changed

+18
-10
lines changed

python/test/unit/language/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,8 @@ def is_layout_applicable(layout) -> bool:
321321
target_arch = triton.runtime.driver.active.get_current_target().arch
322322
if isinstance(layout, PaddedSharedLayout):
323323
return True
324-
elif "gfx11" in target_arch:
325-
# RDNA 3
324+
elif any(arch for arch in ["gfx11", "gfx12"] if arch in target_arch):
325+
# RDNA 3, 4
326326
return isinstance(layout, WmmaLayout)
327327
elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch):
328328
# CDNA 1, 2, 3, 4

third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ enum class ISAFamily {
1515
RDNA1,
1616
RDNA2,
1717
RDNA3,
18+
RDNA4,
1819
};
1920

2021
// Deduces the corresponding ISA family for the given target gfx |arch|.

third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ Value BufferEmitter::createResourceDescriptor(Value basePtr,
5454
// 3 = either swizzles or testing against offset field)
5555
// bits 30-31: Type (must be 0)
5656
uint32_t flags = (7 << 12) | (4 << 15);
57-
if (targetInfo.getISAFamily() == ISAFamily::RDNA2 ||
58-
targetInfo.getISAFamily() == ISAFamily::RDNA3) {
57+
if (llvm::is_contained({ISAFamily::RDNA2, ISAFamily::RDNA3, ISAFamily::RDNA4},
58+
targetInfo.getISAFamily())) {
5959
flags |= (1 << 24);
6060
uint32_t oob = 3;
6161
flags |= (oob << 28);

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
298298
return false;
299299
if (isCDNA(getISAFamily()) && getISAFamily() == ISAFamily::CDNA1)
300300
return false;
301-
if (isRDNA(getISAFamily()) && getISAFamily() != ISAFamily::RDNA3)
301+
if (isRDNA(getISAFamily()) &&
302+
llvm::is_contained({ISAFamily::RDNA1, ISAFamily::RDNA2}, getISAFamily()))
302303
return false;
303304

304305
Operation *reduxOp = op.getSingleCombiner();

third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) {
2323
break;
2424
}
2525

26-
// RNDA ISA cases
27-
if (kind >= llvm::AMDGPU::GK_GFX1100 && kind <= llvm::AMDGPU::GK_GFX1201)
26+
// RDNA ISA cases
27+
if (kind >= llvm::AMDGPU::GK_GFX1200 && kind <= llvm::AMDGPU::GK_GFX1201)
28+
return ISAFamily::RDNA4;
29+
if (kind >= llvm::AMDGPU::GK_GFX1100 && kind <= llvm::AMDGPU::GK_GFX1153)
2830
return ISAFamily::RDNA3;
2931
if (kind >= llvm::AMDGPU::GK_GFX1030 && kind <= llvm::AMDGPU::GK_GFX1036)
3032
return ISAFamily::RDNA2;
@@ -42,6 +44,7 @@ bool supportsVDot(llvm::StringRef arch) {
4244
case AMD::ISAFamily::CDNA4:
4345
case AMD::ISAFamily::RDNA2:
4446
case AMD::ISAFamily::RDNA3:
47+
case AMD::ISAFamily::RDNA4:
4548
return true;
4649
default:
4750
break;
@@ -68,6 +71,7 @@ bool isRDNA(ISAFamily isaFamily) {
6871
case ISAFamily::RDNA1:
6972
case ISAFamily::RDNA2:
7073
case ISAFamily::RDNA3:
74+
case ISAFamily::RDNA4:
7175
return true;
7276
default:
7377
break;

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,11 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
151151
}
152152
} else {
153153
if (!llvm::is_contained({ISAFamily::CDNA2, ISAFamily::CDNA3,
154-
ISAFamily::CDNA4, ISAFamily::RDNA3},
154+
ISAFamily::CDNA4, ISAFamily::RDNA3,
155+
ISAFamily::RDNA4},
155156
isaFamily)) {
156-
// DPP is only supported for CDNA2/CDNA3/CDNA4/RDNA3 right now, so we
157-
// fallback to ds_swizzle for other architectures.
157+
// DPP is only supported for CDNA2/CDNA3/CDNA4/RDNA3/RDNA4 right now, so
158+
// we fallback to ds_swizzle for other architectures.
158159
//
159160
// This map facilates the butterfly shuffle pattern for a stride less
160161
// than 16. The pattern stride is the key of the map.

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,7 @@ struct TritonAMDGPUAccelerateMatmulPass
14541454
/*benefit=*/2);
14551455
break;
14561456
case ISAFamily::RDNA3:
1457+
case ISAFamily::RDNA4:
14571458
ttg::populateDecomposeScaledBlockedPatterns(mfmaPatterns,
14581459
/*benefit=*/3);
14591460
mfmaPatterns.add<::BlockedToWMMA>(

0 commit comments

Comments
 (0)