Skip to content

[HLSL] Add support for fixed-size global resource arrays #152209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -2724,6 +2724,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
bool isHLSLAttributedResourceType() const;
bool isHLSLInlineSpirvType() const;
bool isHLSLResourceRecord() const;
bool isHLSLResourceRecordArray() const;
bool isHLSLIntangibleType()
const; // Any HLSL intangible type (builtin, array, class)

Expand Down
9 changes: 8 additions & 1 deletion clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,17 @@ class SemaHLSL : public SemaBase {

void diagnoseAvailabilityViolations(TranslationUnitDecl *TU);

bool initGlobalResourceDecl(VarDecl *VD);
uint32_t getNextImplicitBindingOrderID() {
return ImplicitBindingNextOrderID++;
}

bool initGlobalResourceDecl(VarDecl *VD);
bool initGlobalResourceArrayDecl(VarDecl *VD);
void createResourceRecordCtorArgs(const Type *ResourceTy, StringRef VarName,
HLSLResourceBindingAttr *RBA,
HLSLVkBindingAttr *VkBinding,
uint32_t ArrayIndex,
llvm::SmallVector<Expr *> &Args);
};

} // namespace clang
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5246,6 +5246,15 @@ bool Type::isHLSLResourceRecord() const {
return HLSLAttributedResourceType::findHandleTypeOnResource(this) != nullptr;
}

bool Type::isHLSLResourceRecordArray() const {
const Type *Ty = getUnqualifiedDesugaredType();
if (!Ty->isArrayType())
return false;
while (isa<ConstantArrayType>(Ty))
Ty = Ty->getArrayElementTypeNoTypeQual();
return Ty->isHLSLResourceRecord();
}

bool Type::isHLSLIntangibleType() const {
const Type *Ty = getUnqualifiedDesugaredType();

Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "CGCall.h"
#include "CGCleanup.h"
#include "CGDebugInfo.h"
#include "CGHLSLRuntime.h"
#include "CGObjCRuntime.h"
#include "CGOpenMPRuntime.h"
#include "CGRecordLayout.h"
Expand Down Expand Up @@ -4532,6 +4533,15 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
LHS.getBaseInfo(), TBAAAccessInfo());
}

// The HLSL runtime handle the subscript expression on global resource arrays.
if (getLangOpts().HLSL && (E->getType()->isHLSLResourceRecord() ||
E->getType()->isHLSLResourceRecordArray())) {
std::optional<LValue> LV =
CGM.getHLSLRuntime().emitResourceArraySubscriptExpr(E, *this);
if (LV.has_value())
return *LV;
}

// All the other cases basically behave like simple offsetting.

// Handle the extvector case we ignored above.
Expand Down
232 changes: 212 additions & 20 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "CodeGenModule.h"
#include "TargetInfo.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attrs.inc"
#include "clang/AST/Decl.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/Type.h"
Expand All @@ -35,6 +36,7 @@
#include "llvm/Support/Alignment.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <optional>

using namespace clang;
using namespace CodeGen;
Expand Down Expand Up @@ -84,6 +86,124 @@ void addRootSignature(llvm::dxbc::RootSignatureVersion RootSigVer,
RootSignatureValMD->addOperand(MDVals);
}

// If the specified expr is a simple decay from an array to pointer,
// return the array subexpression. Otherwise, return nullptr.
static const Expr *getSubExprFromArrayDecayOperand(const Expr *E) {
const auto *CE = dyn_cast<CastExpr>(E);
if (!CE || CE->getCastKind() != CK_ArrayToPointerDecay)
return nullptr;
return CE->getSubExpr();
}

// Find array variable declaration from nested array subscript AST nodes
static const ValueDecl *getArrayDecl(const ArraySubscriptExpr *ASE) {
const Expr *E = nullptr;
while (ASE != nullptr) {
E = getSubExprFromArrayDecayOperand(ASE->getBase());
if (!E)
return nullptr;
ASE = dyn_cast<ArraySubscriptExpr>(E);
}
if (const DeclRefExpr *DRE = dyn_cast_or_null<DeclRefExpr>(E))
return DRE->getDecl();
return nullptr;
}

// Get the total size of the array, or -1 if the array is unbounded.
static int getTotalArraySize(const clang::Type *Ty) {
assert(Ty->isArrayType() && "expected array type");
if (Ty->isIncompleteArrayType())
return -1;
int Size = 1;
while (const auto *CAT = dyn_cast<ConstantArrayType>(Ty)) {
Size *= CAT->getSExtSize();
Ty = CAT->getArrayElementTypeNoTypeQual();
}
return Size;
}

