Skip to content

Commit e310898

Browse files
farzonlaokblast
authored andcommitted
[HLSL] Add matrix constructors using initalizer lists (llvm#162743)
fixes llvm#159434 In HLSL matrices are matrix_type in all respects except that they support a constructor style syntax for initializing matrices. This change adds a translation of vector constructor arguments into initializer lists. This supports the following HLSL syntax: (1) HLSL matrices support constructor syntax (2) HLSL matrices are expanded to constituate components in constructor using the same initalizer list behavior defined in transformInitList allows us to support struct element initalization via HLSLElementwiseCast
1 parent e9a129f commit e310898

File tree

8 files changed

+807
-26
lines changed

8 files changed

+807
-26
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,9 +2418,9 @@ def err_init_conversion_failed : Error<
24182418
"cannot initialize %select{a variable|a parameter|template parameter|"
24192419
"return object|statement expression result|an "
24202420
"exception object|a member subobject|an array element|a new value|a value|a "
2421-
"base class|a constructor delegation|a vector element|a block element|a "
2422-
"block element|a complex element|a lambda capture|a compound literal "
2423-
"initializer|a related result|a parameter of CF audited function|a "
2421+
"base class|a constructor delegation|a vector element|a matrix element|a "
2422+
"block element|a block element|a complex element|a lambda capture|a compound"
2423+
" literal initializer|a related result|a parameter of CF audited function|a "
24242424
"structured binding|a member subobject}0 "
24252425
"%diff{of type $ with an %select{rvalue|lvalue}2 of type $|"
24262426
"with an %select{rvalue|lvalue}2 of incompatible type}1,3"
@@ -6549,9 +6549,9 @@ def warn_extern_init : Warning<"'extern' variable has an initializer">,
65496549
def err_variable_object_no_init : Error<
65506550
"variable-sized object may not be initialized">;
65516551
def err_excess_initializers : Error<
6552-
"excess elements in %select{array|vector|scalar|union|struct}0 initializer">;
6552+
"excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">;
65536553
def ext_excess_initializers : ExtWarn<
6554-
"excess elements in %select{array|vector|scalar|union|struct}0 initializer">,
6554+
"excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">,
65556555
InGroup<ExcessInitializers>;
65566556
def err_excess_initializers_for_sizeless_type : Error<
65576557
"excess elements in initializer for indivisible sizeless type %0">;

clang/include/clang/Sema/Initialization.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class alignas(8) InitializedEntity {
9191
/// or vector.
9292
EK_VectorElement,
9393

94+
/// The entity being initialized is an element of a matrix.
95+
/// or matrix.
96+
EK_MatrixElement,
97+
9498
/// The entity being initialized is a field of block descriptor for
9599
/// the copied-in c++ object.
96100
EK_BlockElement,
@@ -205,8 +209,8 @@ class alignas(8) InitializedEntity {
205209
/// virtual base.
206210
llvm::PointerIntPair<const CXXBaseSpecifier *, 1> Base;
207211

208-
/// When Kind == EK_ArrayElement, EK_VectorElement, or
209-
/// EK_ComplexElement, the index of the array or vector element being
212+
/// When Kind == EK_ArrayElement, EK_VectorElement, EK_MatrixElement,
213+
/// or EK_ComplexElement, the index of the array or vector element being
210214
/// initialized.
211215
unsigned Index;
212216

@@ -536,15 +540,15 @@ class alignas(8) InitializedEntity {
536540
/// element's index.
537541
unsigned getElementIndex() const {
538542
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
539-
getKind() == EK_ComplexElement);
543+
getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
540544
return Index;
541545
}
542546

543547
/// If this is already the initializer for an array or vector
544548
/// element, sets the element index.
545549
void setElementIndex(unsigned Index) {
546550
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
547-
getKind() == EK_ComplexElement);
551+
getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
548552
this->Index = Index;
549553
}
550554

clang/lib/Sema/CheckExprLifetime.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ getEntityLifetime(const InitializedEntity *Entity,
155155
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
156156
case InitializedEntity::EK_LambdaCapture:
157157
case InitializedEntity::EK_VectorElement:
158+
case InitializedEntity::EK_MatrixElement:
158159
case InitializedEntity::EK_ComplexElement:
159160
return {nullptr, LK_FullExpression};
160161

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "clang/AST/Expr.h"
2222
#include "clang/AST/HLSLResource.h"
2323
#include "clang/AST/Type.h"
24+
#include "clang/AST/TypeBase.h"
2425
#include "clang/AST/TypeLoc.h"
2526
#include "clang/Basic/Builtins.h"
2627
#include "clang/Basic/DiagnosticSema.h"
@@ -3432,6 +3433,11 @@ static void BuildFlattenedTypeList(QualType BaseTy,
34323433
List.insert(List.end(), VT->getNumElements(), VT->getElementType());
34333434
continue;
34343435
}
3436+
if (const auto *MT = dyn_cast<ConstantMatrixType>(T)) {
3437+
List.insert(List.end(), MT->getNumElementsFlattened(),
3438+
MT->getElementType());
3439+
continue;
3440+
}
34353441
if (const auto *RD = T->getAsCXXRecordDecl()) {
34363442
if (RD->isStandardLayout())
34373443
RD = RD->getStandardLayoutBaseWithFields();
@@ -4230,6 +4236,32 @@ class InitListTransformer {
42304236
}
42314237
return true;
42324238
}
4239+
if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
4240+
unsigned Rows = MTy->getNumRows();
4241+
unsigned Cols = MTy->getNumColumns();
4242+
QualType ElemTy = MTy->getElementType();
4243+
4244+
for (unsigned C = 0; C < Cols; ++C) {
4245+
for (unsigned R = 0; R < Rows; ++R) {
4246+
// row index literal
4247+
Expr *RowIdx = IntegerLiteral::Create(
4248+
Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), R), Ctx.IntTy,
4249+
E->getBeginLoc());
4250+
// column index literal
4251+
Expr *ColIdx = IntegerLiteral::Create(
4252+
Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), C), Ctx.IntTy,
4253+
E->getBeginLoc());
4254+
ExprResult ElExpr = S.CreateBuiltinMatrixSubscriptExpr(
4255+
E, RowIdx, ColIdx, E->getEndLoc());
4256+
if (ElExpr.isInvalid())
4257+
return false;
4258+
if (!castInitializer(ElExpr.get()))
4259+
return false;
4260+
ElExpr.get()->setType(ElemTy);
4261+
}
4262+
}
4263+
return true;
4264+
}
42334265

