40
40
#include < utility>
41
41
42
42
using namespace clang ;
43
+ using llvm::dxil::ResourceClass;
44
+
45
+ enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
46
+
47
+ static RegisterType getRegisterType (ResourceClass RC) {
48
+ switch (RC) {
49
+ case ResourceClass::SRV:
50
+ return RegisterType::SRV;
51
+ case ResourceClass::UAV:
52
+ return RegisterType::UAV;
53
+ case ResourceClass::CBuffer:
54
+ return RegisterType::CBuffer;
55
+ case ResourceClass::Sampler:
56
+ return RegisterType::Sampler;
57
+ }
58
+ llvm_unreachable (" unexpected ResourceClass value" );
59
+ }
60
+
61
+ static RegisterType getRegisterType (StringRef Slot) {
62
+ switch (Slot[0 ]) {
63
+ case ' t' :
64
+ case ' T' :
65
+ return RegisterType::SRV;
66
+ case ' u' :
67
+ case ' U' :
68
+ return RegisterType::UAV;
69
+ case ' b' :
70
+ case ' B' :
71
+ return RegisterType::CBuffer;
72
+ case ' s' :
73
+ case ' S' :
74
+ return RegisterType::Sampler;
75
+ case ' c' :
76
+ case ' C' :
77
+ return RegisterType::C;
78
+ case ' i' :
79
+ case ' I' :
80
+ return RegisterType::I;
81
+ default :
82
+ return RegisterType::Invalid;
83
+ }
84
+ }
43
85
44
86
SemaHLSL::SemaHLSL (Sema &S) : SemaBase(S) {}
45
87
@@ -586,8 +628,7 @@ bool clang::CreateHLSLAttributedResourceType(
586
628
LocEnd = A->getRange ().getEnd ();
587
629
switch (A->getKind ()) {
588
630
case attr::HLSLResourceClass: {
589
- llvm::dxil::ResourceClass RC =
590
- cast<HLSLResourceClassAttr>(A)->getResourceClass ();
631
+ ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass ();
591
632
if (HasResourceClass) {
592
633
S.Diag (A->getLocation (), ResAttrs.ResourceClass == RC
593
634
? diag::warn_duplicate_attribute_exact
@@ -672,7 +713,7 @@ bool SemaHLSL::handleResourceTypeAttr(const ParsedAttr &AL) {
672
713
SourceLocation ArgLoc = Loc->Loc ;
673
714
674
715
// Validate resource class value
675
- llvm::dxil:: ResourceClass RC;
716
+ ResourceClass RC;
676
717
if (!HLSLResourceClassAttr::ConvertStrToResourceClass (Identifier, RC)) {
677
718
Diag (ArgLoc, diag::warn_attribute_type_not_supported)
678
719
<< " ResourceClass" << Identifier;
@@ -750,28 +791,6 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
750
791
return LocInfo;
751
792
}
752
793
753
- struct RegisterBindingFlags {
754
- bool Resource = false ;
755
- bool UDT = false ;
756
- bool Other = false ;
757
- bool Basic = false ;
758
-
759
- bool SRV = false ;
760
- bool UAV = false ;
761
- bool CBV = false ;
762
- bool Sampler = false ;
763
-
764
- bool ContainsNumeric = false ;
765
- bool DefaultGlobals = false ;
766
-
767
- // used only when Resource == true
768
- std::optional<llvm::dxil::ResourceClass> ResourceClass;
769
- };
770
-
771
- static bool isDeclaredWithinCOrTBuffer (const Decl *TheDecl) {
772
- return TheDecl && isa<HLSLBufferDecl>(TheDecl->getDeclContext ());
773
- }
774
-
775
794
// get the record decl from a var decl that we expect
776
795
// represents a resource
777
796
static CXXRecordDecl *getRecordDeclFromVarDecl (VarDecl *VD) {
@@ -786,24 +805,6 @@ static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
786
805
return TheRecordDecl;
787
806
}
788
807
789
- static void updateResourceClassFlagsFromDeclResourceClass (
790
- RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) {
791
- switch (DeclResourceClass) {
792
- case llvm::hlsl::ResourceClass::SRV:
793
- Flags.SRV = true ;
794
- break ;
795
- case llvm::hlsl::ResourceClass::UAV:
796
- Flags.UAV = true ;
797
- break ;
798
- case llvm::hlsl::ResourceClass::CBuffer:
799
- Flags.CBV = true ;
800
- break ;
801
- case llvm::hlsl::ResourceClass::Sampler:
802
- Flags.Sampler = true ;
803
- break ;
804
- }
805
- }
806
-
807
808
const HLSLAttributedResourceType *
808
809
findAttributedResourceTypeOnField (VarDecl *VD) {
809
810
assert (VD != nullptr && " expected VarDecl" );
@@ -817,8 +818,10 @@ findAttributedResourceTypeOnField(VarDecl *VD) {
817
818
return nullptr ;
818
819
}
819
820
820
- static void updateResourceClassFlagsFromRecordType (RegisterBindingFlags &Flags,
821
- const RecordType *RT) {
821
+ // Iterate over RecordType fields and return true if any of them matched the
822
+ // register type
823
+ static bool ContainsResourceForRegisterType (Sema &S, const RecordType *RT,
824
+ RegisterType RegType) {
822
825
llvm::SmallVector<const Type *> TypesToScan;
823
826
TypesToScan.emplace_back (RT);
824
827
@@ -827,8 +830,8 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
827
830
while (T->isArrayType ())
828
831
T = T->getArrayElementTypeNoTypeQual ();
829
832
if (T->isIntegralOrEnumerationType () || T->isFloatingType ()) {
830
- Flags. ContainsNumeric = true ;
831
- continue ;
833
+ if (RegType == RegisterType::C)
834
+ return true ;
832
835
}
833
836
const RecordType *RT = T->getAs <RecordType>();
834
837
if (!RT)
@@ -839,100 +842,84 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
839
842
const Type *FieldTy = FD->getType ().getTypePtr ();
840
843
if (const HLSLAttributedResourceType *AttrResType =
841
844
dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
842
- updateResourceClassFlagsFromDeclResourceClass (
843
- Flags, AttrResType->getAttrs ().ResourceClass );
844
- continue ;
845
+ ResourceClass RC = AttrResType->getAttrs ().ResourceClass ;
846
+ if (getRegisterType (RC) == RegType)
847
+ return true ;
848
+ } else {
849
+ TypesToScan.emplace_back (FD->getType ().getTypePtr ());
845
850
}
846
- TypesToScan.emplace_back (FD->getType ().getTypePtr ());
847
851
}
848
852
}
853
+ return false ;
849
854
}
850
855
851
- static RegisterBindingFlags HLSLFillRegisterBindingFlags (Sema &S,
852
- Decl *TheDecl) {
853
- RegisterBindingFlags Flags;
856
+ static void CheckContainsResourceForRegisterType (Sema &S,
857
+ SourceLocation &ArgLoc,
858
+ Decl *D, RegisterType RegType,
859
+ bool SpecifiedSpace) {
860
+ int RegTypeNum = static_cast <int >(RegType);
854
861
855
862
// check if the decl type is groupshared
856
- if (TheDecl ->hasAttr <HLSLGroupSharedAddressSpaceAttr>()) {
857
- Flags. Other = true ;
858
- return Flags ;
863
+ if (D ->hasAttr <HLSLGroupSharedAddressSpaceAttr>()) {
864
+ S. Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum ;
865
+ return ;
859
866
}
860
867
861
868
// Cbuffers and Tbuffers are HLSLBufferDecl types
862
- if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
863
- Flags.Resource = true ;
864
- Flags.ResourceClass = CBufferOrTBuffer->isCBuffer ()
865
- ? llvm::dxil::ResourceClass::CBuffer
866
- : llvm::dxil::ResourceClass::SRV;
869
+ if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
870
+ ResourceClass RC = CBufferOrTBuffer->isCBuffer () ? ResourceClass::CBuffer
871
+ : ResourceClass::SRV;
872
+ if (RegType != getRegisterType (RC))
873
+ S.Diag (D->getLocation (), diag::err_hlsl_binding_type_mismatch)
874
+ << RegTypeNum;
875
+ return ;
867
876
}
877
+
868
878
// Samplers, UAVs, and SRVs are VarDecl types
869
- else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
870
- if (const HLSLAttributedResourceType *AttrResType =
871
- findAttributedResourceTypeOnField (TheVarDecl)) {
872
- Flags.Resource = true ;
873
- Flags.ResourceClass = AttrResType->getAttrs ().ResourceClass ;
874
- } else {
875
- const clang::Type *TheBaseType = TheVarDecl->getType ().getTypePtr ();
876
- while (TheBaseType->isArrayType ())
877
- TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual ();
878
-
879
- if (TheBaseType->isArithmeticType ()) {
880
- Flags.Basic = true ;
881
- if (!isDeclaredWithinCOrTBuffer (TheDecl) &&
882
- (TheBaseType->isIntegralType (S.getASTContext ()) ||
883
- TheBaseType->isFloatingType ()))
884
- Flags.DefaultGlobals = true ;
885
- } else if (TheBaseType->isRecordType ()) {
886
- Flags.UDT = true ;
887
- const RecordType *TheRecordTy = TheBaseType->getAs <RecordType>();
888
- updateResourceClassFlagsFromRecordType (Flags, TheRecordTy);
889
- } else
890
- Flags.Other = true ;
891
- }
892
- } else {
893
- llvm_unreachable (" expected be VarDecl or HLSLBufferDecl" );
879
+ assert (isa<VarDecl>(D) && " D is expected to be VarDecl or HLSLBufferDecl" );
880
+ VarDecl *VD = cast<VarDecl>(D);
881
+
882
+ // Resource
883
+ if (const HLSLAttributedResourceType *AttrResType =
884
+ findAttributedResourceTypeOnField (VD)) {
885
+ if (RegType != getRegisterType (AttrResType->getAttrs ().ResourceClass ))
886
+ S.Diag (D->getLocation (), diag::err_hlsl_binding_type_mismatch)
887
+ << RegTypeNum;
888
+ return ;
894
889
}
895
- return Flags;
896
- }
897
890
898
- enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
891
+ const clang::Type *Ty = VD->getType ().getTypePtr ();
892
+ while (Ty->isArrayType ())
893
+ Ty = Ty->getArrayElementTypeNoTypeQual ();
899
894
900
- static RegisterType getRegisterType (llvm::dxil::ResourceClass RC) {
901
- switch (RC) {
902
- case llvm::dxil::ResourceClass::SRV:
903
- return RegisterType::SRV;
904
- case llvm::dxil::ResourceClass::UAV:
905
- return RegisterType::UAV;
906
- case llvm::dxil::ResourceClass::CBuffer:
907
- return RegisterType::CBuffer;
908
- case llvm::dxil::ResourceClass::Sampler:
909
- return RegisterType::Sampler;
910
- }
911
- llvm_unreachable (" unexpected ResourceClass value" );
912
- }
895
+ // Basic types
896
+ if (Ty->isArithmeticType ()) {
897
+ bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext ());
898
+ if (SpecifiedSpace && !DeclaredInCOrTBuffer)
899
+ S.Diag (ArgLoc, diag::err_hlsl_space_on_global_constant);
913
900
914
- static RegisterType getRegisterType (StringRef Slot) {
915
- switch (Slot[ 0 ] ) {
916
- case ' t ' :
917
- case ' T ' :
918
- return RegisterType::SRV ;
919
- case ' u ' :
920
- case ' U ' :
921
- return RegisterType::UAV;
922
- case ' b ' :
923
- case ' B ' :
924
- return RegisterType::CBuffer;
925
- case ' s ' :
926
- case ' S ' :
927
- return RegisterType::Sampler;
928
- case ' c ' :
929
- case ' C ' :
930
- return RegisterType::C;
931
- case ' i ' :
932
- case ' I ' :
933
- return RegisterType::I;
934
- default :
935
- return RegisterType::Invalid ;
901
+ if (!DeclaredInCOrTBuffer &&
902
+ (Ty-> isIntegralType (S. getASTContext ()) || Ty-> isFloatingType ()) ) {
903
+ // Default Globals
904
+ if (RegType == RegisterType::CBuffer)
905
+ S. Diag (ArgLoc, diag::warn_hlsl_deprecated_register_type_b) ;
906
+ else if (RegType != RegisterType::C)
907
+ S. Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
908
+ } else {
909
+ if (RegType == RegisterType::C)
910
+ S. Diag (ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
911
+ else
912
+ S. Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
913
+ }
914
+ } else if (Ty-> isRecordType ()) {
915
+ // Class/struct types - walk the declaration and check each field and
916
+ // subclass
917
+ if (! ContainsResourceForRegisterType (S, Ty-> getAs <RecordType>(), RegType))
918
+ S. Diag (D-> getLocation (), diag::warn_hlsl_user_defined_type_missing_member)
919
+ << RegTypeNum;
920
+ } else {
921
+ // Anything else is an error
922
+ S. Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum ;
936
923
}
937
924
}
938
925
@@ -969,76 +956,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
969
956
}
970
957
971
958
static void DiagnoseHLSLRegisterAttribute (Sema &S, SourceLocation &ArgLoc,
972
- Decl *TheDecl , RegisterType RegType,
973
- const bool SpecifiedSpace) {
959
+ Decl *D , RegisterType RegType,
960
+ bool SpecifiedSpace) {
974
961
975
962
// exactly one of these two types should be set
976
- assert (((isa<VarDecl>(TheDecl ) && !isa<HLSLBufferDecl>(TheDecl )) ||
977
- (!isa<VarDecl>(TheDecl ) && isa<HLSLBufferDecl>(TheDecl ))) &&
963
+ assert (((isa<VarDecl>(D ) && !isa<HLSLBufferDecl>(D )) ||
964
+ (!isa<VarDecl>(D ) && isa<HLSLBufferDecl>(D ))) &&
978
965
" expecting VarDecl or HLSLBufferDecl" );
979
966
980
- RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags (S, TheDecl);
981
- assert ((int )Flags.Other + (int )Flags.Resource + (int )Flags.Basic +
982
- (int )Flags.UDT ==
983
- 1 &&
984
- " only one resource analysis result should be expected" );
985
-
986
- int RegTypeNum = static_cast <int >(RegType);
987
-
988
- // first, if "other" is set, emit an error
989
- if (Flags.Other ) {
990
- S.Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
991
- return ;
992
- }
967
+ // check if the declaration contains resource matching the register type
968
+ CheckContainsResourceForRegisterType (S, ArgLoc, D, RegType, SpecifiedSpace);
993
969
994
970
// next, if multiple register annotations exist, check that none conflict.
995
- ValidateMultipleRegisterAnnotations (S, TheDecl, RegType);
996
-
997
- // next, if resource is set, make sure the register type in the register
998
- // annotation is compatible with the variable's resource type.
999
- if (Flags.Resource ) {
1000
- RegisterType ExpRegType = getRegisterType (Flags.ResourceClass .value ());
1001
- if (RegType != ExpRegType) {
1002
- S.Diag (TheDecl->getLocation (), diag::err_hlsl_binding_type_mismatch)
1003
- << RegTypeNum;
1004
- }
1005
-
1006
- return ;
1007
- }
1008
-
1009
- // next, handle diagnostics for when the "basic" flag is set
1010
- if (Flags.Basic ) {
1011
- if (SpecifiedSpace && !isDeclaredWithinCOrTBuffer (TheDecl))
1012
- S.Diag (ArgLoc, diag::err_hlsl_space_on_global_constant);
1013
-
1014
- if (Flags.DefaultGlobals ) {
1015
- if (RegType == RegisterType::CBuffer)
1016
- S.Diag (ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
1017
- else if (RegType != RegisterType::C)
1018
- S.Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1019
- return ;
1020
- }
1021
-
1022
- if (RegType == RegisterType::C)
1023
- S.Diag (ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
1024
- else
1025
- S.Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1026
-
1027
- return ;
1028
- }
1029
-
1030
- // finally, we handle the udt case
1031
- if (Flags.UDT ) {
1032
- const bool ExpectedRegisterTypesForUDT[] = {
1033
- Flags.SRV , Flags.UAV , Flags.CBV , Flags.Sampler , Flags.ContainsNumeric };
1034
- assert ((size_t )RegTypeNum < std::size (ExpectedRegisterTypesForUDT) &&
1035
- " regType has unexpected value" );
1036
-
1037
- if (!ExpectedRegisterTypesForUDT[RegTypeNum])
1038
- S.Diag (TheDecl->getLocation (),
1039
- diag::warn_hlsl_user_defined_type_missing_member)
1040
- << RegTypeNum;
1041
- }
971
+ ValidateMultipleRegisterAnnotations (S, D, RegType);
1042
972
}
1043
973
1044
974
void SemaHLSL::handleResourceBindingAttr (Decl *TheDecl, const ParsedAttr &AL) {
0 commit comments