Skip to content

Commit 496aaf5

Browse files
committed
[HLSL] Add separate handle for the counter
In DXIL, the counters associated with some types of structured buffers are accessed using the same handle as the buffer. That is not the case for SPIR-V. This PR creates a separate handle for the coutner buffer. Implements wg-hlsl proposal [#23](https://github.com/llvm/wg-hlsl/blob/main/proposals/0023-typed-buffer-counters.md).
1 parent b5045ae commit 496aaf5

18 files changed

+231
-90
lines changed

clang/include/clang/AST/Type.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6277,15 +6277,19 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
62776277
LLVM_PREFERRED_TYPE(bool)
62786278
uint8_t RawBuffer : 1;
62796279

6280+
LLVM_PREFERRED_TYPE(bool)
6281+
uint8_t Counter : 1;
6282+
62806283
Attributes(llvm::dxil::ResourceClass ResourceClass, bool IsROV = false,
62816284
bool RawBuffer = false)
62826285
: ResourceClass(ResourceClass), IsROV(IsROV), RawBuffer(RawBuffer) {}
62836286

62846287
Attributes() : Attributes(llvm::dxil::ResourceClass::UAV, false, false) {}
62856288

62866289
friend bool operator==(const Attributes &LHS, const Attributes &RHS) {
6287-
return std::tie(LHS.ResourceClass, LHS.IsROV, LHS.RawBuffer) ==
6288-
std::tie(RHS.ResourceClass, RHS.IsROV, RHS.RawBuffer);
6290+
return std::tie(LHS.ResourceClass, LHS.IsROV, LHS.RawBuffer,
6291+
LHS.Counter) ==
6292+
std::tie(RHS.ResourceClass, RHS.IsROV, RHS.RawBuffer, LHS.Counter);
62896293
}
62906294
friend bool operator!=(const Attributes &LHS, const Attributes &RHS) {
62916295
return !(LHS == RHS);
@@ -6326,6 +6330,7 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
63266330
ID.AddInteger(static_cast<uint32_t>(Attrs.ResourceClass));
63276331
ID.AddBoolean(Attrs.IsROV);
63286332
ID.AddBoolean(Attrs.RawBuffer);
6333+
ID.AddBoolean(Attrs.Counter);
63296334
}
63306335

63316336
static bool classof(const Type *T) {

clang/include/clang/Basic/Attr.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4845,6 +4845,12 @@ def HLSLRawBuffer : TypeAttr {
48454845
let Documentation = [InternalOnly];
48464846
}
48474847

4848+
def HLSLCounter : TypeAttr {
4849+
let Spellings = [CXX11<"hlsl", "counter">];
4850+
let LangOpts = [HLSL];
4851+
let Documentation = [InternalOnly];
4852+
}
4853+
48484854
def HLSLGroupSharedAddressSpace : TypeAttr {
48494855
let Spellings = [CustomKeyword<"groupshared">];
48504856
let Subjects = SubjectList<[Var]>;

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4659,6 +4659,8 @@ void CXXNameMangler::mangleType(const HLSLAttributedResourceType *T) {
46594659
Str += "_ROV";
46604660
if (Attrs.RawBuffer)
46614661
Str += "_Raw";
4662+
if (Attrs.Counter)
4663+
Str += "_CNT";
46624664
if (T->hasContainedType())
46634665
Str += "_CT";
46644666
mangleVendorQualifier(Str);

clang/lib/AST/TypePrinter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,6 +1981,7 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
19811981
case attr::HLSLROV:
19821982
case attr::HLSLRawBuffer:
19831983
case attr::HLSLContainedType:
1984+
case attr::HLSLCounter:
19841985
llvm_unreachable("HLSL resource type attributes handled separately");
19851986

19861987
case attr::OpenCLPrivateAddressSpace:
@@ -2125,6 +2126,8 @@ void TypePrinter::printHLSLAttributedResourceAfter(
21252126
OS << " [[hlsl::is_rov]]";
21262127
if (Attrs.RawBuffer)
21272128
OS << " [[hlsl::raw_buffer]]";
2129+
if (Attrs.Counter)
2130+
OS << " [[hlsl::counter]]";
21282131

21292132
QualType ContainedTy = T->getContainedType();
21302133
if (!ContainedTy.isNull()) {

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -494,52 +494,74 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
494494
}
495495
}
496496

497+
static Value *createHandle(CGBuilderTy &Builder, CodeGenModule &CGM,
498+
llvm::Type *ReturnTy, unsigned int Space,
499+
unsigned int Slot, uint32_t ArraySize,
500+
uint32_t ArrayIndex, bool IsNonUniform) {
501+
llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Builder.getContext());
502+
Value *Args[] = {
503+
ConstantInt::get(CGM.IntTy, Space), /* reg_space */
504+
ConstantInt::get(CGM.IntTy, Slot), /* lower_bound */
505+
ConstantInt::get(CGM.IntTy, ArraySize), /* range_size */
506+
ConstantInt::get(CGM.IntTy, ArrayIndex), /* index */
507+
ConstantInt::get(Int1Ty, IsNonUniform) /* non-uniform */
508+
};
509+
Value *CreateHandle = Builder.CreateIntrinsic(
510+
ReturnTy, CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic(),
511+
Args, nullptr);
512+
return CreateHandle;
513+
}
514+
515+
static void CreateAndStoreHandle(GlobalVariable *GV, unsigned int Slot,
516+
unsigned int Space, uint32_t structIdx,
517+
CodeGenModule &CGM, CGBuilderTy &Builder) {
518+
llvm::Type *HandleTy = GV->getValueType();
519+
assert((!HandleTy->isTargetExtTy() || structIdx == 0) &&
520+
"No struct to index into.");
521+
assert((HandleTy->isTargetExtTy() || HandleTy->isStructTy()) &&
522+
"Unexpected type.");
523+
assert((HandleTy->isTargetExtTy() ||
524+
HandleTy->getStructElementType(structIdx)->isTargetExtTy()) &&
525+
"Can not access handle.");
526+
527+
if (!HandleTy->isTargetExtTy())
528+
HandleTy = HandleTy->getStructElementType(structIdx);
529+
530+
// FIXME: resource arrays are not yet implemented. Using size 1.
531+
// FIXME: NonUniformResourceIndex bit is not yet implemented. Using false;
532+
const DataLayout &DL = CGM.getModule().getDataLayout();
533+
Value *CreateHandle =
534+
createHandle(Builder, CGM, HandleTy, Space, Slot, 1, 0, false);
535+
CreateHandle->setName(Twine(GV->getName()).concat("_h"));
536+
Value *HandleRef = Builder.CreateStructGEP(GV->getValueType(), GV, structIdx);
537+
Builder.CreateAlignedStore(CreateHandle, HandleRef,
538+
HandleRef->getPointerAlignment(DL));
539+
}
540+
497541
static void createResourceInitFn(CodeGenModule &CGM, llvm::GlobalVariable *GV,
498542
unsigned Slot, unsigned Space) {
499-
LLVMContext &Ctx = CGM.getLLVMContext();
500-
llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Ctx);
501-
502543
llvm::Function *InitResFunc = llvm::Function::Create(
503544
llvm::FunctionType::get(CGM.VoidTy, false),
504545
llvm::GlobalValue::InternalLinkage,
505546
("_init_resource_" + GV->getName()).str(), CGM.getModule());
506547
InitResFunc->addFnAttr(llvm::Attribute::AlwaysInline);
507548

549+
LLVMContext &Ctx = CGM.getLLVMContext();
508550
llvm::BasicBlock *EntryBB =
509551
llvm::BasicBlock::Create(Ctx, "entry", InitResFunc);
510552
CGBuilderTy Builder(CGM, Ctx);
511-
const DataLayout &DL = CGM.getModule().getDataLayout();
512553
Builder.SetInsertPoint(EntryBB);
513554

514-
// Make sure the global variable is resource handle (cbuffer) or
515-
// resource class (=class where the first element is a resource handle).
516-
llvm::Type *HandleTy = GV->getValueType();
517-
assert((HandleTy->isTargetExtTy() ||
518-
(HandleTy->isStructTy() &&
519-
HandleTy->getStructElementType(0)->isTargetExtTy())) &&
520-
"unexpected type of the global");
521-
if (!HandleTy->isTargetExtTy())
522-
HandleTy = HandleTy->getStructElementType(0);
523-
524-
llvm::Value *Args[] = {
525-
llvm::ConstantInt::get(CGM.IntTy, Space), /* reg_space */
526-
llvm::ConstantInt::get(CGM.IntTy, Slot), /* lower_bound */
527-
// FIXME: resource arrays are not yet implemented
528-
llvm::ConstantInt::get(CGM.IntTy, 1), /* range_size */
529-
llvm::ConstantInt::get(CGM.IntTy, 0), /* index */
530-
// FIXME: NonUniformResourceIndex bit is not yet implemented
531-
llvm::ConstantInt::get(Int1Ty, false) /* non-uniform */
532-
};
533-
llvm::Value *CreateHandle = Builder.CreateIntrinsic(
534-
/*ReturnType=*/HandleTy,
535-
CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic(), Args, nullptr,
536-
Twine(GV->getName()).concat("_h"));
555+
CreateAndStoreHandle(GV, Slot, Space, 0, CGM, Builder);
556+
StructType *ST = dyn_cast<StructType>(GV->getValueType());
557+
if (ST && ST->getNumElements() > 1) {
558+
// TODO(124561): This needs to be updated to get the correct slot for
559+
// SPIR-V. Using a placeholder value for now.
560+
uint32_t CounterSlot = (CGM.getTriple().isSPIRV() ? Slot + 1 : Slot);
561+
CreateAndStoreHandle(GV, CounterSlot, Space, 0, CGM, Builder);
562+
}
537563

538-
llvm::Value *HandleRef = Builder.CreateStructGEP(GV->getValueType(), GV, 0);
539-
Builder.CreateAlignedStore(CreateHandle, HandleRef,
540-
HandleRef->getPointerAlignment(DL));
541564
Builder.CreateRetVoid();
542-
543565
CGM.AddCXXGlobalInit(InitResFunc);
544566
}
545567

clang/lib/CodeGen/Targets/SPIR.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,18 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType(
403403
return getSPIRVImageTypeFromHLSLResource(ResAttrs, ElemType, Ctx);
404404
}
405405

406-
llvm::ArrayType *RuntimeArrayType = llvm::ArrayType::get(ElemType, 0);
407-
uint32_t StorageClass = /* StorageBuffer storage class */ 12;
406+
llvm::Type *BufferType = nullptr;
407+
uint32_t StorageClass = 0;
408+
if (ResAttrs.Counter) {
409+
BufferType = llvm::Type::getInt32Ty(CGM.getLLVMContext());
410+
StorageClass = /* Uniform */ 2;
411+
} else {
412+
BufferType = llvm::ArrayType::get(ElemType, 0);
413+
StorageClass = /* StorageBuffer */ 12;
414+
}
415+
408416
bool IsWritable = ResAttrs.ResourceClass == llvm::dxil::ResourceClass::UAV;
409-
return llvm::TargetExtType::get(Ctx, "spirv.VulkanBuffer",
410-
{RuntimeArrayType},
417+
return llvm::TargetExtType::get(Ctx, "spirv.VulkanBuffer", {BufferType},
411418
{StorageClass, IsWritable});
412419
}
413420
case llvm::dxil::ResourceClass::CBuffer:

clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,15 @@ struct BuiltinTypeMethodBuilder {
120120
// LastStmt - refers to the last statement in the method body; referencing
121121
// LastStmt will remove the statement from the method body since
122122
// it will be linked from the new expression being constructed.
123-
enum class PlaceHolder { _0, _1, _2, _3, Handle = 128, LastStmt };
123+
enum class PlaceHolder {
124+
_0,
125+
_1,
126+
_2,
127+
_3,
128+
Handle = 128,
129+
CounterHandle,
130+
LastStmt
131+
};
124132

125133
Expr *convertPlaceholder(PlaceHolder PH);
126134
Expr *convertPlaceholder(Expr *E) { return E; }
@@ -155,6 +163,7 @@ struct BuiltinTypeMethodBuilder {
155163
template <typename T> BuiltinTypeMethodBuilder &dereference(T Ptr);
156164
BuiltinTypeDeclBuilder &finalize();
157165
Expr *getResourceHandleExpr();
166+
Expr *getResourceCounterHandleExpr();
158167

159168
private:
160169
void createDecl();
@@ -322,6 +331,8 @@ TemplateParameterListBuilder::finalizeTemplateArgs(ConceptDecl *CD) {
322331
Expr *BuiltinTypeMethodBuilder::convertPlaceholder(PlaceHolder PH) {
323332
if (PH == PlaceHolder::Handle)
324333
return getResourceHandleExpr();
334+
if (PH == PlaceHolder::CounterHandle)
335+
return getResourceCounterHandleExpr();
325336

326337
if (PH == PlaceHolder::LastStmt) {
327338
assert(!StmtsList.empty() && "no statements in the list");
@@ -432,6 +443,18 @@ Expr *BuiltinTypeMethodBuilder::getResourceHandleExpr() {
432443
OK_Ordinary);
433444
}
434445

446+
Expr *BuiltinTypeMethodBuilder::getResourceCounterHandleExpr() {
447+
ensureCompleteDecl();
448+
449+
ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
450+
CXXThisExpr *This = CXXThisExpr::Create(
451+
AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
452+
FieldDecl *HandleField = DeclBuilder.getResourceCounterHandleField();
453+
return MemberExpr::CreateImplicit(AST, This, false, HandleField,
454+
HandleField->getType(), VK_LValue,
455+
OK_Ordinary);
456+
}
457+
435458
template <typename... Ts>
436459
BuiltinTypeMethodBuilder &
437460
BuiltinTypeMethodBuilder::callBuiltin(StringRef BuiltinName,
@@ -626,6 +649,30 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addHandleMember(
626649
return *this;
627650
}
628651

652+
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addCounterHandleMember(
653+
ResourceClass RC, bool IsROV, bool RawBuffer, AccessSpecifier Access) {
654+
assert(!Record->isCompleteDefinition() && "record is already complete");
655+
656+
ASTContext &Ctx = SemaRef.getASTContext();
657+
TypeSourceInfo *ElementTypeInfo =
658+
Ctx.getTrivialTypeSourceInfo(getHandleElementType(), SourceLocation());
659+
660+
// add handle member with resource type attributes
661+
QualType AttributedResTy = QualType();
662+
SmallVector<const Attr *> Attrs = {
663+
HLSLResourceClassAttr::CreateImplicit(Ctx, RC),
664+
IsROV ? HLSLROVAttr::CreateImplicit(Ctx) : nullptr,
665+
RawBuffer ? HLSLRawBufferAttr::CreateImplicit(Ctx) : nullptr,
666+
ElementTypeInfo
667+
? HLSLContainedTypeAttr::CreateImplicit(Ctx, ElementTypeInfo)
668+
: nullptr,
669+
HLSLCounterAttr::CreateImplicit(Ctx)};
670+
if (CreateHLSLAttributedResourceType(SemaRef, Ctx.HLSLResourceTy, Attrs,
671+
AttributedResTy))
672+
addMemberVariable("__counter_handle", AttributedResTy, {}, Access);
673+
return *this;
674+
}
675+
629676
// Adds default constructor to the resource class:
630677
// Resource::Resource()
631678
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addDefaultHandleConstructor() {
@@ -671,6 +718,14 @@ FieldDecl *BuiltinTypeDeclBuilder::getResourceHandleField() {
671718
return I->second;
672719
}
673720

721+
FieldDecl *BuiltinTypeDeclBuilder::getResourceCounterHandleField() {
722+
auto I = Fields.find("__counter_handle");
723+
assert(I != Fields.end() &&
724+
I->second->getType()->isHLSLAttributedResourceType() &&
725+
"record does not have resource handle field");
726+
return I->second;
727+
}
728+
674729
QualType BuiltinTypeDeclBuilder::getFirstTemplateTypeParam() {
675730
assert(Template && "record it not a template");
676731
if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
@@ -730,7 +785,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addIncrementCounterMethod() {
730785
return BuiltinTypeMethodBuilder(*this, "IncrementCounter",
731786
SemaRef.getASTContext().UnsignedIntTy)
732787
.callBuiltin("__builtin_hlsl_buffer_update_counter", QualType(),
733-
PH::Handle, getConstantIntExpr(1))
788+
PH::CounterHandle, getConstantIntExpr(1))
734789
.finalize();
735790
}
736791

@@ -739,7 +794,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addDecrementCounterMethod() {
739794
return BuiltinTypeMethodBuilder(*this, "DecrementCounter",
740795
SemaRef.getASTContext().UnsignedIntTy)
741796
.callBuiltin("__builtin_hlsl_buffer_update_counter", QualType(),
742-
PH::Handle, getConstantIntExpr(-1))
797+
PH::CounterHandle, getConstantIntExpr(-1))
743798
.finalize();
744799
}
745800

@@ -774,7 +829,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addAppendMethod() {
774829
return BuiltinTypeMethodBuilder(*this, "Append", AST.VoidTy)
775830
.addParam("value", ElemTy)
776831
.callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy,
777-
PH::Handle, getConstantIntExpr(1))
832+
PH::CounterHandle, getConstantIntExpr(1))
778833
.callBuiltin("__builtin_hlsl_resource_getpointer",
779834
AST.getPointerType(ElemTy), PH::Handle, PH::LastStmt)
780835
.dereference(PH::LastStmt)
@@ -788,7 +843,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addConsumeMethod() {
788843
QualType ElemTy = getHandleElementType();
789844
return BuiltinTypeMethodBuilder(*this, "Consume", ElemTy)
790845
.callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy,
791-
PH::Handle, getConstantIntExpr(-1))
846+
PH::CounterHandle, getConstantIntExpr(-1))
792847
.callBuiltin("__builtin_hlsl_resource_getpointer",
793848
AST.getPointerType(ElemTy), PH::Handle, PH::LastStmt)
794849
.dereference(PH::LastStmt)

clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class BuiltinTypeDeclBuilder {
7474
BuiltinTypeDeclBuilder &
7575
addHandleMember(ResourceClass RC, bool IsROV, bool RawBuffer,
7676
AccessSpecifier Access = AccessSpecifier::AS_private);
77+
BuiltinTypeDeclBuilder &
78+
addCounterHandleMember(ResourceClass RC, bool IsROV, bool RawBuffer,
79+
AccessSpecifier Access = AccessSpecifier::AS_private);
7780
BuiltinTypeDeclBuilder &addArraySubscriptOperators();
7881

7982
// Builtin types methods
@@ -90,6 +93,7 @@ class BuiltinTypeDeclBuilder {
9093

9194
private:
9295
FieldDecl *getResourceHandleField();
96+
FieldDecl *getResourceCounterHandleField();
9397
QualType getFirstTemplateTypeParam();
9498
QualType getHandleElementType();
9599
Expr *getConstantIntExpr(int value);

clang/lib/Sema/HLSLExternalSemaSource.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,14 @@ void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
129129
static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S,
130130
ResourceClass RC, bool IsROV,
131131
bool RawBuffer) {
132-
return BuiltinTypeDeclBuilder(S, Decl)
133-
.addHandleMember(RC, IsROV, RawBuffer)
134-
.addDefaultHandleConstructor();
132+
BuiltinTypeDeclBuilder BTB(S, Decl);
133+
BTB.addHandleMember(RC, IsROV, RawBuffer);
134+
135+
if (RC == ResourceClass::UAV && RawBuffer)
136+
BTB.addCounterHandleMember(RC, IsROV, RawBuffer);
137+
138+
BTB.addDefaultHandleConstructor();
139+
return BTB;
135140
}
136141

137142
// This function is responsible for constructing the constraint expression for

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,13 @@ bool clang::CreateHLSLAttributedResourceType(
12271227
ContainedTyInfo = CTAttr->getTypeLoc();
12281228
break;
12291229
}
1230+
case attr::HLSLCounter:
1231+
if (ResAttrs.Counter) {
1232+
S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
1233+
return false;
1234+
}
1235+
ResAttrs.Counter = true;
1236+
break;
12301237
default:
12311238
llvm_unreachable("unhandled resource attribute type");
12321239
}

0 commit comments

Comments
 (0)