Skip to content

Commit f212826

Browse files
authored
[HLSL][NFC] Remove RegisterBindingFlags struct (#108924)
When diagnosing register bindings we just need to make sure there is a resource that matches the provided register type. We can emit the diagnostics right away instead of collecting flags in the RegisterBindingFlags struct. That also enables early exit when scanning user defined types because we can return as soon as we find a matching resource for the given register type.
1 parent 86d2abe commit f212826

File tree

1 file changed

+119
-189
lines changed

1 file changed

+119
-189
lines changed

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 119 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,48 @@
4040
#include <utility>
4141

4242
using namespace clang;
43+
using llvm::dxil::ResourceClass;
44+
45+
enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
46+
47+
static RegisterType getRegisterType(ResourceClass RC) {
48+
switch (RC) {
49+
case ResourceClass::SRV:
50+
return RegisterType::SRV;
51+
case ResourceClass::UAV:
52+
return RegisterType::UAV;
53+
case ResourceClass::CBuffer:
54+
return RegisterType::CBuffer;
55+
case ResourceClass::Sampler:
56+
return RegisterType::Sampler;
57+
}
58+
llvm_unreachable("unexpected ResourceClass value");
59+
}
60+
61+
static RegisterType getRegisterType(StringRef Slot) {
62+
switch (Slot[0]) {
63+
case 't':
64+
case 'T':
65+
return RegisterType::SRV;
66+
case 'u':
67+
case 'U':
68+
return RegisterType::UAV;
69+
case 'b':
70+
case 'B':
71+
return RegisterType::CBuffer;
72+
case 's':
73+
case 'S':
74+
return RegisterType::Sampler;
75+
case 'c':
76+
case 'C':
77+
return RegisterType::C;
78+
case 'i':
79+
case 'I':
80+
return RegisterType::I;
81+
default:
82+
return RegisterType::Invalid;
83+
}
84+
}
4385

4486
SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
4587

@@ -586,8 +628,7 @@ bool clang::CreateHLSLAttributedResourceType(
586628
LocEnd = A->getRange().getEnd();
587629
switch (A->getKind()) {
588630
case attr::HLSLResourceClass: {
589-
llvm::dxil::ResourceClass RC =
590-
cast<HLSLResourceClassAttr>(A)->getResourceClass();
631+
ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
591632
if (HasResourceClass) {
592633
S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
593634
? diag::warn_duplicate_attribute_exact
@@ -672,7 +713,7 @@ bool SemaHLSL::handleResourceTypeAttr(const ParsedAttr &AL) {
672713
SourceLocation ArgLoc = Loc->Loc;
673714

674715
// Validate resource class value
675-
llvm::dxil::ResourceClass RC;
716+
ResourceClass RC;
676717
if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
677718
Diag(ArgLoc, diag::warn_attribute_type_not_supported)
678719
<< "ResourceClass" << Identifier;
@@ -750,28 +791,6 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
750791
return LocInfo;
751792
}
752793

753-
struct RegisterBindingFlags {
754-
bool Resource = false;
755-
bool UDT = false;
756-
bool Other = false;
757-
bool Basic = false;
758-
759-
bool SRV = false;
760-
bool UAV = false;
761-
bool CBV = false;
762-
bool Sampler = false;
763-
764-
bool ContainsNumeric = false;
765-
bool DefaultGlobals = false;
766-
767-
// used only when Resource == true
768-
std::optional<llvm::dxil::ResourceClass> ResourceClass;
769-
};
770-
771-
static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
772-
return TheDecl && isa<HLSLBufferDecl>(TheDecl->getDeclContext());
773-
}
774-
775794
// get the record decl from a var decl that we expect
776795
// represents a resource
777796
static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
@@ -786,24 +805,6 @@ static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
786805
return TheRecordDecl;
787806
}
788807

789-
static void updateResourceClassFlagsFromDeclResourceClass(
790-
RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) {
791-
switch (DeclResourceClass) {
792-
case llvm::hlsl::ResourceClass::SRV:
793-
Flags.SRV = true;
794-
break;
795-
case llvm::hlsl::ResourceClass::UAV:
796-
Flags.UAV = true;
797-
break;
798-
case llvm::hlsl::ResourceClass::CBuffer:
799-
Flags.CBV = true;
800-
break;
801-
case llvm::hlsl::ResourceClass::Sampler:
802-
Flags.Sampler = true;
803-
break;
804-
}
805-
}
806-
807808
const HLSLAttributedResourceType *
808809
findAttributedResourceTypeOnField(VarDecl *VD) {
809810
assert(VD != nullptr && "expected VarDecl");
@@ -817,8 +818,10 @@ findAttributedResourceTypeOnField(VarDecl *VD) {
817818
return nullptr;
818819
}
819820

820-
static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
821-
const RecordType *RT) {
821+
// Iterate over RecordType fields and return true if any of them matched the
822+
// register type
823+
static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
824+
RegisterType RegType) {
822825
llvm::SmallVector<const Type *> TypesToScan;
823826
TypesToScan.emplace_back(RT);
824827

@@ -827,8 +830,8 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
827830
while (T->isArrayType())
828831
T = T->getArrayElementTypeNoTypeQual();
829832
if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
830-
Flags.ContainsNumeric = true;
831-
continue;
833+
if (RegType == RegisterType::C)
834+
return true;
832835
}
833836
const RecordType *RT = T->getAs<RecordType>();
834837
if (!RT)
@@ -839,100 +842,84 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
839842
const Type *FieldTy = FD->getType().getTypePtr();
840843
if (const HLSLAttributedResourceType *AttrResType =
841844
dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
842-
updateResourceClassFlagsFromDeclResourceClass(
843-
Flags, AttrResType->getAttrs().ResourceClass);
844-
continue;
845+
ResourceClass RC = AttrResType->getAttrs().ResourceClass;
846+
if (getRegisterType(RC) == RegType)
847+
return true;
848+
} else {
849+
TypesToScan.emplace_back(FD->getType().getTypePtr());
845850
}
846-
TypesToScan.emplace_back(FD->getType().getTypePtr());
847851
}
848852
}
853+
return false;
849854
}
850855

