Skip to content

Commit 3ae1dc3

Browse files
committed
[HLSL] Add matrix constructors using initalizer lists
fixes #159434 In HLSL matrices are matrix_type in all respects except that they support a constructor style syntax for initializing matrices. This change adds a translation of vector constructor arguments into initializer lists. This supports the following HLSL syntax: (1) HLSL matrices support constructor syntax (2) HLSL matrices are expanded to constituate components in constructor using the same initalizer list behavior defined in transformInitList allows us to support struct element initalization via HLSLElementwiseCast
1 parent c7d776b commit 3ae1dc3

File tree

7 files changed

+608
-31
lines changed

7 files changed

+608
-31
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,9 +2418,9 @@ def err_init_conversion_failed : Error<
24182418
"cannot initialize %select{a variable|a parameter|template parameter|"
24192419
"return object|statement expression result|an "
24202420
"exception object|a member subobject|an array element|a new value|a value|a "
2421-
"base class|a constructor delegation|a vector element|a block element|a "
2422-
"block element|a complex element|a lambda capture|a compound literal "
2423-
"initializer|a related result|a parameter of CF audited function|a "
2421+
"base class|a constructor delegation|a vector element|a matrix element|a "
2422+
"block element|a block element|a complex element|a lambda capture|a compound"
2423+
" literal initializer|a related result|a parameter of CF audited function|a "
24242424
"structured binding|a member subobject}0 "
24252425
"%diff{of type $ with an %select{rvalue|lvalue}2 of type $|"
24262426
"with an %select{rvalue|lvalue}2 of incompatible type}1,3"
@@ -6543,9 +6543,9 @@ def warn_extern_init : Warning<"'extern' variable has an initializer">,
65436543
def err_variable_object_no_init : Error<
65446544
"variable-sized object may not be initialized">;
65456545
def err_excess_initializers : Error<
6546-
"excess elements in %select{array|vector|scalar|union|struct}0 initializer">;
6546+
"excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">;
65476547
def ext_excess_initializers : ExtWarn<
6548-
"excess elements in %select{array|vector|scalar|union|struct}0 initializer">,
6548+
"excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">,
65496549
InGroup<ExcessInitializers>;
65506550
def err_excess_initializers_for_sizeless_type : Error<
65516551
"excess elements in initializer for indivisible sizeless type %0">;
@@ -11086,8 +11086,8 @@ def err_first_argument_to_cwsc_pdtor_call : Error<
1108611086
def err_second_argument_to_cwsc_not_pointer : Error<
1108711087
"second argument to __builtin_call_with_static_chain must be of pointer type">;
1108811088

11089-
def err_vector_incorrect_num_elements : Error<
11090-
"%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">;
11089+
def err_tensor_incorrect_num_elements : Error<
11090+
"%select{too many|too few}0 elements in %select{vector|matrix}1 %select{initialization|operand}4 (expected %2 elements, have %3)">;
1109111091
def err_altivec_empty_initializer : Error<"expected initializer">;
1109211092

1109311093
def err_vector_incorrect_bit_count : Error<