// Find constructor decl for a specific resource record type and binding
// (implicit vs. explicit). The constructor has 6 parameters.
// For explicit binding the signature is:
// void(unsigned, unsigned, int, unsigned, const char *).
// For implicit binding the signature is:
// void(unsigned, int, unsigned, unsigned, const char *).
static CXXConstructorDecl *findResourceConstructorDecl(ASTContext &AST,
QualType ResTy,
bool ExplicitBinding) {
SmallVector<QualType> ExpParmTypes = {
AST.UnsignedIntTy, AST.UnsignedIntTy, AST.UnsignedIntTy,
AST.UnsignedIntTy, AST.getPointerType(AST.CharTy.withConst())};
ExpParmTypes[ExplicitBinding ? 2 : 1] = AST.IntTy;

CXXRecordDecl *ResDecl = ResTy->getAsCXXRecordDecl();
for (auto *Ctor : ResDecl->ctors()) {
if (Ctor->getNumParams() != ExpParmTypes.size())
continue;
ParmVarDecl **ParmIt = Ctor->param_begin();
QualType *ExpTyIt = ExpParmTypes.begin();
for (; ParmIt != Ctor->param_end() && ExpTyIt != ExpParmTypes.end();
++ParmIt, ++ExpTyIt) {
if ((*ParmIt)->getType() != *ExpTyIt)
break;
}
if (ParmIt == Ctor->param_end())
return Ctor;
}
llvm_unreachable("did not find constructor for resource class");
}

static Value *buildNameForResource(llvm::StringRef BaseName,
CodeGenModule &CGM) {
std::string Str(BaseName);
std::string GlobalName(Str + ".str");
return CGM.GetAddrOfConstantCString(Str, GlobalName.c_str()).getPointer();
}

static void createResourceCtorArgs(CodeGenModule &CGM, CXXConstructorDecl *CD,
llvm::Value *ThisPtr, llvm::Value *Range,
llvm::Value *Index, StringRef Name,
HLSLResourceBindingAttr *RBA,
HLSLVkBindingAttr *VkBinding,
CallArgList &Args) {
assert((VkBinding || RBA) && "at least one a binding attribute expected");

std::optional<uint32_t> RegisterSlot;
uint32_t SpaceNo = 0;
if (VkBinding) {
RegisterSlot = VkBinding->getBinding();
SpaceNo = VkBinding->getSet();
} else if (RBA) {
if (RBA->hasRegisterSlot())
RegisterSlot = RBA->getSlotNumber();
SpaceNo = RBA->getSpaceNumber();
}

ASTContext &AST = CD->getASTContext();
Value *NameStr = buildNameForResource(Name, CGM);
Value *Space = llvm::ConstantInt::get(CGM.IntTy, SpaceNo);

Args.add(RValue::get(ThisPtr), CD->getThisType());
if (RegisterSlot.has_value()) {
// explicit binding
auto *RegSlot = llvm::ConstantInt::get(CGM.IntTy, RegisterSlot.value());
Args.add(RValue::get(RegSlot), AST.UnsignedIntTy);
Args.add(RValue::get(Space), AST.UnsignedIntTy);
Args.add(RValue::get(Range), AST.IntTy);
Args.add(RValue::get(Index), AST.UnsignedIntTy);

} else {
// implicit binding
auto *OrderID =
llvm::ConstantInt::get(CGM.IntTy, RBA->getImplicitBindingOrderID());
Args.add(RValue::get(Space), AST.UnsignedIntTy);
Args.add(RValue::get(Range), AST.IntTy);
Args.add(RValue::get(Index), AST.UnsignedIntTy);
Args.add(RValue::get(OrderID), AST.UnsignedIntTy);
}
Args.add(RValue::get(NameStr), AST.getPointerType(AST.CharTy.withConst()));
}

} // namespace

llvm::Type *
Expand All @@ -103,13 +223,6 @@ llvm::Triple::ArchType CGHLSLRuntime::getArch() {
return CGM.getTarget().getTriple().getArch();
}

// Returns true if the type is an HLSL resource class or an array of them
static bool isResourceRecordTypeOrArrayOf(const clang::Type *Ty) {
while (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(Ty))
Ty = CAT->getArrayElementTypeNoTypeQual();
return Ty->isHLSLResourceRecord();
}

