Skip to content

Commit 967e498

Browse files
[AMD] Correctly mask program_id(1) for architected SGPRs (#7455)
Architectures with the architected SGPRs feature (Just gfx12 for now) don't use sX registers for the program id, but read from registers ttmp9 (X) and ttmp7(Y[15:0], Z[31:16]). As a result, the upper 16 bits of ttmp7 need to be masked when reading Y. This can be omitted when Z is not used. optimize_module infers this from existence of @llvm.amdgcn.workgroup.id.z() and creates the attribute `amdgpu-no-workgroup-id-z` for the backend. `test_core.py:test_zero_strided_tensors` doesn't read Z, but still has thread blocks on that axis. I don't think we know for sure during compilation which dispatch dimensions are going to be used. We can query this feature directly from LLVM and just set all dimensions as used. This does not increase register pressure, because these registers are reserved anyway. It only adds an `s_and_b32 s1, ttmp7, 0xffff` instruction when loading Y. Not too happy with recreating the MCSubtargetInfo just for this, but I didn't see a better way to preserve this object. With this change, `test_core.py:test_zero_strided_tensors` passes on gfx12. See "3.5.3.4. Compute Shader (CS)" in https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna4-instruction-set-architecture.pdf for details on this. --------- Co-authored-by: Paul Trojahn <[email protected]>
1 parent 1ab4bb4 commit 967e498

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

python/src/llvm.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ void init_triton_llvm(py::module &&m) {
220220
.def("set_calling_conv", &llvm::Function::setCallingConv)
221221
.def("add_fn_attr", [](llvm::Function *fn, std::string &name,
222222
std::string &val) { fn->addFnAttr(name, val); })
223+
.def("remove_fn_attr", [](llvm::Function *fn,
224+
std::string &name) { fn->removeFnAttr(name); })
223225
.def("add_fn_asan_attr",
224226
[](llvm::Function *fn) {
225227
fn->addFnAttr(llvm::Attribute::SanitizeAddress);

third_party/amd/backend/compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,15 @@ def make_llir(src, metadata, options):
377377

378378
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
379379

380+
# Architectures with architected SGPRs store the workgroup id in ttmp9 (X) and ttmp7 (Y[15:0], Z[31:16]).
381+
# These attributes are used to determine if Z should be masked out when loading Y. They are inferred during
382+
# optimize_module from calls to @llvm.amdgcn.workgroup.id.x/y/z(). We cannot rely on this because a
383+
# dispatch dimensions might be used even if there is no program_id() call for it.
384+
if amd.has_architected_sgprs(options.arch):
385+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-x")
386+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-y")
387+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-z")
388+
380389
if knobs.amd.scalarize_packed_fops:
381390
amd.add_scalarize_packed_fops_llvm_pass(fns[0])
382391

third_party/amd/python/triton_amd.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,18 @@ void init_triton_amd(py::module &&m) {
261261
},
262262
py::return_value_policy::take_ownership);
263263

264+
m.def("has_architected_sgprs", [](const std::string &arch) {
265+
std::string error;
266+
llvm::Triple triple(amdTargetTriple);
267+
const llvm::Target *target =
268+
llvm::TargetRegistry::lookupTarget(triple.normalize(), error);
269+
if (!target)
270+
throw std::runtime_error("target lookup error: " + error);
271+
std::unique_ptr<llvm::MCSubtargetInfo> sti(
272+
target->createMCSubtargetInfo(amdTargetTriple, arch, ""));
273+
return sti->checkFeatures("+architected-sgprs");
274+
});
275+
264276
m.def("need_extern_lib", [](llvm::Module *module, const std::string &lib) {
265277
for (llvm::Function &f : module->functions()) {
266278
if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) {

0 commit comments

Comments
 (0)