clang/include/clang/Sema/Initialization.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class alignas(8) InitializedEntity {
9191
/// or vector.
9292
EK_VectorElement,
9393

94+
/// The entity being initialized is an element of a matrix.
95+
/// or matrix.
96+
EK_MatrixElement,
97+
9498
/// The entity being initialized is a field of block descriptor for
9599
/// the copied-in c++ object.
96100
EK_BlockElement,
@@ -205,8 +209,8 @@ class alignas(8) InitializedEntity {
205209
/// virtual base.
206210
llvm::PointerIntPair<const CXXBaseSpecifier *, 1> Base;
207211

208-
/// When Kind == EK_ArrayElement, EK_VectorElement, or
209-
/// EK_ComplexElement, the index of the array or vector element being
212+
/// When Kind == EK_ArrayElement, EK_VectorElement, or EK_MatrixElement,
213+
/// or EK_ComplexElement, the index of the array or vector element being
210214
/// initialized.
211215
unsigned Index;
212216

@@ -536,15 +540,15 @@ class alignas(8) InitializedEntity {
536540
/// element's index.
537541
unsigned getElementIndex() const {
538542
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
539-
getKind() == EK_ComplexElement);
543+
getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
540544
return Index;
541545
}
542546

543547
/// If this is already the initializer for an array or vector
544548
/// element, sets the element index.
545549
void setElementIndex(unsigned Index) {
546550
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
547-
getKind() == EK_ComplexElement);
551+
getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
548552
this->Index = Index;
549553
}
550554

clang/lib/Sema/CheckExprLifetime.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ getEntityLifetime(const InitializedEntity *Entity,
155155
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
156156
case InitializedEntity::EK_LambdaCapture:
157157
case InitializedEntity::EK_VectorElement:
158+
case InitializedEntity::EK_MatrixElement:
158159
case InitializedEntity::EK_ComplexElement:
159160
return {nullptr, LK_FullExpression};
160161

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "clang/AST/Expr.h"
2222
#include "clang/AST/HLSLResource.h"
2323
#include "clang/AST/Type.h"
24+
#include "clang/AST/TypeBase.h"
2425
#include "clang/AST/TypeLoc.h"
2526
#include "clang/Basic/Builtins.h"
2627
#include "clang/Basic/DiagnosticSema.h"
@@ -3323,6 +3324,11 @@ static void BuildFlattenedTypeList(QualType BaseTy,
33233324
List.insert(List.end(), VT->getNumElements(), VT->getElementType());
33243325
continue;
33253326
}
3327+
if (const auto *MT = dyn_cast<ConstantMatrixType>(T)) {
3328+
List.insert(List.end(), MT->getNumElementsFlattened(),
3329+
MT->getElementType());
3330+
continue;
3331+
}
33263332
if (const auto *RD = T->getAsCXXRecordDecl()) {
33273333
if (RD->isStandardLayout())
33283334
RD = RD->getStandardLayoutBaseWithFields();
@@ -4124,6 +4130,32 @@ class InitListTransformer {
41244130
}
41254131
return true;
41264132
}
4133+
if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
4134+
unsigned Rows = MTy->getNumRows();
4135+
unsigned Cols = MTy->getNumColumns();
4136+
QualType ElemTy = MTy->getElementType();
4137+
4138+
for (unsigned C = 0; C < Cols; ++C) {
4139+
for (unsigned R = 0; R < Rows; ++R) {
4140+
// row index literal
4141+
Expr *RowIdx = IntegerLiteral::Create(
4142+
Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), R), Ctx.IntTy,
4143+
E->getBeginLoc());
4144+
// column index literal
4145+
Expr *ColIdx = IntegerLiteral::Create(
4146+
Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), C), Ctx.IntTy,
4147+
E->getBeginLoc());
4148+
ExprResult ElExpr = S.CreateBuiltinMatrixSubscriptExpr(
4149+
E, RowIdx, ColIdx, E->getEndLoc());
4150+
if (ElExpr.isInvalid())
4151+
return false;
4152+
if (!buildInitializerListImpl(ElExpr.get()))
4153+
return false;
4154+
ElExpr.get()->setType(ElemTy);
4155+
}
4156+
}
4157+
return true;
4158+
}
41274159

41284160
if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
41294161
uint64_t Size = ArrTy->getZExtSize();
@@ -4177,14 +4209,17 @@ class InitListTransformer {
41774209
return *(ArgIt++);
41784210

41794211
llvm::SmallVector<Expr *> Inits;
4180-
assert(!isa<MatrixType>(Ty) && "Matrix types not yet supported in HLSL");
41814212
Ty = Ty.getDesugaredType(Ctx);
4182-
if (Ty->isVectorType() || Ty->isConstantArrayType()) {
4213+
if (Ty->isVectorType() || Ty->isConstantArrayType() ||
4214+
Ty->isConstantMatrixType()) {
41834215
QualType ElTy;
41844216
uint64_t Size = 0;
41854217
if (auto *ATy = Ty->getAs<VectorType>()) {
41864218
ElTy = ATy->getElementType();
41874219
Size = ATy->getNumElements();
4220+
} else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
4221+
ElTy = CMTy->getElementType();
4222+
Size = CMTy->getNumElementsFlattened();
41884223
} else {
41894224
auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
41904225
ElTy = VTy->getElementType();

0 commit comments

Comments
 (0)