Skip to content

Commit 80eb648

Browse files
jmmartinezgit-crd
authored andcommitted
[SPIRV][SPIRVPrepareGlobals] Map AMD's dynamic LDS 0-element globals to arrays with UINT32_MAX elements (llvm#166952)
In HIP, dynamic LDS variables are represented using `0-element` global arrays in the `__shared__` language address-space. ```cpp extern __shared__ int LDS[]; ``` These are not representable in SPIRV directly. To represent them, for AMD, we use an array with `UINT32_MAX`-elements. These are reverse translated to 0-element arrays later in AMD's SPIRV runtime pipeline (in [SPIRVReader.cpp](https://github.com/ROCm/SPIRV-LLVM-Translator/blob/8cb74e264ddcde89f62354544803dc8cdbac148d/lib/SPIRV/SPIRVReader.cpp#L358)).
1 parent 645e335 commit 80eb648

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

llvm/lib/Target/SPIRV/SPIRVPrepareGlobals.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "SPIRV.h"
15+
#include "SPIRVUtils.h"
1516

17+
#include "llvm/ADT/STLExtras.h"
1618
#include "llvm/IR/Module.h"
1719

1820
using namespace llvm;
@@ -43,6 +45,38 @@ bool tryExtendLLVMBitcodeMarker(GlobalVariable &Bitcode) {
4345
return true;
4446
}
4547

48+
// In HIP, dynamic LDS variables are represented using 0-element global arrays
49+
// in the __shared__ language address-space.
50+
//
51+
// extern __shared__ int LDS[];
52+
//
53+
// These are not representable in SPIRV directly.
54+
// To represent them, for AMD, we use an array with UINT32_MAX-elements.
55+
// These are reverse translated to 0-element arrays.
56+
bool tryExtendDynamicLDSGlobal(GlobalVariable &GV) {
57+
constexpr unsigned WorkgroupAS =
58+
storageClassToAddressSpace(SPIRV::StorageClass::Workgroup);
59+
const bool IsWorkgroupExternal =
60+
GV.hasExternalLinkage() && GV.getAddressSpace() == WorkgroupAS;
61+
if (!IsWorkgroupExternal)
62+
return false;
63+
64+
const ArrayType *AT = dyn_cast<ArrayType>(GV.getValueType());
65+
if (!AT || AT->getNumElements() != 0)
66+
return false;
67+
68+
constexpr auto UInt32Max = std::numeric_limits<uint32_t>::max();
69+
ArrayType *NewAT = ArrayType::get(AT->getElementType(), UInt32Max);
70+
GlobalVariable *NewGV = new GlobalVariable(
71+
*GV.getParent(), NewAT, GV.isConstant(), GV.getLinkage(), nullptr, "",
72+
&GV, GV.getThreadLocalMode(), WorkgroupAS, GV.isExternallyInitialized());
73+
NewGV->takeName(&GV);
74+
GV.replaceAllUsesWith(NewGV);
75+
GV.eraseFromParent();
76+
77+
return true;
78+
}
79+
4680
bool SPIRVPrepareGlobals::runOnModule(Module &M) {
4781
const bool IsAMD = M.getTargetTriple().getVendor() == Triple::AMD;
4882
if (!IsAMD)
@@ -52,6 +86,9 @@ bool SPIRVPrepareGlobals::runOnModule(Module &M) {
5286
if (GlobalVariable *Bitcode = M.getNamedGlobal("llvm.embedded.module"))
5387
Changed |= tryExtendLLVMBitcodeMarker(*Bitcode);
5488

89+
for (GlobalVariable &GV : make_early_inc_range(M.globals()))
90+
Changed |= tryExtendDynamicLDSGlobal(GV);
91+
5592
return Changed;
5693
}
5794
char SPIRVPrepareGlobals::ID = 0;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; RUN: llc -verify-machineinstrs -mtriple=spirv64-amd-amdhsa %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -mtriple=spirv64-amd-amdhsa %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK: OpName %[[#LDS:]] "lds"
5+
; CHECK: OpDecorate %[[#LDS]] LinkageAttributes "lds" Import
6+
; CHECK: %[[#UINT:]] = OpTypeInt 32 0
7+
; CHECK: %[[#UINT_MAX:]] = OpConstant %[[#UINT]] 4294967295
8+
; CHECK: %[[#LDS_ARR_TY:]] = OpTypeArray %[[#UINT]] %[[#UINT_MAX]]
9+
; CHECK: %[[#LDS_ARR_PTR_WG:]] = OpTypePointer Workgroup %[[#LDS_ARR_TY]]
10+
; CHECK: %[[#LDS]] = OpVariable %[[#LDS_ARR_PTR_WG]] Workgroup
11+
12+
@lds = external addrspace(3) global [0 x i32]
13+
14+
define spir_kernel void @foo(ptr addrspace(4) %in, ptr addrspace(4) %out) {
15+
entry:
16+
%val = load i32, ptr addrspace(4) %in
17+
%add = add i32 %val, 1
18+
store i32 %add, ptr addrspace(4) %out
19+
ret void
20+
}

0 commit comments

Comments
 (0)