42344266
if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
42354267
uint64_t Size = ArrTy->getZExtSize();
@@ -4283,14 +4315,17 @@ class InitListTransformer {
42834315
return *(ArgIt++);
42844316

42854317
llvm::SmallVector<Expr *> Inits;
4286-
assert(!isa<MatrixType>(Ty) && "Matrix types not yet supported in HLSL");
42874318
Ty = Ty.getDesugaredType(Ctx);
4288-
if (Ty->isVectorType() || Ty->isConstantArrayType()) {
4319+
if (Ty->isVectorType() || Ty->isConstantArrayType() ||
4320+
Ty->isConstantMatrixType()) {
42894321
QualType ElTy;
42904322
uint64_t Size = 0;
42914323
if (auto *ATy = Ty->getAs<VectorType>()) {
42924324
ElTy = ATy->getElementType();
42934325
Size = ATy->getNumElements();
4326+
} else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
4327+
ElTy = CMTy->getElementType();
4328+
Size = CMTy->getNumElementsFlattened();
42944329
} else {
42954330
auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
42964331
ElTy = VTy->getElementType();

clang/lib/Sema/SemaInit.cpp

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "clang/AST/ExprCXX.h"
1818
#include "clang/AST/ExprObjC.h"
1919
#include "clang/AST/IgnoreExpr.h"
20+
#include "clang/AST/TypeBase.h"
2021
#include "clang/AST/TypeLoc.h"
2122
#include "clang/Basic/SourceManager.h"
2223
#include "clang/Basic/Specifiers.h"
@@ -403,6 +404,9 @@ class InitListChecker {
403404
unsigned &Index,
404405
InitListExpr *StructuredList,
405406
unsigned &StructuredIndex);
407+
void CheckMatrixType(const InitializedEntity &Entity, InitListExpr *IList,
408+
QualType DeclType, unsigned &Index,
409+
InitListExpr *StructuredList, unsigned &StructuredIndex);
406410
void CheckVectorType(const InitializedEntity &Entity,
407411
InitListExpr *IList, QualType DeclType, unsigned &Index,
408412
InitListExpr *StructuredList,
@@ -1004,7 +1008,8 @@ InitListChecker::FillInEmptyInitializations(const InitializedEntity &Entity,
10041008
return;
10051009

10061010
if (ElementEntity.getKind() == InitializedEntity::EK_ArrayElement ||
1007-
ElementEntity.getKind() == InitializedEntity::EK_VectorElement)
1011+
ElementEntity.getKind() == InitializedEntity::EK_VectorElement ||
1012+
ElementEntity.getKind() == InitializedEntity::EK_MatrixElement)
10081013
ElementEntity.setElementIndex(Init);
10091014

10101015
if (Init >= NumInits && (ILE->hasArrayFiller() || SkipEmptyInitChecks))
@@ -1274,6 +1279,7 @@ static void warnBracedScalarInit(Sema &S, const InitializedEntity &Entity,
12741279

12751280
switch (Entity.getKind()) {
12761281
case InitializedEntity::EK_VectorElement:
1282+
case InitializedEntity::EK_MatrixElement:
12771283
case InitializedEntity::EK_ComplexElement:
12781284
case InitializedEntity::EK_ArrayElement:
12791285
case InitializedEntity::EK_Parameter:
@@ -1373,11 +1379,12 @@ void InitListChecker::CheckExplicitInitList(const InitializedEntity &Entity,
13731379
SemaRef.Diag(IList->getInit(Index)->getBeginLoc(), DK)
13741380
<< T << IList->getInit(Index)->getSourceRange();
13751381
} else {
1376-
int initKind = T->isArrayType() ? 0 :
1377-
T->isVectorType() ? 1 :
1378-
T->isScalarType() ? 2 :
1379-
T->isUnionType() ? 3 :
1380-
4;
1382+
int initKind = T->isArrayType() ? 0
1383+
: T->isVectorType() ? 1
1384+
: T->isMatrixType() ? 2
1385+
: T->isScalarType() ? 3
1386+
: T->isUnionType() ? 4
1387+
: 5;
13811388

13821389
unsigned DK = ExtraInitsIsError ? diag::err_excess_initializers
13831390
: diag::ext_excess_initializers;
@@ -1431,6 +1438,9 @@ void InitListChecker::CheckListElementTypes(const InitializedEntity &Entity,
14311438
} else if (DeclType->isVectorType()) {
14321439
CheckVectorType(Entity, IList, DeclType, Index,
14331440
StructuredList, StructuredIndex);
1441+
} else if (DeclType->isMatrixType()) {
1442+
CheckMatrixType(Entity, IList, DeclType, Index, StructuredList,
1443+
StructuredIndex);
14341444
} else if (const RecordDecl *RD = DeclType->getAsRecordDecl()) {
14351445
auto Bases =
14361446
CXXRecordDecl::base_class_const_range(CXXRecordDecl::base_class_const_iterator(),
@@ -1878,6 +1888,37 @@ void InitListChecker::CheckReferenceType(const InitializedEntity &Entity,
18781888
AggrDeductionCandidateParamTypes->push_back(DeclType);
18791889
}
18801890

1891+
void InitListChecker::CheckMatrixType(const InitializedEntity &Entity,
1892+
InitListExpr *IList, QualType DeclType,
1893+
unsigned &Index,
1894+
InitListExpr *StructuredList,
1895+
unsigned &StructuredIndex) {
1896+
if (!SemaRef.getLangOpts().HLSL)
1897+
return;
1898+
1899+
const ConstantMatrixType *MT = DeclType->castAs<ConstantMatrixType>();
1900+
QualType ElemTy = MT->getElementType();
1901+
const unsigned MaxElts = MT->getNumElementsFlattened();
1902+
1903+
unsigned NumEltsInit = 0;
1904+
InitializedEntity ElemEnt =
1905+
InitializedEntity::InitializeElement(SemaRef.Context, 0, Entity);
1906+
1907+
while (NumEltsInit < MaxElts && Index < IList->getNumInits()) {
1908+
// Not a sublist: just consume directly.
1909+
ElemEnt.setElementIndex(Index);
1910+
CheckSubElementType(ElemEnt, IList, ElemTy, Index, StructuredList,
1911+
StructuredIndex);
1912+
++NumEltsInit;
1913+
}
1914+
1915+
// For HLSL The error for this case is handled in SemaHLSL's initializer
1916+
// list diagnostics, That means the execution should require NumEltsInit
1917+
// to equal Max initializers. In other words execution should never
1918+
// reach this point if this condition is not true".
1919+
assert(NumEltsInit == MaxElts && "NumEltsInit must equal MaxElts");
1920+
}
1921+
18811922
void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
18821923
InitListExpr *IList, QualType DeclType,
18831924
unsigned &Index,
@@ -3640,6 +3681,9 @@ InitializedEntity::InitializedEntity(ASTContext &Context, unsigned Index,
36403681
} else if (const VectorType *VT = Parent.getType()->getAs<VectorType>()) {
36413682
Kind = EK_VectorElement;
36423683
Type = VT->getElementType();
3684+
} else if (const MatrixType *MT = Parent.getType()->getAs<MatrixType>()) {
3685+
Kind = EK_MatrixElement;
3686+
Type = MT->getElementType();
36433687
} else {
36443688
const ComplexType *CT = Parent.getType()->getAs<ComplexType>();
36453689
assert(CT && "Unexpected type");
@@ -3688,6 +3732,7 @@ DeclarationName InitializedEntity::getName() const {
36883732
case EK_Delegating:
36893733
case EK_ArrayElement:
36903734
case EK_VectorElement:
3735+
case EK_MatrixElement:
36913736
case EK_ComplexElement:
36923737
case EK_BlockElement:
36933738
case EK_LambdaToBlockConversionBlockElement:
@@ -3721,6 +3766,7 @@ ValueDecl *InitializedEntity::getDecl() const {
37213766
case EK_Delegating:
37223767
case EK_ArrayElement:
37233768
case EK_VectorElement:
3769+
case EK_MatrixElement:
37243770
case EK_ComplexElement:
37253771
case EK_BlockElement:
37263772
case EK_LambdaToBlockConversionBlockElement:
@@ -3754,6 +3800,7 @@ bool InitializedEntity::allowsNRVO() const {
37543800
case EK_Delegating:
37553801
case EK_ArrayElement:
37563802
case EK_VectorElement:
3803+
case EK_MatrixElement:
37573804
case EK_ComplexElement:
37583805
case EK_BlockElement:
37593806
case EK_LambdaToBlockConversionBlockElement:
@@ -3793,6 +3840,9 @@ unsigned InitializedEntity::dumpImpl(raw_ostream &OS) const {
37933840
case EK_Delegating: OS << "Delegating"; break;
37943841
case EK_ArrayElement: OS << "ArrayElement " << Index; break;
37953842
case EK_VectorElement: OS << "VectorElement " << Index; break;
3843+
case EK_MatrixElement:
3844+
OS << "MatrixElement " << Index;
3845+
break;
37963846
case EK_ComplexElement: OS << "ComplexElement " << Index; break;
37973847
case EK_BlockElement: OS << "Block"; break;
37983848
case EK_LambdaToBlockConversionBlockElement:
@@ -6030,7 +6080,7 @@ static void TryOrBuildParenListInitialization(
60306080
Sequence.SetFailed(InitializationSequence::FK_ParenthesizedListInitFailed);
60316081
if (!VerifyOnly) {
60326082
QualType T = Entity.getType();
6033-
int InitKind = T->isArrayType() ? 0 : T->isUnionType() ? 3 : 4;
6083+
int InitKind = T->isArrayType() ? 0 : T->isUnionType() ? 4 : 5;
60346084
SourceRange ExcessInitSR(Args[EntityIndexToProcess]->getBeginLoc(),
60356085
Args.back()->getEndLoc());
60366086
S.Diag(Kind.getLocation(), diag::err_excess_initializers)
@@ -6823,7 +6873,8 @@ void InitializationSequence::InitializeFrom(Sema &S,
68236873
// For HLSL ext vector types we allow list initialization behavior for C++
68246874
// functional cast expressions which look like constructor syntax. This is
68256875
// accomplished by converting initialization arguments to InitListExpr.
6826-
if (S.getLangOpts().HLSL && Args.size() > 1 && DestType->isExtVectorType() &&
6876+
if (S.getLangOpts().HLSL && Args.size() > 1 &&
6877+
(DestType->isExtVectorType() || DestType->isConstantMatrixType()) &&
68276878
(SourceType.isNull() ||
68286879
!Context.hasSameUnqualifiedType(SourceType, DestType))) {
68296880
InitListExpr *ILE = new (Context)
@@ -6988,6 +7039,7 @@ static AssignmentAction getAssignmentAction(const InitializedEntity &Entity,
69887039
case InitializedEntity::EK_Binding:
69897040
case InitializedEntity::EK_ArrayElement:
69907041
case InitializedEntity::EK_VectorElement:
7042+
case InitializedEntity::EK_MatrixElement:
69917043
case InitializedEntity::EK_ComplexElement:
69927044
case InitializedEntity::EK_BlockElement:
69937045
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7013,6 +7065,7 @@ static bool shouldBindAsTemporary(const InitializedEntity &Entity) {
70137065
case InitializedEntity::EK_Base:
70147066
case InitializedEntity::EK_Delegating:
70157067
case InitializedEntity::EK_VectorElement:
7068+
case InitializedEntity::EK_MatrixElement:
70167069
case InitializedEntity::EK_ComplexElement:
70177070
case InitializedEntity::EK_Exception:
70187071
case InitializedEntity::EK_BlockElement:
@@ -7043,6 +7096,7 @@ static bool shouldDestroyEntity(const InitializedEntity &Entity) {
70437096
case InitializedEntity::EK_Base:
70447097
case InitializedEntity::EK_Delegating:
70457098
case InitializedEntity::EK_VectorElement:
7099+
case InitializedEntity::EK_MatrixElement:
70467100
case InitializedEntity::EK_ComplexElement:
70477101
case InitializedEntity::EK_BlockElement:
70487102
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7096,6 +7150,7 @@ static SourceLocation getInitializationLoc(const InitializedEntity &Entity,
70967150
case InitializedEntity::EK_Base:
70977151
case InitializedEntity::EK_Delegating:
70987152
case InitializedEntity::EK_VectorElement:
7153+
case InitializedEntity::EK_MatrixElement:
70997154
case InitializedEntity::EK_ComplexElement:
71007155
case InitializedEntity::EK_BlockElement:
71017156
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7845,11 +7900,13 @@ ExprResult InitializationSequence::Perform(Sema &S,
78457900
ExprResult CurInit((Expr *)nullptr);
78467901
SmallVector<Expr*, 4> ArrayLoopCommonExprs;
78477902

7848-
// HLSL allows vector initialization to function like list initialization, but
7849-
// use the syntax of a C++-like constructor.
7850-
bool IsHLSLVectorInit = S.getLangOpts().HLSL && DestType->isExtVectorType() &&
7851-
isa<InitListExpr>(Args[0]);
7852-
(void)IsHLSLVectorInit;
7903+
// HLSL allows vector/matrix initialization to function like list
7904+
// initialization, but use the syntax of a C++-like constructor.
7905+
bool IsHLSLVectorOrMatrixInit =
7906+
S.getLangOpts().HLSL &&
7907+
(DestType->isExtVectorType() || DestType->isConstantMatrixType()) &&
7908+
isa<InitListExpr>(Args[0]);
7909+
(void)IsHLSLVectorOrMatrixInit;
78537910

78547911
// For initialization steps that start with a single initializer,
78557912
// grab the only argument out the Args and place it into the "current"
@@ -7888,7 +7945,7 @@ ExprResult InitializationSequence::Perform(Sema &S,
78887945
case SK_StdInitializerList:
78897946
case SK_OCLSamplerInit:
78907947
case SK_OCLZeroOpaqueType: {
7891-
assert(Args.size() == 1 || IsHLSLVectorInit);
7948+
assert(Args.size() == 1 || IsHLSLVectorOrMatrixInit);
78927949
CurInit = Args[0];
78937950
if (!CurInit.get()) return ExprError();
78947951
break;
@@ -9105,7 +9162,7 @@ bool InitializationSequence::Diagnose(Sema &S,
91059162
<< R;
91069163
else
91079164
S.Diag(Kind.getLocation(), diag::err_excess_initializers)
9108-
<< /*scalar=*/2 << R;
9165+
<< /*scalar=*/3 << R;
91099166
break;
91109167
}
91119168

0 commit comments

Comments
 (0)