851-
static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
852-
Decl *TheDecl) {
853-
RegisterBindingFlags Flags;
856+
static void CheckContainsResourceForRegisterType(Sema &S,
857+
SourceLocation &ArgLoc,
858+
Decl *D, RegisterType RegType,
859+
bool SpecifiedSpace) {
860+
int RegTypeNum = static_cast<int>(RegType);
854861

855862
// check if the decl type is groupshared
856-
if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
857-
Flags.Other = true;
858-
return Flags;
863+
if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
864+
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
865+
return;
859866
}
860867

861868
// Cbuffers and Tbuffers are HLSLBufferDecl types
862-
if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
863-
Flags.Resource = true;
864-
Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
865-
? llvm::dxil::ResourceClass::CBuffer
866-
: llvm::dxil::ResourceClass::SRV;
869+
if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
870+
ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
871+
: ResourceClass::SRV;
872+
if (RegType != getRegisterType(RC))
873+
S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
874+
<< RegTypeNum;
875+
return;
867876
}
877+
868878
// Samplers, UAVs, and SRVs are VarDecl types
869-
else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
870-
if (const HLSLAttributedResourceType *AttrResType =
871-
findAttributedResourceTypeOnField(TheVarDecl)) {
872-
Flags.Resource = true;
873-
Flags.ResourceClass = AttrResType->getAttrs().ResourceClass;
874-
} else {
875-
const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
876-
while (TheBaseType->isArrayType())
877-
TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
878-
879-
if (TheBaseType->isArithmeticType()) {
880-
Flags.Basic = true;
881-
if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
882-
(TheBaseType->isIntegralType(S.getASTContext()) ||
883-
TheBaseType->isFloatingType()))
884-
Flags.DefaultGlobals = true;
885-
} else if (TheBaseType->isRecordType()) {
886-
Flags.UDT = true;
887-
const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
888-
updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
889-
} else
890-
Flags.Other = true;
891-
}
892-
} else {
893-
llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
879+
assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
880+
VarDecl *VD = cast<VarDecl>(D);
881+
882+
// Resource
883+
if (const HLSLAttributedResourceType *AttrResType =
884+
findAttributedResourceTypeOnField(VD)) {
885+
if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
886+
S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
887+
<< RegTypeNum;
888+
return;
894889
}
895-
return Flags;
896-
}
897890

898-
enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
891+
const clang::Type *Ty = VD->getType().getTypePtr();
892+
while (Ty->isArrayType())
893+
Ty = Ty->getArrayElementTypeNoTypeQual();
899894

