Skip to content

[HLSL] Add separate handle for the counter #134864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -6277,15 +6277,19 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
LLVM_PREFERRED_TYPE(bool)
uint8_t RawBuffer : 1;

LLVM_PREFERRED_TYPE(bool)
uint8_t Counter : 1;

Attributes(llvm::dxil::ResourceClass ResourceClass, bool IsROV = false,
bool RawBuffer = false)
: ResourceClass(ResourceClass), IsROV(IsROV), RawBuffer(RawBuffer) {}

Attributes() : Attributes(llvm::dxil::ResourceClass::UAV, false, false) {}

friend bool operator==(const Attributes &LHS, const Attributes &RHS) {
return std::tie(LHS.ResourceClass, LHS.IsROV, LHS.RawBuffer) ==
std::tie(RHS.ResourceClass, RHS.IsROV, RHS.RawBuffer);
return std::tie(LHS.ResourceClass, LHS.IsROV, LHS.RawBuffer,
LHS.Counter) ==
std::tie(RHS.ResourceClass, RHS.IsROV, RHS.RawBuffer, LHS.Counter);
}
friend bool operator!=(const Attributes &LHS, const Attributes &RHS) {
return !(LHS == RHS);
Expand Down Expand Up @@ -6326,6 +6330,7 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
ID.AddInteger(static_cast<uint32_t>(Attrs.ResourceClass));
ID.AddBoolean(Attrs.IsROV);
ID.AddBoolean(Attrs.RawBuffer);
ID.AddBoolean(Attrs.Counter);
}