// Emits constant global variables for buffer constants declarations
// and creates metadata linking the constant globals with the buffer global.
void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
Expand Down Expand Up @@ -146,7 +259,7 @@ void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
if (VDTy.getAddressSpace() != LangAS::hlsl_constant) {
if (VD->getStorageClass() == SC_Static ||
VDTy.getAddressSpace() == LangAS::hlsl_groupshared ||
isResourceRecordTypeOrArrayOf(VDTy.getTypePtr())) {
VDTy->isHLSLResourceRecord() || VDTy->isHLSLResourceRecordArray()) {
// Emit static and groupshared variables and resource classes inside
// cbuffer as regular globals
CGM.EmitGlobal(VD);
Expand Down Expand Up @@ -597,13 +710,6 @@ static void initializeBuffer(CodeGenModule &CGM, llvm::GlobalVariable *GV,
CGM.AddCXXGlobalInit(InitResFunc);
}

static Value *buildNameForResource(llvm::StringRef BaseName,
CodeGenModule &CGM) {
std::string Str(BaseName);
std::string GlobalName(Str + ".str");
return CGM.GetAddrOfConstantCString(Str, GlobalName.c_str()).getPointer();
}

void CGHLSLRuntime::initializeBufferFromBinding(const HLSLBufferDecl *BufDecl,
llvm::GlobalVariable *GV,
HLSLVkBindingAttr *VkBinding) {
Expand Down Expand Up @@ -631,17 +737,13 @@ void CGHLSLRuntime::initializeBufferFromBinding(const HLSLBufferDecl *BufDecl,
auto *Index = llvm::ConstantInt::get(CGM.IntTy, 0);
auto *RangeSize = llvm::ConstantInt::get(CGM.IntTy, 1);
auto *Space = llvm::ConstantInt::get(CGM.IntTy, RBA->getSpaceNumber());
Value *Name = nullptr;
Value *Name = buildNameForResource(BufDecl->getName(), CGM);

llvm::Intrinsic::ID IntrinsicID =
RBA->hasRegisterSlot()
? CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic()
: CGM.getHLSLRuntime().getCreateHandleFromImplicitBindingIntrinsic();

std::string Str(BufDecl->getName());
std::string GlobalName(Str + ".str");
Name = CGM.GetAddrOfConstantCString(Str, GlobalName.c_str()).getPointer();

// buffer with explicit binding
if (RBA->hasRegisterSlot()) {
auto *RegSlot = llvm::ConstantInt::get(CGM.IntTy, RBA->getSlotNumber());
Expand Down Expand Up @@ -708,3 +810,93 @@ void CGHLSLRuntime::emitInitListOpaqueValues(CodeGenFunction &CGF,
}
}
}

std::optional<LValue> CGHLSLRuntime::emitResourceArraySubscriptExpr(
const ArraySubscriptExpr *ArraySubsExpr, CodeGenFunction &CGF) {
assert(ArraySubsExpr->getType()->isHLSLResourceRecord() ||
ArraySubsExpr->getType()->isHLSLResourceRecordArray() &&
"expected resource array subscript expression");

// let clang codegen handle local resource array subscrips
const VarDecl *ArrayDecl = dyn_cast<VarDecl>(getArrayDecl(ArraySubsExpr));
if (!ArrayDecl || !ArrayDecl->hasGlobalStorage())
return std::nullopt;

// FIXME: this is not yet implemented (llvm/llvm-project#145426)
assert(!ArraySubsExpr->getType()->isArrayType() &&
"indexing of array subsets it not supported yet");

// get total array size (= range size)
const Type *ResArrayTy = ArrayDecl->getType().getTypePtr();
assert(ResArrayTy->isHLSLResourceRecordArray() &&
"expected array of resource classes");
llvm::Value *Range =
llvm::ConstantInt::get(CGM.IntTy, getTotalArraySize(ResArrayTy));

// Iterate through all nested array subscript expressions to calculate
// the index in the flattened resource array (if this is a multi-
// dimensional array). The index is calculated as a sum of all indices
// multiplied by the total size of the array at that level.
Value *Index = nullptr;
Value *Multiplier = nullptr;
const ArraySubscriptExpr *ASE = ArraySubsExpr;
while (ASE != nullptr) {
Value *SubIndex = CGF.EmitScalarExpr(ASE->getIdx());
if (const auto *ArrayTy =
dyn_cast<ConstantArrayType>(ASE->getType().getTypePtr())) {
Value *SubMultiplier =
llvm::ConstantInt::get(CGM.IntTy, ArrayTy->getSExtSize());
Multiplier = Multiplier ? CGF.Builder.CreateMul(Multiplier, SubMultiplier)
: SubMultiplier;
SubIndex = CGF.Builder.CreateMul(SubIndex, Multiplier);
}

Index = Index ? CGF.Builder.CreateAdd(Index, SubIndex) : SubIndex;
ASE = dyn_cast<ArraySubscriptExpr>(
getSubExprFromArrayDecayOperand(ASE->getBase()));
}

// find binding info for the resource array
// (for implicit binding an HLSLResourceBindingAttr should have been added by SemaHLSL)
QualType ResourceTy = ArraySubsExpr->getType();
HLSLVkBindingAttr *VkBinding = ArrayDecl->getAttr<HLSLVkBindingAttr>();
HLSLResourceBindingAttr *RBA = ArrayDecl->getAttr<HLSLResourceBindingAttr>();
assert((VkBinding || RBA) && "resource array must have a binding attribute");

// lookup the resource class constructor based on the resource type and
// binding
CXXConstructorDecl *CD = findResourceConstructorDecl(
ArrayDecl->getASTContext(), ResourceTy, VkBinding || RBA->hasRegisterSlot());

// create a temporary variable for the resource class instance (we need to
// return an LValue)
RawAddress TmpVar = CGF.CreateMemTemp(ResourceTy);
if (auto *Size = CGF.EmitLifetimeStart(
CGM.getDataLayout().getTypeAllocSize(TmpVar.getElementType()),
TmpVar.getPointer())) {
CGF.pushFullExprCleanup<CodeGenFunction::CallLifetimeEnd>(
NormalEHLifetimeMarker, TmpVar, Size);
}
AggValueSlot ValueSlot = AggValueSlot::forAddr(
TmpVar, Qualifiers(), AggValueSlot::IsDestructed_t(true),
AggValueSlot::DoesNotNeedGCBarriers, AggValueSlot::IsAliased_t(false),
AggValueSlot::MayOverlap);

Address ThisAddress = ValueSlot.getAddress();
llvm::Value *ThisPtr = CGF.getAsNaturalPointerTo(
ThisAddress, CD->getThisType()->getPointeeType());

// assemble the constructor parameters
CallArgList Args;
createResourceCtorArgs(CGM, CD, ThisPtr, Range, Index, ArrayDecl->getName(),
RBA, VkBinding, Args);

// call the constructor
CGF.EmitCXXConstructorCall(CD, Ctor_Complete, false, false, ThisAddress, Args,
ValueSlot.mayOverlap(),
ArraySubsExpr->getExprLoc(),
ValueSlot.isSanitizerChecked());

return CGF.MakeAddrLValue(TmpVar, ArraySubsExpr->getType(),
AlignmentSource::Decl);
}
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ class Type;
class RecordType;
class DeclContext;
class HLSLPackOffsetAttr;
class ArraySubscriptExpr;

class FunctionDecl;

namespace CodeGen {

class CodeGenModule;
class CodeGenFunction;
class LValue;

class CGHLSLRuntime {
public:
Expand Down Expand Up @@ -164,6 +166,10 @@ class CGHLSLRuntime {
llvm::TargetExtType *LayoutTy);
void emitInitListOpaqueValues(CodeGenFunction &CGF, InitListExpr *E);

std::optional<LValue>
emitResourceArraySubscriptExpr(const ArraySubscriptExpr *E,
CodeGenFunction &CGF);

private:
void emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
llvm::GlobalVariable *BufGV);
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CodeGen/CodeGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5775,8 +5775,8 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl *D,
if (D->getType()->isReferenceType())
T = D->getType();

if (getLangOpts().HLSL &&
D->getType().getTypePtr()->isHLSLResourceRecord()) {
if (getLangOpts().HLSL && (D->getType()->isHLSLResourceRecord() ||
D->getType()->isHLSLResourceRecordArray())) {
Init = llvm::PoisonValue::get(getTypes().ConvertType(ASTTy));
NeedsGlobalCtor = true;
} else if (getLangOpts().CPlusPlus) {
Expand Down
Loading
Loading