900-
static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
901-
switch (RC) {
902-
case llvm::dxil::ResourceClass::SRV:
903-
return RegisterType::SRV;
904-
case llvm::dxil::ResourceClass::UAV:
905-
return RegisterType::UAV;
906-
case llvm::dxil::ResourceClass::CBuffer:
907-
return RegisterType::CBuffer;
908-
case llvm::dxil::ResourceClass::Sampler:
909-
return RegisterType::Sampler;
910-
}
911-
llvm_unreachable("unexpected ResourceClass value");
912-
}
895+
// Basic types
896+
if (Ty->isArithmeticType()) {
897+
bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
898+
if (SpecifiedSpace && !DeclaredInCOrTBuffer)
899+
S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
913900

914-
static RegisterType getRegisterType(StringRef Slot) {
915-
switch (Slot[0]) {
916-
case 't':
917-
case 'T':
918-
return RegisterType::SRV;
919-
case 'u':
920-
case 'U':
921-
return RegisterType::UAV;
922-
case 'b':
923-
case 'B':
924-
return RegisterType::CBuffer;
925-
case 's':
926-
case 'S':
927-
return RegisterType::Sampler;
928-
case 'c':
929-
case 'C':
930-
return RegisterType::C;
931-
case 'i':
932-
case 'I':
933-
return RegisterType::I;
934-
default:
935-
return RegisterType::Invalid;
901+
if (!DeclaredInCOrTBuffer &&
902+
(Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
903+
// Default Globals
904+
if (RegType == RegisterType::CBuffer)
905+
S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
906+
else if (RegType != RegisterType::C)
907+
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
908+
} else {
909+
if (RegType == RegisterType::C)
910+
S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
911+
else
912+
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
913+
}
914+
} else if (Ty->isRecordType()) {
915+
// Class/struct types - walk the declaration and check each field and
916+
// subclass
917+
if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType))
918+
S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member)
919+
<< RegTypeNum;
920+
} else {
921+
// Anything else is an error
922+
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
936923
}
937924
}
938925

@@ -969,76 +956,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
969956
}
970957

971958
static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
972-
Decl *TheDecl, RegisterType RegType,
973-
const bool SpecifiedSpace) {
959+
Decl *D, RegisterType RegType,
960+
bool SpecifiedSpace) {
974961

975962
// exactly one of these two types should be set
976-
assert(((isa<VarDecl>(TheDecl) && !isa<HLSLBufferDecl>(TheDecl)) ||
977-
(!isa<VarDecl>(TheDecl) && isa<HLSLBufferDecl>(TheDecl))) &&
963+
assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
964+
(!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
978965
"expecting VarDecl or HLSLBufferDecl");
979966

980-
RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl);
981-
assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic +
982-
(int)Flags.UDT ==
983-
1 &&
984-
"only one resource analysis result should be expected");
985-
986-
int RegTypeNum = static_cast<int>(RegType);
987-
988-
// first, if "other" is set, emit an error
989-
if (Flags.Other) {
990-
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
991-
return;
992-
}
967+
// check if the declaration contains resource matching the register type
968+
CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace);
993969

994970
// next, if multiple register annotations exist, check that none conflict.
995-
ValidateMultipleRegisterAnnotations(S, TheDecl, RegType);
996-
997-
// next, if resource is set, make sure the register type in the register
998-
// annotation is compatible with the variable's resource type.
999-
if (Flags.Resource) {
1000-
RegisterType ExpRegType = getRegisterType(Flags.ResourceClass.value());
1001-
if (RegType != ExpRegType) {
1002-
S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
1003-
<< RegTypeNum;
1004-
}
1005-
1006-
return;
1007-
}
1008-
1009-
// next, handle diagnostics for when the "basic" flag is set
1010-
if (Flags.Basic) {
1011-
if (SpecifiedSpace && !isDeclaredWithinCOrTBuffer(TheDecl))
1012-
S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
1013-
1014-
if (Flags.DefaultGlobals) {
1015-
if (RegType == RegisterType::CBuffer)
1016-
S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
1017-
else if (RegType != RegisterType::C)
1018-
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1019-
return;
1020-
}
1021-
1022-
if (RegType == RegisterType::C)
1023-
S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
1024-
else
1025-
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1026-
1027-
return;
1028-
}
1029-
1030-
// finally, we handle the udt case
1031-
if (Flags.UDT) {
1032-
const bool ExpectedRegisterTypesForUDT[] = {
1033-
Flags.SRV, Flags.UAV, Flags.CBV, Flags.Sampler, Flags.ContainsNumeric};
1034-
assert((size_t)RegTypeNum < std::size(ExpectedRegisterTypesForUDT) &&
1035-
"regType has unexpected value");
1036-
1037-
if (!ExpectedRegisterTypesForUDT[RegTypeNum])
1038-
S.Diag(TheDecl->getLocation(),
1039-
diag::warn_hlsl_user_defined_type_missing_member)
1040-
<< RegTypeNum;
1041-
}
971+
ValidateMultipleRegisterAnnotations(S, D, RegType);
1042972
}
1043973

1044974
void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {

0 commit comments

Comments
 (0)