Skip to content

Commit d0a888c

Browse files
agron911paultrojahnamd
authored andcommitted
[Cherry-pick] [AMD] Add ISAFamily for RDNA4 (#8054) (#572)
Summary: Cherry-picked from upstream OAI repository. Original Commit: 37888a9 Original Author: Paul Trojahn Original Date: 2025-09-04 02:02:45 +0200 Original commit message: ``` [AMD] Add ISAFamily for RDNA4 (#8054) RDNA4 is currently treated as RDNA3. This patch splits the current family into an RDNA3 and an RDNA4 family. Behavior should be unchanged. ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. Pull Request resolved: #572 Reviewed By: njriasan Differential Revision: D86391879 Pulled By: agron911 fbshipit-source-id: 2b6ac2d54cb99a5227e6cff6fcc704478adb96ac Co-authored-by: Paul Trojahn <paul.trojahn@amd.com>
1 parent 9d05d76 commit d0a888c

File tree

7 files changed

+18
-10
lines changed

7 files changed

+18
-10
lines changed

python/test/unit/language/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,8 @@ def is_layout_applicable(layout) -> bool:
329329
target_arch = triton.runtime.driver.active.get_current_target().arch
330330
if isinstance(layout, PaddedSharedLayout):
331331
return True
332-
elif "gfx11" in target_arch:
333-
# RDNA 3
332+
elif any(arch for arch in ["gfx11", "gfx12"] if arch in target_arch):
333+
# RDNA 3, 4
334334
return isinstance(layout, WmmaLayout)
335335
elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch):
336336
# CDNA 1, 2, 3, 4

third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ enum class ISAFamily {
1515
RDNA1,
1616
RDNA2,
1717
RDNA3,
18+
RDNA4,
1819
};
1920

2021
// Deduces the corresponding ISA family for the given target gfx |arch|.

third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ Value BufferEmitter::createResourceDescriptor(Value basePtr,
5454
// 3 = either swizzles or testing against offset field)
5555
// bits 30-31: Type (must be 0)
5656
uint32_t flags = (7 << 12) | (4 << 15);
57-
if (targetInfo.getISAFamily() == ISAFamily::RDNA2 ||
58-
targetInfo.getISAFamily() == ISAFamily::RDNA3) {
57+
if (llvm::is_contained({ISAFamily::RDNA2, ISAFamily::RDNA3, ISAFamily::RDNA4},
58+
targetInfo.getISAFamily())) {
5959
flags |= (1 << 24);
6060
uint32_t oob = 3;
6161
flags |= (oob << 28);

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
298298
return false;
299299
if (isCDNA(getISAFamily()) && getISAFamily() == ISAFamily::CDNA1)
300300
return false;
301-
if (isRDNA(getISAFamily()) && getISAFamily() != ISAFamily::RDNA3)
301+
if (isRDNA(getISAFamily()) &&
302+
llvm::is_contained({ISAFamily::RDNA1, ISAFamily::RDNA2}, getISAFamily()))
302303
return false;
303304

304305
Operation *reduxOp = op.getSingleCombiner();

third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) {
2323
break;
2424
}
2525

26-
// RNDA ISA cases
27-
if (kind >= llvm::AMDGPU::GK_GFX1100 && kind <= llvm::AMDGPU::GK_GFX1201)
26+
// RDNA ISA cases
27+
if (kind >= llvm::AMDGPU::GK_GFX1200 && kind <= llvm::AMDGPU::GK_GFX1201)
28+
return ISAFamily::RDNA4;
29+
if (kind >= llvm::AMDGPU::GK_GFX1100 && kind <= llvm::AMDGPU::GK_GFX1153)
2830
return ISAFamily::RDNA3;
2931
if (kind >= llvm::AMDGPU::GK_GFX1030 && kind <= llvm::AMDGPU::GK_GFX1036)
3032
return ISAFamily::RDNA2;
@@ -42,6 +44,7 @@ bool supportsVDot(llvm::StringRef arch) {
4244
case AMD::ISAFamily::CDNA4:
4345
case AMD::ISAFamily::RDNA2:
4446
case AMD::ISAFamily::RDNA3:
47+
case AMD::ISAFamily::RDNA4:
4548
return true;
4649
default:
4750
break;
@@ -68,6 +71,7 @@ bool isRDNA(ISAFamily isaFamily) {
6871
case ISAFamily::RDNA1:
6972
case ISAFamily::RDNA2:
7073
case ISAFamily::RDNA3:
74+
case ISAFamily::RDNA4:
7175
return true;
7276
default:
7377
break;

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,11 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
151151
}
152152
} else {
153153
if (!llvm::is_contained({ISAFamily::CDNA2, ISAFamily::CDNA3,
154-
ISAFamily::CDNA4, ISAFamily::RDNA3},
154+
ISAFamily::CDNA4, ISAFamily::RDNA3,
155+
ISAFamily::RDNA4},
155156
isaFamily)) {
156-
// DPP is only supported for CDNA2/CDNA3/CDNA4/RDNA3 right now, so we
157-
// fallback to ds_swizzle for other architectures.
157+
// DPP is only supported for CDNA2/CDNA3/CDNA4/RDNA3/RDNA4 right now, so
158+
// we fallback to ds_swizzle for other architectures.
158159
//
159160
// This map facilates the butterfly shuffle pattern for a stride less
160161
// than 16. The pattern stride is the key of the map.

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,7 @@ struct TritonAMDGPUAccelerateMatmulPass
14541454
/*benefit=*/2);
14551455
break;
14561456
case ISAFamily::RDNA3:
1457+
case ISAFamily::RDNA4:
14571458
ttg::populateDecomposeScaledBlockedPatterns(mfmaPatterns,
14581459
/*benefit=*/3);
14591460
mfmaPatterns.add<::BlockedToWMMA>(

0 commit comments

Comments
 (0)