-
Notifications
You must be signed in to change notification settings - Fork 14.9k
Support HLSL matrix initializers #160960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support HLSL matrix initializers #160960
Conversation
@llvm/pr-subscribers-clang Author: Farzon Lotfi (farzonl) Changesfixes #159434 In HLSL matrices are This supports the following HLSL syntax: Patch is 47.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160960.diff 6 Files Affected:
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index bd0e53d3086b0..19e4c548d0208 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -6543,9 +6543,9 @@ def warn_extern_init : Warning<"'extern' variable has an initializer">,
def err_variable_object_no_init : Error<
"variable-sized object may not be initialized">;
def err_excess_initializers : Error<
- "excess elements in %select{array|vector|scalar|union|struct}0 initializer">;
+ "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">;
def ext_excess_initializers : ExtWarn<
- "excess elements in %select{array|vector|scalar|union|struct}0 initializer">,
+ "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">,
InGroup<ExcessInitializers>;
def err_excess_initializers_for_sizeless_type : Error<
"excess elements in initializer for indivisible sizeless type %0">;
@@ -11086,8 +11086,8 @@ def err_first_argument_to_cwsc_pdtor_call : Error<
def err_second_argument_to_cwsc_not_pointer : Error<
"second argument to __builtin_call_with_static_chain must be of pointer type">;
-def err_vector_incorrect_num_elements : Error<
- "%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">;
+def err_tensor_incorrect_num_elements : Error<
+ "%select{too many|too few}0 elements in %select{vector|matrix}1 %select{initialization|operand}4 (expected %2 elements, have %3)">;
def err_altivec_empty_initializer : Error<"expected initializer">;
def err_vector_incorrect_bit_count : Error<
diff --git a/clang/include/clang/Sema/Initialization.h b/clang/include/clang/Sema/Initialization.h
index d7675ea153afb..865b6428f3081 100644
--- a/clang/include/clang/Sema/Initialization.h
+++ b/clang/include/clang/Sema/Initialization.h
@@ -91,6 +91,10 @@ class alignas(8) InitializedEntity {
/// or vector.
EK_VectorElement,
+ /// The entity being initialized is an element of a matrix.
+ /// or matrix.
+ EK_MatrixElement,
+
/// The entity being initialized is a field of block descriptor for
/// the copied-in c++ object.
EK_BlockElement,
@@ -205,8 +209,8 @@ class alignas(8) InitializedEntity {
/// virtual base.
llvm::PointerIntPair<const CXXBaseSpecifier *, 1> Base;
- /// When Kind == EK_ArrayElement, EK_VectorElement, or
- /// EK_ComplexElement, the index of the array or vector element being
+ /// When Kind == EK_ArrayElement, EK_VectorElement, or EK_MatrixElement,
+ /// or EK_ComplexElement, the index of the array or vector element being
/// initialized.
unsigned Index;
@@ -536,7 +540,7 @@ class alignas(8) InitializedEntity {
/// element's index.
unsigned getElementIndex() const {
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
- getKind() == EK_ComplexElement);
+ getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
return Index;
}
@@ -544,7 +548,7 @@ class alignas(8) InitializedEntity {
/// element, sets the element index.
void setElementIndex(unsigned Index) {
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
- getKind() == EK_ComplexElement);
+ getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
this->Index = Index;
}
diff --git a/clang/lib/Sema/CheckExprLifetime.cpp b/clang/lib/Sema/CheckExprLifetime.cpp
index e02e00231e58e..d647fbf007838 100644
--- a/clang/lib/Sema/CheckExprLifetime.cpp
+++ b/clang/lib/Sema/CheckExprLifetime.cpp
@@ -154,6 +154,7 @@ getEntityLifetime(const InitializedEntity *Entity,
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
case InitializedEntity::EK_LambdaCapture:
case InitializedEntity::EK_VectorElement:
+ case clang::InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
return {nullptr, LK_FullExpression};
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index c97129336736b..0f06d715600b3 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -17,6 +17,7 @@
#include "clang/AST/ExprCXX.h"
#include "clang/AST/ExprObjC.h"
#include "clang/AST/IgnoreExpr.h"
+#include "clang/AST/TypeBase.h"
#include "clang/AST/TypeLoc.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Basic/Specifiers.h"
@@ -403,6 +404,9 @@ class InitListChecker {
unsigned &Index,
InitListExpr *StructuredList,
unsigned &StructuredIndex);
+ void CheckMatrixType(const InitializedEntity &Entity, InitListExpr *IList,
+ QualType DeclType, unsigned &Index,
+ InitListExpr *StructuredList, unsigned &StructuredIndex);
void CheckVectorType(const InitializedEntity &Entity,
InitListExpr *IList, QualType DeclType, unsigned &Index,
InitListExpr *StructuredList,
@@ -1003,7 +1007,8 @@ InitListChecker::FillInEmptyInitializations(const InitializedEntity &Entity,
return;
if (ElementEntity.getKind() == InitializedEntity::EK_ArrayElement ||
- ElementEntity.getKind() == InitializedEntity::EK_VectorElement)
+ ElementEntity.getKind() == InitializedEntity::EK_VectorElement ||
+ ElementEntity.getKind() == InitializedEntity::EK_MatrixElement)
ElementEntity.setElementIndex(Init);
if (Init >= NumInits && (ILE->hasArrayFiller() || SkipEmptyInitChecks))
@@ -1273,6 +1278,7 @@ static void warnBracedScalarInit(Sema &S, const InitializedEntity &Entity,
switch (Entity.getKind()) {
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_ArrayElement:
case InitializedEntity::EK_Parameter:
@@ -1372,11 +1378,12 @@ void InitListChecker::CheckExplicitInitList(const InitializedEntity &Entity,
SemaRef.Diag(IList->getInit(Index)->getBeginLoc(), DK)
<< T << IList->getInit(Index)->getSourceRange();
} else {
- int initKind = T->isArrayType() ? 0 :
- T->isVectorType() ? 1 :
- T->isScalarType() ? 2 :
- T->isUnionType() ? 3 :
- 4;
+ int initKind = T->isArrayType() ? 0
+ : T->isVectorType() ? 1
+ : T->isMatrixType() ? 2
+ : T->isScalarType() ? 3
+ : T->isUnionType() ? 4
+ : 5;
unsigned DK = ExtraInitsIsError ? diag::err_excess_initializers
: diag::ext_excess_initializers;
@@ -1430,6 +1437,9 @@ void InitListChecker::CheckListElementTypes(const InitializedEntity &Entity,
} else if (DeclType->isVectorType()) {
CheckVectorType(Entity, IList, DeclType, Index,
StructuredList, StructuredIndex);
+ } else if (DeclType->isMatrixType()) {
+ CheckMatrixType(Entity, IList, DeclType, Index, StructuredList,
+ StructuredIndex);
} else if (const RecordDecl *RD = DeclType->getAsRecordDecl()) {
auto Bases =
CXXRecordDecl::base_class_const_range(CXXRecordDecl::base_class_const_iterator(),
@@ -1877,6 +1887,93 @@ void InitListChecker::CheckReferenceType(const InitializedEntity &Entity,
AggrDeductionCandidateParamTypes->push_back(DeclType);
}
+void InitListChecker::CheckMatrixType(const InitializedEntity &Entity,
+ InitListExpr *IList, QualType DeclType,
+ unsigned &Index,
+ InitListExpr *StructuredList,
+ unsigned &StructuredIndex) {
+ if (!SemaRef.getLangOpts().HLSL)
+ return;
+
+ const ConstantMatrixType *MT = DeclType->castAs<ConstantMatrixType>();
+ QualType ElemTy = MT->getElementType();
+ const unsigned Rows = MT->getNumRows();
+ const unsigned Cols = MT->getNumColumns();
+ const unsigned MaxElts = Rows * Cols;
+
+ unsigned NumEltsInit = 0;
+ InitializedEntity ElemEnt =
+ InitializedEntity::InitializeElement(SemaRef.Context, 0, Entity);
+
+ // A Matrix initalizer should be able to take scalars, vectors, and matrices.
+ auto HandleInit = [&](InitListExpr *List, unsigned &Idx) {
+ Expr *Init = List->getInit(Idx);
+ QualType ITy = Init->getType();
+
+ if (ITy->isVectorType()) {
+ const VectorType *IVT = ITy->castAs<VectorType>();
+ unsigned N = IVT->getNumElements();
+ QualType VTy =
+ ITy->isExtVectorType()
+ ? SemaRef.Context.getExtVectorType(ElemTy, N)
+ : SemaRef.Context.getVectorType(ElemTy, N, IVT->getVectorKind());
+ ElemEnt.setElementIndex(Idx);
+ CheckSubElementType(ElemEnt, List, VTy, Idx, StructuredList,
+ StructuredIndex);
+ NumEltsInit += N;
+ return;
+ }
+
+ if (ITy->isMatrixType()) {
+ const ConstantMatrixType *IMT = ITy->castAs<ConstantMatrixType>();
+ unsigned N = IMT->getNumRows() * IMT->getNumColumns();
+ QualType MTy = SemaRef.Context.getConstantMatrixType(
+ ElemTy, IMT->getNumRows(), IMT->getNumColumns());
+ ElemEnt.setElementIndex(Idx);
+ CheckSubElementType(ElemEnt, List, MTy, Idx, StructuredList,
+ StructuredIndex);
+ NumEltsInit += N;
+ return;
+ }
+
+ // Scalar element
+ ElemEnt.setElementIndex(Idx);
+ CheckSubElementType(ElemEnt, List, ElemTy, Idx, StructuredList,
+ StructuredIndex);
+ ++NumEltsInit;
+ };
+
+ // Column-major: each top-level sublist is treated as a column.
+ while (NumEltsInit < MaxElts && Index < IList->getNumInits()) {
+ Expr *Init = IList->getInit(Index);
+
+ if (auto *SubList = dyn_cast<InitListExpr>(Init)) {
+ unsigned SubIdx = 0;
+ unsigned Row = 0;
+ while (Row < Rows && SubIdx < SubList->getNumInits() &&
+ NumEltsInit < MaxElts) {
+ HandleInit(SubList, SubIdx);
+ ++Row;
+ }
+ ++Index; // advance past this column sublist
+ continue;
+ }
+
+ // Not a sublist: just consume directly.
+ HandleInit(IList, Index);
+ }
+
+ // HLSL requires exactly Rows*Cols initializers after flattening.
+ if (NumEltsInit != MaxElts) {
+ if (!VerifyOnly)
+ SemaRef.Diag(IList->getBeginLoc(),
+ diag::err_tensor_incorrect_num_elements)
+ << (NumEltsInit < MaxElts) << /*matrix*/ 1 << MaxElts << NumEltsInit
+ << /*initialization*/ 0;
+ hadError = true;
+ }
+}
+
void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
InitListExpr *IList, QualType DeclType,
unsigned &Index,
@@ -2026,9 +2123,9 @@ void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
if (numEltsInit != maxElements) {
if (!VerifyOnly)
SemaRef.Diag(IList->getBeginLoc(),
- diag::err_vector_incorrect_num_elements)
- << (numEltsInit < maxElements) << maxElements << numEltsInit
- << /*initialization*/ 0;
+ diag::err_tensor_incorrect_num_elements)
+ << (numEltsInit < maxElements) << /*vector*/ 0 << maxElements
+ << numEltsInit << /*initialization*/ 0;
hadError = true;
}
}
@@ -3639,6 +3736,9 @@ InitializedEntity::InitializedEntity(ASTContext &Context, unsigned Index,
} else if (const VectorType *VT = Parent.getType()->getAs<VectorType>()) {
Kind = EK_VectorElement;
Type = VT->getElementType();
+ } else if (const MatrixType *MT = Parent.getType()->getAs<MatrixType>()) {
+ Kind = EK_MatrixElement;
+ Type = MT->getElementType();
} else {
const ComplexType *CT = Parent.getType()->getAs<ComplexType>();
assert(CT && "Unexpected type");
@@ -3687,6 +3787,7 @@ DeclarationName InitializedEntity::getName() const {
case EK_Delegating:
case EK_ArrayElement:
case EK_VectorElement:
+ case EK_MatrixElement:
case EK_ComplexElement:
case EK_BlockElement:
case EK_LambdaToBlockConversionBlockElement:
@@ -3720,6 +3821,7 @@ ValueDecl *InitializedEntity::getDecl() const {
case EK_Delegating:
case EK_ArrayElement:
case EK_VectorElement:
+ case EK_MatrixElement:
case EK_ComplexElement:
case EK_BlockElement:
case EK_LambdaToBlockConversionBlockElement:
@@ -3753,6 +3855,7 @@ bool InitializedEntity::allowsNRVO() const {
case EK_Delegating:
case EK_ArrayElement:
case EK_VectorElement:
+ case EK_MatrixElement:
case EK_ComplexElement:
case EK_BlockElement:
case EK_LambdaToBlockConversionBlockElement:
@@ -3792,6 +3895,9 @@ unsigned InitializedEntity::dumpImpl(raw_ostream &OS) const {
case EK_Delegating: OS << "Delegating"; break;
case EK_ArrayElement: OS << "ArrayElement " << Index; break;
case EK_VectorElement: OS << "VectorElement " << Index; break;
+ case EK_MatrixElement:
+ OS << "MatrixElement " << Index;
+ break;
case EK_ComplexElement: OS << "ComplexElement " << Index; break;
case EK_BlockElement: OS << "Block"; break;
case EK_LambdaToBlockConversionBlockElement:
@@ -6847,6 +6953,67 @@ void InitializationSequence::InitializeFrom(Sema &S,
return;
}
+ if (S.getLangOpts().HLSL && DestType->isMatrixType() &&
+ (SourceType.isNull() ||
+ !Context.hasSameUnqualifiedType(SourceType, DestType))) {
+
+ llvm::SmallVector<Expr *> InitArgs;
+
+ for (Expr *Arg : Args) {
+ QualType AT = Arg->getType();
+
+ // Expand matrix arguments element-by-element (col-major).
+ if (AT->isMatrixType()) {
+ const auto *MTy = AT->castAs<ConstantMatrixType>();
+ unsigned Rows = MTy->getNumRows();
+ unsigned Cols = MTy->getNumColumns();
+ QualType ElemTy = MTy->getElementType();
+
+ for (unsigned c = 0; c < Cols; ++c) {
+ for (unsigned r = 0; r < Rows; ++r) {
+ // row index literal
+ Expr *RowIdx = IntegerLiteral::Create(
+ Context, llvm::APInt(Context.getIntWidth(Context.IntTy), r),
+ Context.IntTy, Arg->getBeginLoc());
+ // column index literal
+ Expr *ColIdx = IntegerLiteral::Create(
+ Context, llvm::APInt(Context.getIntWidth(Context.IntTy), c),
+ Context.IntTy, Arg->getBeginLoc());
+
+ InitArgs.emplace_back(new (Context) MatrixSubscriptExpr(
+ Arg, RowIdx, ColIdx, ElemTy, Arg->getEndLoc()));
+ }
+ }
+
+ // Keep your vector expansion, in case vectors appear in the argument
+ // list.
+ } else if (AT->isExtVectorType()) {
+ const auto *VTy = AT->castAs<ExtVectorType>();
+ unsigned Elm = VTy->getNumElements();
+ for (unsigned Idx = 0; Idx < Elm; ++Idx) {
+ InitArgs.emplace_back(new (Context) ArraySubscriptExpr(
+ Arg,
+ IntegerLiteral::Create(
+ Context, llvm::APInt(Context.getIntWidth(Context.IntTy), Idx),
+ Context.IntTy, Arg->getBeginLoc()),
+ VTy->getElementType(), Arg->getValueKind(), Arg->getObjectKind(),
+ Arg->getEndLoc()));
+ }
+
+ } else {
+ // Scalar or other: forward as-is
+ InitArgs.emplace_back(Arg);
+ }
+ }
+
+ InitListExpr *ILE = new (Context) InitListExpr(
+ S.getASTContext(), SourceLocation(), InitArgs, SourceLocation());
+
+ Args[0] = ILE;
+ AddListInitializationStep(DestType);
+ return;
+ }
+
// The remaining cases all need a source type.
if (Args.size() > 1) {
SetFailed(FK_TooManyInitsForScalar);
@@ -6999,6 +7166,7 @@ static AssignmentAction getAssignmentAction(const InitializedEntity &Entity,
case InitializedEntity::EK_Binding:
case InitializedEntity::EK_ArrayElement:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_BlockElement:
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7024,6 +7192,7 @@ static bool shouldBindAsTemporary(const InitializedEntity &Entity) {
case InitializedEntity::EK_Base:
case InitializedEntity::EK_Delegating:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_Exception:
case InitializedEntity::EK_BlockElement:
@@ -7054,6 +7223,7 @@ static bool shouldDestroyEntity(const InitializedEntity &Entity) {
case InitializedEntity::EK_Base:
case InitializedEntity::EK_Delegating:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_BlockElement:
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7107,6 +7277,7 @@ static SourceLocation getInitializationLoc(const InitializedEntity &Entity,
case InitializedEntity::EK_Base:
case InitializedEntity::EK_Delegating:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_BlockElement:
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7858,9 +8029,11 @@ ExprResult InitializationSequence::Perform(Sema &S,
// HLSL allows vector initialization to function like list initialization, but
// use the syntax of a C++-like constructor.
- bool IsHLSLVectorInit = S.getLangOpts().HLSL && DestType->isExtVectorType() &&
- isa<InitListExpr>(Args[0]);
- (void)IsHLSLVectorInit;
+ bool IsHLSLVectorOrMatrixInit =
+ S.getLangOpts().HLSL &&
+ (DestType->isExtVectorType() || DestType->isMatrixType()) &&
+ isa<InitListExpr>(Args[0]);
+ (void)IsHLSLVectorOrMatrixInit;
// For initialization steps that start with a single initializer,
// grab the only argument out the Args and place it into the "current"
@@ -7899,7 +8072,7 @@ ExprResult InitializationSequence::Perform(Sema &S,
case SK_StdInitializerList:
case SK_OCLSamplerInit:
case SK_OCLZeroOpaqueType: {
- assert(Args.size() == 1 || IsHLSLVectorInit);
+ assert(Args.size() == 1 || IsHLSLVectorOrMatrixInit);
CurInit = Args[0];
if (!CurInit.get()) return ExprError();
break;
diff --git a/clang/test/AST/HLSL/matrix-constructors.hlsl b/clang/test/AST/HLSL/matrix-constructors.hlsl
new file mode 100644
index 0000000000000..faee0162a314b
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-constructors.hlsl
@@ -0,0 +1,338 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+typedef float float2x1 __attribute__((matrix_type(2,1)));
+typedef float float2x3 __attribute__((matrix_type(2,3)));
+typedef float float2x2 __attribute__((matrix_type(2,2)));
+typedef float float2 __attribute__((ext_vector_type(2)));
+typedef float float4 __attribute__((ext_vector_type(4)));
+
+[numthreads(1,1,1)]
+void ok() {
+
+ // CHECK: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:36> col:12 A 'float2x3':'matrix<float, 2, 3>' cinit
+ // CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} <col:16, col:36> 'float2x3':'matrix<float, 2, 3>' functional cast to float2x3 <NoOp>
+ // CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} <col:25, col:35> 'float2x3':'matrix<float, 2, 3>'
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:25> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:25> 'int' 1
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:27> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:27> 'int' 2
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:29> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:29> 'int' 3
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:31> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:31> 'int' 4
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:33> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:33> ...
[truncated]
|
@llvm/pr-subscribers-hlsl Author: Farzon Lotfi (farzonl) Changesfixes #159434 In HLSL matrices are This supports the following HLSL syntax: Patch is 47.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160960.diff 6 Files Affected:
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index bd0e53d3086b0..19e4c548d0208 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -6543,9 +6543,9 @@ def warn_extern_init : Warning<"'extern' variable has an initializer">,
def err_variable_object_no_init : Error<
"variable-sized object may not be initialized">;
def err_excess_initializers : Error<
- "excess elements in %select{array|vector|scalar|union|struct}0 initializer">;
+ "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">;
def ext_excess_initializers : ExtWarn<
- "excess elements in %select{array|vector|scalar|union|struct}0 initializer">,
+ "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">,
InGroup<ExcessInitializers>;
def err_excess_initializers_for_sizeless_type : Error<
"excess elements in initializer for indivisible sizeless type %0">;
@@ -11086,8 +11086,8 @@ def err_first_argument_to_cwsc_pdtor_call : Error<
def err_second_argument_to_cwsc_not_pointer : Error<
"second argument to __builtin_call_with_static_chain must be of pointer type">;
-def err_vector_incorrect_num_elements : Error<
- "%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">;
+def err_tensor_incorrect_num_elements : Error<
+ "%select{too many|too few}0 elements in %select{vector|matrix}1 %select{initialization|operand}4 (expected %2 elements, have %3)">;
def err_altivec_empty_initializer : Error<"expected initializer">;
def err_vector_incorrect_bit_count : Error<
diff --git a/clang/include/clang/Sema/Initialization.h b/clang/include/clang/Sema/Initialization.h
index d7675ea153afb..865b6428f3081 100644
--- a/clang/include/clang/Sema/Initialization.h
+++ b/clang/include/clang/Sema/Initialization.h
@@ -91,6 +91,10 @@ class alignas(8) InitializedEntity {
/// or vector.
EK_VectorElement,
+ /// The entity being initialized is an element of a matrix.
+ /// or matrix.
+ EK_MatrixElement,
+
/// The entity being initialized is a field of block descriptor for
/// the copied-in c++ object.
EK_BlockElement,
@@ -205,8 +209,8 @@ class alignas(8) InitializedEntity {
/// virtual base.
llvm::PointerIntPair<const CXXBaseSpecifier *, 1> Base;
- /// When Kind == EK_ArrayElement, EK_VectorElement, or
- /// EK_ComplexElement, the index of the array or vector element being
+ /// When Kind == EK_ArrayElement, EK_VectorElement, or EK_MatrixElement,
+ /// or EK_ComplexElement, the index of the array or vector element being
/// initialized.
unsigned Index;
@@ -536,7 +540,7 @@ class alignas(8) InitializedEntity {
/// element's index.
unsigned getElementIndex() const {
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
- getKind() == EK_ComplexElement);
+ getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
return Index;
}
@@ -544,7 +548,7 @@ class alignas(8) InitializedEntity {
/// element, sets the element index.
void setElementIndex(unsigned Index) {
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
- getKind() == EK_ComplexElement);
+ getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
this->Index = Index;
}
diff --git a/clang/lib/Sema/CheckExprLifetime.cpp b/clang/lib/Sema/CheckExprLifetime.cpp
index e02e00231e58e..d647fbf007838 100644
--- a/clang/lib/Sema/CheckExprLifetime.cpp
+++ b/clang/lib/Sema/CheckExprLifetime.cpp
@@ -154,6 +154,7 @@ getEntityLifetime(const InitializedEntity *Entity,
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
case InitializedEntity::EK_LambdaCapture:
case InitializedEntity::EK_VectorElement:
+ case clang::InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
return {nullptr, LK_FullExpression};
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index c97129336736b..0f06d715600b3 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -17,6 +17,7 @@
#include "clang/AST/ExprCXX.h"
#include "clang/AST/ExprObjC.h"
#include "clang/AST/IgnoreExpr.h"
+#include "clang/AST/TypeBase.h"
#include "clang/AST/TypeLoc.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Basic/Specifiers.h"
@@ -403,6 +404,9 @@ class InitListChecker {
unsigned &Index,
InitListExpr *StructuredList,
unsigned &StructuredIndex);
+ void CheckMatrixType(const InitializedEntity &Entity, InitListExpr *IList,
+ QualType DeclType, unsigned &Index,
+ InitListExpr *StructuredList, unsigned &StructuredIndex);
void CheckVectorType(const InitializedEntity &Entity,
InitListExpr *IList, QualType DeclType, unsigned &Index,
InitListExpr *StructuredList,
@@ -1003,7 +1007,8 @@ InitListChecker::FillInEmptyInitializations(const InitializedEntity &Entity,
return;
if (ElementEntity.getKind() == InitializedEntity::EK_ArrayElement ||
- ElementEntity.getKind() == InitializedEntity::EK_VectorElement)
+ ElementEntity.getKind() == InitializedEntity::EK_VectorElement ||
+ ElementEntity.getKind() == InitializedEntity::EK_MatrixElement)
ElementEntity.setElementIndex(Init);
if (Init >= NumInits && (ILE->hasArrayFiller() || SkipEmptyInitChecks))
@@ -1273,6 +1278,7 @@ static void warnBracedScalarInit(Sema &S, const InitializedEntity &Entity,
switch (Entity.getKind()) {
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_ArrayElement:
case InitializedEntity::EK_Parameter:
@@ -1372,11 +1378,12 @@ void InitListChecker::CheckExplicitInitList(const InitializedEntity &Entity,
SemaRef.Diag(IList->getInit(Index)->getBeginLoc(), DK)
<< T << IList->getInit(Index)->getSourceRange();
} else {
- int initKind = T->isArrayType() ? 0 :
- T->isVectorType() ? 1 :
- T->isScalarType() ? 2 :
- T->isUnionType() ? 3 :
- 4;
+ int initKind = T->isArrayType() ? 0
+ : T->isVectorType() ? 1
+ : T->isMatrixType() ? 2
+ : T->isScalarType() ? 3
+ : T->isUnionType() ? 4
+ : 5;
unsigned DK = ExtraInitsIsError ? diag::err_excess_initializers
: diag::ext_excess_initializers;
@@ -1430,6 +1437,9 @@ void InitListChecker::CheckListElementTypes(const InitializedEntity &Entity,
} else if (DeclType->isVectorType()) {
CheckVectorType(Entity, IList, DeclType, Index,
StructuredList, StructuredIndex);
+ } else if (DeclType->isMatrixType()) {
+ CheckMatrixType(Entity, IList, DeclType, Index, StructuredList,
+ StructuredIndex);
} else if (const RecordDecl *RD = DeclType->getAsRecordDecl()) {
auto Bases =
CXXRecordDecl::base_class_const_range(CXXRecordDecl::base_class_const_iterator(),
@@ -1877,6 +1887,93 @@ void InitListChecker::CheckReferenceType(const InitializedEntity &Entity,
AggrDeductionCandidateParamTypes->push_back(DeclType);
}
+void InitListChecker::CheckMatrixType(const InitializedEntity &Entity,
+ InitListExpr *IList, QualType DeclType,
+ unsigned &Index,
+ InitListExpr *StructuredList,
+ unsigned &StructuredIndex) {
+ if (!SemaRef.getLangOpts().HLSL)
+ return;
+
+ const ConstantMatrixType *MT = DeclType->castAs<ConstantMatrixType>();
+ QualType ElemTy = MT->getElementType();
+ const unsigned Rows = MT->getNumRows();
+ const unsigned Cols = MT->getNumColumns();
+ const unsigned MaxElts = Rows * Cols;
+
+ unsigned NumEltsInit = 0;
+ InitializedEntity ElemEnt =
+ InitializedEntity::InitializeElement(SemaRef.Context, 0, Entity);
+
+ // A Matrix initalizer should be able to take scalars, vectors, and matrices.
+ auto HandleInit = [&](InitListExpr *List, unsigned &Idx) {
+ Expr *Init = List->getInit(Idx);
+ QualType ITy = Init->getType();
+
+ if (ITy->isVectorType()) {
+ const VectorType *IVT = ITy->castAs<VectorType>();
+ unsigned N = IVT->getNumElements();
+ QualType VTy =
+ ITy->isExtVectorType()
+ ? SemaRef.Context.getExtVectorType(ElemTy, N)
+ : SemaRef.Context.getVectorType(ElemTy, N, IVT->getVectorKind());
+ ElemEnt.setElementIndex(Idx);
+ CheckSubElementType(ElemEnt, List, VTy, Idx, StructuredList,
+ StructuredIndex);
+ NumEltsInit += N;
+ return;
+ }
+
+ if (ITy->isMatrixType()) {
+ const ConstantMatrixType *IMT = ITy->castAs<ConstantMatrixType>();
+ unsigned N = IMT->getNumRows() * IMT->getNumColumns();
+ QualType MTy = SemaRef.Context.getConstantMatrixType(
+ ElemTy, IMT->getNumRows(), IMT->getNumColumns());
+ ElemEnt.setElementIndex(Idx);
+ CheckSubElementType(ElemEnt, List, MTy, Idx, StructuredList,
+ StructuredIndex);
+ NumEltsInit += N;
+ return;
+ }
+
+ // Scalar element
+ ElemEnt.setElementIndex(Idx);
+ CheckSubElementType(ElemEnt, List, ElemTy, Idx, StructuredList,
+ StructuredIndex);
+ ++NumEltsInit;
+ };
+
+ // Column-major: each top-level sublist is treated as a column.
+ while (NumEltsInit < MaxElts && Index < IList->getNumInits()) {
+ Expr *Init = IList->getInit(Index);
+
+ if (auto *SubList = dyn_cast<InitListExpr>(Init)) {
+ unsigned SubIdx = 0;
+ unsigned Row = 0;
+ while (Row < Rows && SubIdx < SubList->getNumInits() &&
+ NumEltsInit < MaxElts) {
+ HandleInit(SubList, SubIdx);
+ ++Row;
+ }
+ ++Index; // advance past this column sublist
+ continue;
+ }
+
+ // Not a sublist: just consume directly.
+ HandleInit(IList, Index);
+ }
+
+ // HLSL requires exactly Rows*Cols initializers after flattening.
+ if (NumEltsInit != MaxElts) {
+ if (!VerifyOnly)
+ SemaRef.Diag(IList->getBeginLoc(),
+ diag::err_tensor_incorrect_num_elements)
+ << (NumEltsInit < MaxElts) << /*matrix*/ 1 << MaxElts << NumEltsInit
+ << /*initialization*/ 0;
+ hadError = true;
+ }
+}
+
void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
InitListExpr *IList, QualType DeclType,
unsigned &Index,
@@ -2026,9 +2123,9 @@ void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
if (numEltsInit != maxElements) {
if (!VerifyOnly)
SemaRef.Diag(IList->getBeginLoc(),
- diag::err_vector_incorrect_num_elements)
- << (numEltsInit < maxElements) << maxElements << numEltsInit
- << /*initialization*/ 0;
+ diag::err_tensor_incorrect_num_elements)
+ << (numEltsInit < maxElements) << /*vector*/ 0 << maxElements
+ << numEltsInit << /*initialization*/ 0;
hadError = true;
}
}
@@ -3639,6 +3736,9 @@ InitializedEntity::InitializedEntity(ASTContext &Context, unsigned Index,
} else if (const VectorType *VT = Parent.getType()->getAs<VectorType>()) {
Kind = EK_VectorElement;
Type = VT->getElementType();
+ } else if (const MatrixType *MT = Parent.getType()->getAs<MatrixType>()) {
+ Kind = EK_MatrixElement;
+ Type = MT->getElementType();
} else {
const ComplexType *CT = Parent.getType()->getAs<ComplexType>();
assert(CT && "Unexpected type");
@@ -3687,6 +3787,7 @@ DeclarationName InitializedEntity::getName() const {
case EK_Delegating:
case EK_ArrayElement:
case EK_VectorElement:
+ case EK_MatrixElement:
case EK_ComplexElement:
case EK_BlockElement:
case EK_LambdaToBlockConversionBlockElement:
@@ -3720,6 +3821,7 @@ ValueDecl *InitializedEntity::getDecl() const {
case EK_Delegating:
case EK_ArrayElement:
case EK_VectorElement:
+ case EK_MatrixElement:
case EK_ComplexElement:
case EK_BlockElement:
case EK_LambdaToBlockConversionBlockElement:
@@ -3753,6 +3855,7 @@ bool InitializedEntity::allowsNRVO() const {
case EK_Delegating:
case EK_ArrayElement:
case EK_VectorElement:
+ case EK_MatrixElement:
case EK_ComplexElement:
case EK_BlockElement:
case EK_LambdaToBlockConversionBlockElement:
@@ -3792,6 +3895,9 @@ unsigned InitializedEntity::dumpImpl(raw_ostream &OS) const {
case EK_Delegating: OS << "Delegating"; break;
case EK_ArrayElement: OS << "ArrayElement " << Index; break;
case EK_VectorElement: OS << "VectorElement " << Index; break;
+ case EK_MatrixElement:
+ OS << "MatrixElement " << Index;
+ break;
case EK_ComplexElement: OS << "ComplexElement " << Index; break;
case EK_BlockElement: OS << "Block"; break;
case EK_LambdaToBlockConversionBlockElement:
@@ -6847,6 +6953,67 @@ void InitializationSequence::InitializeFrom(Sema &S,
return;
}
+ if (S.getLangOpts().HLSL && DestType->isMatrixType() &&
+ (SourceType.isNull() ||
+ !Context.hasSameUnqualifiedType(SourceType, DestType))) {
+
+ llvm::SmallVector<Expr *> InitArgs;
+
+ for (Expr *Arg : Args) {
+ QualType AT = Arg->getType();
+
+ // Expand matrix arguments element-by-element (col-major).
+ if (AT->isMatrixType()) {
+ const auto *MTy = AT->castAs<ConstantMatrixType>();
+ unsigned Rows = MTy->getNumRows();
+ unsigned Cols = MTy->getNumColumns();
+ QualType ElemTy = MTy->getElementType();
+
+ for (unsigned c = 0; c < Cols; ++c) {
+ for (unsigned r = 0; r < Rows; ++r) {
+ // row index literal
+ Expr *RowIdx = IntegerLiteral::Create(
+ Context, llvm::APInt(Context.getIntWidth(Context.IntTy), r),
+ Context.IntTy, Arg->getBeginLoc());
+ // column index literal
+ Expr *ColIdx = IntegerLiteral::Create(
+ Context, llvm::APInt(Context.getIntWidth(Context.IntTy), c),
+ Context.IntTy, Arg->getBeginLoc());
+
+ InitArgs.emplace_back(new (Context) MatrixSubscriptExpr(
+ Arg, RowIdx, ColIdx, ElemTy, Arg->getEndLoc()));
+ }
+ }
+
+ // Keep your vector expansion, in case vectors appear in the argument
+ // list.
+ } else if (AT->isExtVectorType()) {
+ const auto *VTy = AT->castAs<ExtVectorType>();
+ unsigned Elm = VTy->getNumElements();
+ for (unsigned Idx = 0; Idx < Elm; ++Idx) {
+ InitArgs.emplace_back(new (Context) ArraySubscriptExpr(
+ Arg,
+ IntegerLiteral::Create(
+ Context, llvm::APInt(Context.getIntWidth(Context.IntTy), Idx),
+ Context.IntTy, Arg->getBeginLoc()),
+ VTy->getElementType(), Arg->getValueKind(), Arg->getObjectKind(),
+ Arg->getEndLoc()));
+ }
+
+ } else {
+ // Scalar or other: forward as-is
+ InitArgs.emplace_back(Arg);
+ }
+ }
+
+ InitListExpr *ILE = new (Context) InitListExpr(
+ S.getASTContext(), SourceLocation(), InitArgs, SourceLocation());
+
+ Args[0] = ILE;
+ AddListInitializationStep(DestType);
+ return;
+ }
+
// The remaining cases all need a source type.
if (Args.size() > 1) {
SetFailed(FK_TooManyInitsForScalar);
@@ -6999,6 +7166,7 @@ static AssignmentAction getAssignmentAction(const InitializedEntity &Entity,
case InitializedEntity::EK_Binding:
case InitializedEntity::EK_ArrayElement:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_BlockElement:
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7024,6 +7192,7 @@ static bool shouldBindAsTemporary(const InitializedEntity &Entity) {
case InitializedEntity::EK_Base:
case InitializedEntity::EK_Delegating:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_Exception:
case InitializedEntity::EK_BlockElement:
@@ -7054,6 +7223,7 @@ static bool shouldDestroyEntity(const InitializedEntity &Entity) {
case InitializedEntity::EK_Base:
case InitializedEntity::EK_Delegating:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_BlockElement:
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7107,6 +7277,7 @@ static SourceLocation getInitializationLoc(const InitializedEntity &Entity,
case InitializedEntity::EK_Base:
case InitializedEntity::EK_Delegating:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_BlockElement:
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7858,9 +8029,11 @@ ExprResult InitializationSequence::Perform(Sema &S,
// HLSL allows vector initialization to function like list initialization, but
// use the syntax of a C++-like constructor.
- bool IsHLSLVectorInit = S.getLangOpts().HLSL && DestType->isExtVectorType() &&
- isa<InitListExpr>(Args[0]);
- (void)IsHLSLVectorInit;
+ bool IsHLSLVectorOrMatrixInit =
+ S.getLangOpts().HLSL &&
+ (DestType->isExtVectorType() || DestType->isMatrixType()) &&
+ isa<InitListExpr>(Args[0]);
+ (void)IsHLSLVectorOrMatrixInit;
// For initialization steps that start with a single initializer,
// grab the only argument out the Args and place it into the "current"
@@ -7899,7 +8072,7 @@ ExprResult InitializationSequence::Perform(Sema &S,
case SK_StdInitializerList:
case SK_OCLSamplerInit:
case SK_OCLZeroOpaqueType: {
- assert(Args.size() == 1 || IsHLSLVectorInit);
+ assert(Args.size() == 1 || IsHLSLVectorOrMatrixInit);
CurInit = Args[0];
if (!CurInit.get()) return ExprError();
break;
diff --git a/clang/test/AST/HLSL/matrix-constructors.hlsl b/clang/test/AST/HLSL/matrix-constructors.hlsl
new file mode 100644
index 0000000000000..faee0162a314b
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-constructors.hlsl
@@ -0,0 +1,338 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+typedef float float2x1 __attribute__((matrix_type(2,1)));
+typedef float float2x3 __attribute__((matrix_type(2,3)));
+typedef float float2x2 __attribute__((matrix_type(2,2)));
+typedef float float2 __attribute__((ext_vector_type(2)));
+typedef float float4 __attribute__((ext_vector_type(4)));
+
+[numthreads(1,1,1)]
+void ok() {
+
+ // CHECK: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:36> col:12 A 'float2x3':'matrix<float, 2, 3>' cinit
+ // CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} <col:16, col:36> 'float2x3':'matrix<float, 2, 3>' functional cast to float2x3 <NoOp>
+ // CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} <col:25, col:35> 'float2x3':'matrix<float, 2, 3>'
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:25> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:25> 'int' 1
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:27> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:27> 'int' 2
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:29> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:29> 'int' 3
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:31> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:31> 'int' 4
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:33> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:33> ...
[truncated]
|
0bf8626
to
3fbcceb
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
fixes llvm#159434 In HLSL matrices are `matrix_type` in all respects except that they support a constructor style syntax for initializing matrices. This change adds a translation of vector constructor arguments into initializer lists. This supports the following HLSL syntax: (1) HLSL matrices support constructor syntax (2) HLSL matrices are expanded to constituate components in constructor
3fbcceb
to
526c1cd
Compare
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:34> 'int' 1 | ||
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:36> 'float' <IntegralToFloating> | ||
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:36> 'int' 2 | ||
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:38> 'float' <IntegralToFloating> | ||
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:38> 'int' 3 | ||
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:40> 'float' <IntegralToFloating> | ||
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:40> 'int' 4 | ||
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:25> 'int' 1 | ||
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:25> 'int' 1 | ||
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:44> 'float' <IntegralToFloating> | ||
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:44> 'int' 5 | ||
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:47> 'float' <IntegralToFloating> | ||
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:47> 'int' 6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://godbolt.org/z/Wh1daeMYP
From the dxc godbolt seems like this needs to be 1,4,2,5,3,6
<6 x float> <float 1.000000e+00, float 4.000000e+00, float 2.000000e+00, float 5.000000e+00, float 3.000000e+00, float 6.000000e+00>
It seems like we are doing
[0,0], [1,0], [0,1], [1,1], 5, 6
1, 3, 2,4,5,6
Investigating failures revert to draft |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any public documentation for the HLSL syntax that could be linked?
InitializedEntity::InitializeElement(SemaRef.Context, 0, Entity); | ||
|
||
// A Matrix initalizer should be able to take scalars, vectors, and matrices. | ||
auto HandleInit = [&](InitListExpr *List, unsigned &Idx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to also support initialization syntax for regular matrix types, addressing the TODO from https://clang.llvm.org/docs/MatrixTypes.html#todos
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good idea. is there a spec for how clang matrix_type constructors are intended to work that I can compare against what HLSL does? the todos section is pretty light and doesn't indicate what is expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this approach probably needs some adjustment. HLSL allows constructor-like syntax but it behaves the same way the HLSL initializer list syntax does (note: It’s a CXX functional cast, not a constructor), so matrix initialization should be done through SemaHLSL’s initialization list handling.
Here’s an example of an odd case that really makes this special:
struct F {
float f[16];
};
export void fn() {
F f;
float4x4 M = float4x4(f);
}
https://godbolt.org/z/z9GoGvnrK
The language spec that describes aggregate initialization is available here.
Let me see what I can find. I'm basing my implementation off of what works in DXC. |
Will look into this thanks for spec reference. If I do things through SemaHLSL for matrices how do we want to support this request? Also why did we do vectors via SemaInit? Are vectors also using SemaHLSL’s initialization list handling? I'm assuming if so it came after the commit linked below |
I think we likely will want to handle C/C++ matrix initializers in a more sane way than the HLSL initializers need to work (at least for current HLSL). More on this below...
This commit predates our efforts to have spec language drafted before we make changes in Clang and is a pretty great example of why we are emphasizing testing and spec writing. Amusingly, that code actually does eventually call down to the SemaHLSL code for initialization lists, so it does actually support the edge cases it should and "do the right thing", but you today you should be able to delete this whole loop (https://github.com/llvm/llvm-project/blob/main/clang/lib/Sema/SemaInit.cpp#L6827), because everything it does is done in a more complete way during list initialization. A similar simplification is likely what you need for your PR here, which would have the effect of just passing the arguments through to the list initialization behavior for the underlying language. This is probably the right solution for both HLSL and C++ as it will keep the special grossness of HLSL to itself and allow C++ to have initializers that behave just like initializer lists. |
fixes #159434
In HLSL matrices are
matrix_type
in all respects except that they support a constructor style syntax for initializing matrices. This change adds a translation of vector constructor arguments into initializer lists.This supports the following HLSL syntax:
(1) HLSL matrices support constructor syntax
(2) HLSL matrices are expanded to constituate components in constructor