static bool classof(const Type *T) {
Expand Down
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4845,6 +4845,12 @@ def HLSLRawBuffer : TypeAttr {
let Documentation = [InternalOnly];
}

def HLSLCounter : TypeAttr {
let Spellings = [CXX11<"hlsl", "counter">];
let LangOpts = [HLSL];
let Documentation = [InternalOnly];
}

def HLSLGroupSharedAddressSpace : TypeAttr {
let Spellings = [CustomKeyword<"groupshared">];
let Subjects = SubjectList<[Var]>;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/ItaniumMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4659,6 +4659,8 @@ void CXXNameMangler::mangleType(const HLSLAttributedResourceType *T) {
Str += "_ROV";
if (Attrs.RawBuffer)
Str += "_Raw";
if (Attrs.Counter)
Str += "_CNT";
if (T->hasContainedType())
Str += "_CT";
mangleVendorQualifier(Str);
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1981,6 +1981,7 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
case attr::HLSLROV:
case attr::HLSLRawBuffer:
case attr::HLSLContainedType:
case attr::HLSLCounter:
llvm_unreachable("HLSL resource type attributes handled separately");

case attr::OpenCLPrivateAddressSpace:
Expand Down Expand Up @@ -2125,6 +2126,8 @@ void TypePrinter::printHLSLAttributedResourceAfter(
OS << " [[hlsl::is_rov]]";
if (Attrs.RawBuffer)
OS << " [[hlsl::raw_buffer]]";
if (Attrs.Counter)
OS << " [[hlsl::counter]]";

QualType ContainedTy = T->getContainedType();
if (!ContainedTy.isNull()) {
Expand Down
84 changes: 53 additions & 31 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,52 +494,74 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
}
}

static Value *createHandle(CGBuilderTy &Builder, CodeGenModule &CGM,
llvm::Type *ReturnTy, unsigned int Space,
unsigned int Slot, uint32_t ArraySize,
uint32_t ArrayIndex, bool IsNonUniform) {
llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Builder.getContext());
Value *Args[] = {
ConstantInt::get(CGM.IntTy, Space), /* reg_space */
ConstantInt::get(CGM.IntTy, Slot), /* lower_bound */
ConstantInt::get(CGM.IntTy, ArraySize), /* range_size */
ConstantInt::get(CGM.IntTy, ArrayIndex), /* index */
ConstantInt::get(Int1Ty, IsNonUniform) /* non-uniform */
};
Value *CreateHandle = Builder.CreateIntrinsic(
ReturnTy, CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic(),
Args, nullptr);
return CreateHandle;
}

static void createAndStoreHandle(GlobalVariable *GV, unsigned int Slot,
unsigned int Space, uint32_t StructIdx,
CodeGenModule &CGM, CGBuilderTy &Builder) {
llvm::Type *HandleTy = GV->getValueType();
assert((!HandleTy->isTargetExtTy() || StructIdx == 0) &&
"No struct to index into.");
assert((HandleTy->isTargetExtTy() || HandleTy->isStructTy()) &&
"Unexpected type.");
assert((HandleTy->isTargetExtTy() ||
HandleTy->getStructElementType(StructIdx)->isTargetExtTy()) &&
"Can not access handle.");

if (!HandleTy->isTargetExtTy())
HandleTy = HandleTy->getStructElementType(StructIdx);

// FIXME: resource arrays are not yet implemented. Using size 1.
// FIXME: NonUniformResourceIndex bit is not yet implemented. Using false;
const DataLayout &DL = CGM.getModule().getDataLayout();
Value *CreateHandle =
createHandle(Builder, CGM, HandleTy, Space, Slot, 1, 0, false);
CreateHandle->setName(Twine(GV->getName()).concat("_h"));
Value *HandleRef = Builder.CreateStructGEP(GV->getValueType(), GV, StructIdx);
Builder.CreateAlignedStore(CreateHandle, HandleRef,
HandleRef->getPointerAlignment(DL));
}

static void createResourceInitFn(CodeGenModule &CGM, llvm::GlobalVariable *GV,
unsigned Slot, unsigned Space) {
LLVMContext &Ctx = CGM.getLLVMContext();
llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Ctx);

llvm::Function *InitResFunc = llvm::Function::Create(
llvm::FunctionType::get(CGM.VoidTy, false),
llvm::GlobalValue::InternalLinkage,
("_init_resource_" + GV->getName()).str(), CGM.getModule());
InitResFunc->addFnAttr(llvm::Attribute::AlwaysInline);

LLVMContext &Ctx = CGM.getLLVMContext();
llvm::BasicBlock *EntryBB =
llvm::BasicBlock::Create(Ctx, "entry", InitResFunc);
CGBuilderTy Builder(CGM, Ctx);
const DataLayout &DL = CGM.getModule().getDataLayout();
Builder.SetInsertPoint(EntryBB);

// Make sure the global variable is resource handle (cbuffer) or
// resource class (=class where the first element is a resource handle).
llvm::Type *HandleTy = GV->getValueType();
assert((HandleTy->isTargetExtTy() ||
(HandleTy->isStructTy() &&
HandleTy->getStructElementType(0)->isTargetExtTy())) &&
"unexpected type of the global");
if (!HandleTy->isTargetExtTy())
HandleTy = HandleTy->getStructElementType(0);

llvm::Value *Args[] = {
llvm::ConstantInt::get(CGM.IntTy, Space), /* reg_space */
llvm::ConstantInt::get(CGM.IntTy, Slot), /* lower_bound */
// FIXME: resource arrays are not yet implemented
llvm::ConstantInt::get(CGM.IntTy, 1), /* range_size */
llvm::ConstantInt::get(CGM.IntTy, 0), /* index */
// FIXME: NonUniformResourceIndex bit is not yet implemented
llvm::ConstantInt::get(Int1Ty, false) /* non-uniform */
};
llvm::Value *CreateHandle = Builder.CreateIntrinsic(
/*ReturnType=*/HandleTy,
CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic(), Args, nullptr,
Twine(GV->getName()).concat("_h"));
createAndStoreHandle(GV, Slot, Space, 0, CGM, Builder);
StructType *ST = dyn_cast<StructType>(GV->getValueType());
if (ST && ST->getNumElements() > 1) {
// TODO(124561): This needs to be updated to get the correct slot for
// SPIR-V. Using a placeholder value for now.
uint32_t CounterSlot = (CGM.getTriple().isSPIRV() ? Slot + 1 : Slot);
createAndStoreHandle(GV, CounterSlot, Space, 1, CGM, Builder);
}

llvm::Value *HandleRef = Builder.CreateStructGEP(GV->getValueType(), GV, 0);
Builder.CreateAlignedStore(CreateHandle, HandleRef,
HandleRef->getPointerAlignment(DL));
Builder.CreateRetVoid();

CGM.AddCXXGlobalInit(InitResFunc);
}

Expand Down
15 changes: 11 additions & 4 deletions clang/lib/CodeGen/Targets/SPIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,18 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType(
return getSPIRVImageTypeFromHLSLResource(ResAttrs, ElemType, Ctx);
}

llvm::ArrayType *RuntimeArrayType = llvm::ArrayType::get(ElemType, 0);
uint32_t StorageClass = /* StorageBuffer storage class */ 12;
llvm::Type *BufferType = nullptr;
uint32_t StorageClass = 0;
if (ResAttrs.Counter) {
BufferType = llvm::Type::getInt32Ty(CGM.getLLVMContext());
StorageClass = /* Uniform */ 2;
} else {
BufferType = llvm::ArrayType::get(ElemType, 0);
StorageClass = /* StorageBuffer */ 12;
}

bool IsWritable = ResAttrs.ResourceClass == llvm::dxil::ResourceClass::UAV;
return llvm::TargetExtType::get(Ctx, "spirv.VulkanBuffer",
{RuntimeArrayType},
return llvm::TargetExtType::get(Ctx, "spirv.VulkanBuffer", {BufferType},
{StorageClass, IsWritable});
}
case llvm::dxil::ResourceClass::CBuffer:
Expand Down
65 changes: 60 additions & 5 deletions clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,15 @@ struct BuiltinTypeMethodBuilder {
// LastStmt - refers to the last statement in the method body; referencing
// LastStmt will remove the statement from the method body since
// it will be linked from the new expression being constructed.
enum class PlaceHolder { _0, _1, _2, _3, Handle = 128, LastStmt };
enum class PlaceHolder {
_0,
_1,
_2,
_3,
Handle = 128,
CounterHandle,
LastStmt
};

Expr *convertPlaceholder(PlaceHolder PH);
Expr *convertPlaceholder(Expr *E) { return E; }
Expand Down Expand Up @@ -155,6 +163,7 @@ struct BuiltinTypeMethodBuilder {
template <typename T> BuiltinTypeMethodBuilder &dereference(T Ptr);
BuiltinTypeDeclBuilder &finalize();
Expr *getResourceHandleExpr();
Expr *getResourceCounterHandleExpr();

private:
void createDecl();
Expand Down Expand Up @@ -322,6 +331,8 @@ TemplateParameterListBuilder::finalizeTemplateArgs(ConceptDecl *CD) {
Expr *BuiltinTypeMethodBuilder::convertPlaceholder(PlaceHolder PH) {
if (PH == PlaceHolder::Handle)
return getResourceHandleExpr();
if (PH == PlaceHolder::CounterHandle)
return getResourceCounterHandleExpr();

if (PH == PlaceHolder::LastStmt) {
assert(!StmtsList.empty() && "no statements in the list");
Expand Down Expand Up @@ -432,6 +443,18 @@ Expr *BuiltinTypeMethodBuilder::getResourceHandleExpr() {
OK_Ordinary);
}

Expr *BuiltinTypeMethodBuilder::getResourceCounterHandleExpr() {
ensureCompleteDecl();

ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
CXXThisExpr *This = CXXThisExpr::Create(
AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
FieldDecl *HandleField = DeclBuilder.getResourceCounterHandleField();
return MemberExpr::CreateImplicit(AST, This, false, HandleField,
HandleField->getType(), VK_LValue,
OK_Ordinary);
}

template <typename... Ts>
BuiltinTypeMethodBuilder &
BuiltinTypeMethodBuilder::callBuiltin(StringRef BuiltinName,
Expand Down Expand Up @@ -626,6 +649,30 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addHandleMember(
return *this;
}

BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addCounterHandleMember(
ResourceClass RC, bool IsROV, bool RawBuffer, AccessSpecifier Access) {
assert(!Record->isCompleteDefinition() && "record is already complete");

ASTContext &Ctx = SemaRef.getASTContext();
TypeSourceInfo *ElementTypeInfo =
Ctx.getTrivialTypeSourceInfo(getHandleElementType(), SourceLocation());

// add handle member with resource type attributes
QualType AttributedResTy = QualType();
SmallVector<const Attr *> Attrs = {
HLSLResourceClassAttr::CreateImplicit(Ctx, RC),
IsROV ? HLSLROVAttr::CreateImplicit(Ctx) : nullptr,
RawBuffer ? HLSLRawBufferAttr::CreateImplicit(Ctx) : nullptr,
ElementTypeInfo
? HLSLContainedTypeAttr::CreateImplicit(Ctx, ElementTypeInfo)
: nullptr,
HLSLCounterAttr::CreateImplicit(Ctx)};
if (CreateHLSLAttributedResourceType(SemaRef, Ctx.HLSLResourceTy, Attrs,
AttributedResTy))
addMemberVariable("__counter_handle", AttributedResTy, {}, Access);
return *this;
}

// Adds default constructor to the resource class:
// Resource::Resource()
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addDefaultHandleConstructor() {
Expand Down Expand Up @@ -671,6 +718,14 @@ FieldDecl *BuiltinTypeDeclBuilder::getResourceHandleField() {
return I->second;
}

FieldDecl *BuiltinTypeDeclBuilder::getResourceCounterHandleField() {
auto I = Fields.find("__counter_handle");
assert(I != Fields.end() &&
I->second->getType()->isHLSLAttributedResourceType() &&
"record does not have resource handle field");
return I->second;
}

QualType BuiltinTypeDeclBuilder::getFirstTemplateTypeParam() {
assert(Template && "record it not a template");
if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
Expand Down Expand Up @@ -730,7 +785,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addIncrementCounterMethod() {
return BuiltinTypeMethodBuilder(*this, "IncrementCounter",
SemaRef.getASTContext().UnsignedIntTy)
.callBuiltin("__builtin_hlsl_buffer_update_counter", QualType(),
PH::Handle, getConstantIntExpr(1))
PH::CounterHandle, getConstantIntExpr(1))
.finalize();
}

Expand All @@ -739,7 +794,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addDecrementCounterMethod() {
return BuiltinTypeMethodBuilder(*this, "DecrementCounter",
SemaRef.getASTContext().UnsignedIntTy)
.callBuiltin("__builtin_hlsl_buffer_update_counter", QualType(),
PH::Handle, getConstantIntExpr(-1))
PH::CounterHandle, getConstantIntExpr(-1))
.finalize();
}

Expand Down Expand Up @@ -774,7 +829,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addAppendMethod() {
return BuiltinTypeMethodBuilder(*this, "Append", AST.VoidTy)
.addParam("value", ElemTy)
.callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy,
PH::Handle, getConstantIntExpr(1))
PH::CounterHandle, getConstantIntExpr(1))
.callBuiltin("__builtin_hlsl_resource_getpointer",
AST.getPointerType(ElemTy), PH::Handle, PH::LastStmt)
.dereference(PH::LastStmt)
Expand All @@ -788,7 +843,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addConsumeMethod() {
QualType ElemTy = getHandleElementType();
return BuiltinTypeMethodBuilder(*this, "Consume", ElemTy)
.callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy,
PH::Handle, getConstantIntExpr(-1))
PH::CounterHandle, getConstantIntExpr(-1))
.callBuiltin("__builtin_hlsl_resource_getpointer",
AST.getPointerType(ElemTy), PH::Handle, PH::LastStmt)
.dereference(PH::LastStmt)
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class BuiltinTypeDeclBuilder {
BuiltinTypeDeclBuilder &
addHandleMember(ResourceClass RC, bool IsROV, bool RawBuffer,
AccessSpecifier Access = AccessSpecifier::AS_private);
BuiltinTypeDeclBuilder &
addCounterHandleMember(ResourceClass RC, bool IsROV, bool RawBuffer,
AccessSpecifier Access = AccessSpecifier::AS_private);
BuiltinTypeDeclBuilder &addArraySubscriptOperators();

// Builtin types methods
Expand All @@ -90,6 +93,7 @@ class BuiltinTypeDeclBuilder {

private:
FieldDecl *getResourceHandleField();
FieldDecl *getResourceCounterHandleField();
QualType getFirstTemplateTypeParam();
QualType getHandleElementType();
Expr *getConstantIntExpr(int value);
Expand Down
11 changes: 8 additions & 3 deletions clang/lib/Sema/HLSLExternalSemaSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,14 @@ void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S,
ResourceClass RC, bool IsROV,
bool RawBuffer) {
return BuiltinTypeDeclBuilder(S, Decl)
.addHandleMember(RC, IsROV, RawBuffer)
.addDefaultHandleConstructor();
BuiltinTypeDeclBuilder BTB(S, Decl);
BTB.addHandleMember(RC, IsROV, RawBuffer);

if (RC == ResourceClass::UAV && RawBuffer)
BTB.addCounterHandleMember(RC, IsROV, RawBuffer);

BTB.addDefaultHandleConstructor();
return BTB;
}

// This function is responsible for constructing the constraint expression for
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,13 @@ bool clang::CreateHLSLAttributedResourceType(
ContainedTyInfo = CTAttr->getTypeLoc();
break;
}
case attr::HLSLCounter:
if (ResAttrs.Counter) {
S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
return false;
}
ResAttrs.Counter = true;
break;
default:
llvm_unreachable("unhandled resource attribute type");
}
Expand Down
5 changes: 5 additions & 0 deletions clang/test/AST/HLSL/ByteAddressBuffers-AST.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ RESOURCE Buffer;
// CHECK-UAV-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
// CHECK-SAME{LITERAL}: [[hlsl::raw_buffer]]
// CHECK-SAME{LITERAL}: [[hlsl::contained_type(char8_t)]]
// CHECK-UAV: FieldDecl {{.*}} implicit __counter_handle '__hlsl_resource_t
// CHECK-UAV-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
// CHECK-UAV-SAME{LITERAL}: [[hlsl::raw_buffer]]
// CHECK-UAV-SAME{LITERAL}: [[hlsl::counter]]
// CHECK-UAV-SAME{LITERAL}: [[hlsl::contained_type(char8_t)]]

// CHECK-NOSUBSCRIPT-NOT: CXXMethodDecl {{.*}} operator[] 'const element_type &(unsigned int) const'
// CHECK-NOSUBSCRIPT-NOT: CXXMethodDecl {{.*}} operator[] 'element_type &(unsigned int)'
Loading
Loading