Skip to content

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Sep 26, 2025

fixes #159434

In HLSL matrices are matrix_type in all respects except that they support a constructor style syntax for initializing matrices. This change adds a translation of vector constructor arguments into initializer lists.

This supports the following HLSL syntax:
(1) HLSL matrices support constructor syntax
(2) HLSL matrices are expanded to constituate components in constructor

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Sep 26, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 26, 2025

@llvm/pr-subscribers-clang

Author: Farzon Lotfi (farzonl)

Changes

fixes #159434

In HLSL matrices are matrix_type in all respects except that they support a constructor style syntax for initializing matrices. This change adds a translation of vector constructor arguments into initializer lists.

This supports the following HLSL syntax:
(1) HLSL matrices support constructor syntax
(2) HLSL matrices are expanded to constituate components in constructor


Patch is 47.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160960.diff

6 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4-4)
  • (modified) clang/include/clang/Sema/Initialization.h (+8-4)
  • (modified) clang/lib/Sema/CheckExprLifetime.cpp (+1)
  • (modified) clang/lib/Sema/SemaInit.cpp (+186-13)
  • (added) clang/test/AST/HLSL/matrix-constructors.hlsl (+338)
  • (added) clang/test/SemaHLSL/BuiltIns/matrix-constructors-errors.hlsl (+24)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index bd0e53d3086b0..19e4c548d0208 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -6543,9 +6543,9 @@ def warn_extern_init : Warning<"'extern' variable has an initializer">,
 def err_variable_object_no_init : Error<
   "variable-sized object may not be initialized">;
 def err_excess_initializers : Error<
-  "excess elements in %select{array|vector|scalar|union|struct}0 initializer">;
+  "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">;
 def ext_excess_initializers : ExtWarn<
-  "excess elements in %select{array|vector|scalar|union|struct}0 initializer">,
+  "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">,
   InGroup<ExcessInitializers>;
 def err_excess_initializers_for_sizeless_type : Error<
   "excess elements in initializer for indivisible sizeless type %0">;
@@ -11086,8 +11086,8 @@ def err_first_argument_to_cwsc_pdtor_call : Error<
 def err_second_argument_to_cwsc_not_pointer : Error<
   "second argument to __builtin_call_with_static_chain must be of pointer type">;
 
-def err_vector_incorrect_num_elements : Error<
-  "%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">;
+def err_tensor_incorrect_num_elements : Error<
+  "%select{too many|too few}0 elements in %select{vector|matrix}1 %select{initialization|operand}4 (expected %2 elements, have %3)">;
 def err_altivec_empty_initializer : Error<"expected initializer">;
 
 def err_vector_incorrect_bit_count : Error<
diff --git a/clang/include/clang/Sema/Initialization.h b/clang/include/clang/Sema/Initialization.h
index d7675ea153afb..865b6428f3081 100644
--- a/clang/include/clang/Sema/Initialization.h
+++ b/clang/include/clang/Sema/Initialization.h
@@ -91,6 +91,10 @@ class alignas(8) InitializedEntity {
     /// or vector.
     EK_VectorElement,
 
+    /// The entity being initialized is an element of a matrix.
+    /// or matrix.
+    EK_MatrixElement,
+
     /// The entity being initialized is a field of block descriptor for
     /// the copied-in c++ object.
     EK_BlockElement,
@@ -205,8 +209,8 @@ class alignas(8) InitializedEntity {
     /// virtual base.
     llvm::PointerIntPair<const CXXBaseSpecifier *, 1> Base;
 
-    /// When Kind == EK_ArrayElement, EK_VectorElement, or
-    /// EK_ComplexElement, the index of the array or vector element being
+    /// When Kind == EK_ArrayElement, EK_VectorElement, or EK_MatrixElement,
+    /// or EK_ComplexElement, the index of the array or vector element being
     /// initialized.
     unsigned Index;
 
@@ -536,7 +540,7 @@ class alignas(8) InitializedEntity {
   /// element's index.
   unsigned getElementIndex() const {
     assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
-           getKind() == EK_ComplexElement);
+           getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
     return Index;
   }
 
@@ -544,7 +548,7 @@ class alignas(8) InitializedEntity {
   /// element, sets the element index.
   void setElementIndex(unsigned Index) {
     assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
-           getKind() == EK_ComplexElement);
+           getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
     this->Index = Index;
   }
 
diff --git a/clang/lib/Sema/CheckExprLifetime.cpp b/clang/lib/Sema/CheckExprLifetime.cpp
index e02e00231e58e..d647fbf007838 100644
--- a/clang/lib/Sema/CheckExprLifetime.cpp
+++ b/clang/lib/Sema/CheckExprLifetime.cpp
@@ -154,6 +154,7 @@ getEntityLifetime(const InitializedEntity *Entity,
   case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
   case InitializedEntity::EK_LambdaCapture:
   case InitializedEntity::EK_VectorElement:
+  case clang::InitializedEntity::EK_MatrixElement:
   case InitializedEntity::EK_ComplexElement:
     return {nullptr, LK_FullExpression};
 
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index c97129336736b..0f06d715600b3 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -17,6 +17,7 @@
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExprObjC.h"
 #include "clang/AST/IgnoreExpr.h"
+#include "clang/AST/TypeBase.h"
 #include "clang/AST/TypeLoc.h"
 #include "clang/Basic/SourceManager.h"
 #include "clang/Basic/Specifiers.h"
@@ -403,6 +404,9 @@ class InitListChecker {
                           unsigned &Index,
                           InitListExpr *StructuredList,
                           unsigned &StructuredIndex);
+  void CheckMatrixType(const InitializedEntity &Entity, InitListExpr *IList,
+                       QualType DeclType, unsigned &Index,
+                       InitListExpr *StructuredList, unsigned &StructuredIndex);
   void CheckVectorType(const InitializedEntity &Entity,
                        InitListExpr *IList, QualType DeclType, unsigned &Index,
                        InitListExpr *StructuredList,
@@ -1003,7 +1007,8 @@ InitListChecker::FillInEmptyInitializations(const InitializedEntity &Entity,
       return;
 
     if (ElementEntity.getKind() == InitializedEntity::EK_ArrayElement ||
-        ElementEntity.getKind() == InitializedEntity::EK_VectorElement)
+        ElementEntity.getKind() == InitializedEntity::EK_VectorElement ||
+        ElementEntity.getKind() == InitializedEntity::EK_MatrixElement)
       ElementEntity.setElementIndex(Init);
 
     if (Init >= NumInits && (ILE->hasArrayFiller() || SkipEmptyInitChecks))
@@ -1273,6 +1278,7 @@ static void warnBracedScalarInit(Sema &S, const InitializedEntity &Entity,
 
   switch (Entity.getKind()) {
   case InitializedEntity::EK_VectorElement:
+  case InitializedEntity::EK_MatrixElement:
   case InitializedEntity::EK_ComplexElement:
   case InitializedEntity::EK_ArrayElement:
   case InitializedEntity::EK_Parameter:
@@ -1372,11 +1378,12 @@ void InitListChecker::CheckExplicitInitList(const InitializedEntity &Entity,
       SemaRef.Diag(IList->getInit(Index)->getBeginLoc(), DK)
           << T << IList->getInit(Index)->getSourceRange();
     } else {
-      int initKind = T->isArrayType() ? 0 :
-                     T->isVectorType() ? 1 :
-                     T->isScalarType() ? 2 :
-                     T->isUnionType() ? 3 :
-                     4;
+      int initKind = T->isArrayType()    ? 0
+                     : T->isVectorType() ? 1
+                     : T->isMatrixType() ? 2
+                     : T->isScalarType() ? 3
+                     : T->isUnionType()  ? 4
+                                         : 5;
 
       unsigned DK = ExtraInitsIsError ? diag::err_excess_initializers
                                       : diag::ext_excess_initializers;
@@ -1430,6 +1437,9 @@ void InitListChecker::CheckListElementTypes(const InitializedEntity &Entity,
   } else if (DeclType->isVectorType()) {
     CheckVectorType(Entity, IList, DeclType, Index,
                     StructuredList, StructuredIndex);
+  } else if (DeclType->isMatrixType()) {
+    CheckMatrixType(Entity, IList, DeclType, Index, StructuredList,
+                    StructuredIndex);
   } else if (const RecordDecl *RD = DeclType->getAsRecordDecl()) {
     auto Bases =
         CXXRecordDecl::base_class_const_range(CXXRecordDecl::base_class_const_iterator(),
@@ -1877,6 +1887,93 @@ void InitListChecker::CheckReferenceType(const InitializedEntity &Entity,
     AggrDeductionCandidateParamTypes->push_back(DeclType);
 }
 
+void InitListChecker::CheckMatrixType(const InitializedEntity &Entity,
+                                      InitListExpr *IList, QualType DeclType,
+                                      unsigned &Index,
+                                      InitListExpr *StructuredList,
+                                      unsigned &StructuredIndex) {
+  if (!SemaRef.getLangOpts().HLSL)
+    return;
+
+  const ConstantMatrixType *MT = DeclType->castAs<ConstantMatrixType>();
+  QualType ElemTy = MT->getElementType();
+  const unsigned Rows = MT->getNumRows();
+  const unsigned Cols = MT->getNumColumns();
+  const unsigned MaxElts = Rows * Cols;
+
+  unsigned NumEltsInit = 0;
+  InitializedEntity ElemEnt =
+      InitializedEntity::InitializeElement(SemaRef.Context, 0, Entity);
+
+  // A Matrix initalizer should be able to take scalars, vectors, and matrices.
+  auto HandleInit = [&](InitListExpr *List, unsigned &Idx) {
+    Expr *Init = List->getInit(Idx);
+    QualType ITy = Init->getType();
+
+    if (ITy->isVectorType()) {
+      const VectorType *IVT = ITy->castAs<VectorType>();
+      unsigned N = IVT->getNumElements();
+      QualType VTy =
+          ITy->isExtVectorType()
+              ? SemaRef.Context.getExtVectorType(ElemTy, N)
+              : SemaRef.Context.getVectorType(ElemTy, N, IVT->getVectorKind());
+      ElemEnt.setElementIndex(Idx);
+      CheckSubElementType(ElemEnt, List, VTy, Idx, StructuredList,
+                          StructuredIndex);
+      NumEltsInit += N;
+      return;
+    }
+
+    if (ITy->isMatrixType()) {
+      const ConstantMatrixType *IMT = ITy->castAs<ConstantMatrixType>();
+      unsigned N = IMT->getNumRows() * IMT->getNumColumns();
+      QualType MTy = SemaRef.Context.getConstantMatrixType(
+          ElemTy, IMT->getNumRows(), IMT->getNumColumns());
+      ElemEnt.setElementIndex(Idx);
+      CheckSubElementType(ElemEnt, List, MTy, Idx, StructuredList,
+                          StructuredIndex);
+      NumEltsInit += N;
+      return;
+    }
+
+    // Scalar element
+    ElemEnt.setElementIndex(Idx);
+    CheckSubElementType(ElemEnt, List, ElemTy, Idx, StructuredList,
+                        StructuredIndex);
+    ++NumEltsInit;
+  };
+
+  // Column-major: each top-level sublist is treated as a column.
+  while (NumEltsInit < MaxElts && Index < IList->getNumInits()) {
+    Expr *Init = IList->getInit(Index);
+
+    if (auto *SubList = dyn_cast<InitListExpr>(Init)) {
+      unsigned SubIdx = 0;
+      unsigned Row = 0;
+      while (Row < Rows && SubIdx < SubList->getNumInits() &&
+             NumEltsInit < MaxElts) {
+        HandleInit(SubList, SubIdx);
+        ++Row;
+      }
+      ++Index; // advance past this column sublist
+      continue;
+    }
+
+    // Not a sublist: just consume directly.
+    HandleInit(IList, Index);
+  }
+
+  // HLSL requires exactly Rows*Cols initializers after flattening.
+  if (NumEltsInit != MaxElts) {
+    if (!VerifyOnly)
+      SemaRef.Diag(IList->getBeginLoc(),
+                   diag::err_tensor_incorrect_num_elements)
+          << (NumEltsInit < MaxElts) << /*matrix*/ 1 << MaxElts << NumEltsInit
+          << /*initialization*/ 0;
+    hadError = true;
+  }
+}
+
 void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
                                       InitListExpr *IList, QualType DeclType,
                                       unsigned &Index,
@@ -2026,9 +2123,9 @@ void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
   if (numEltsInit != maxElements) {
     if (!VerifyOnly)
       SemaRef.Diag(IList->getBeginLoc(),
-                   diag::err_vector_incorrect_num_elements)
-          << (numEltsInit < maxElements) << maxElements << numEltsInit
-          << /*initialization*/ 0;
+                   diag::err_tensor_incorrect_num_elements)
+          << (numEltsInit < maxElements) << /*vector*/ 0 << maxElements
+          << numEltsInit << /*initialization*/ 0;
     hadError = true;
   }
 }
@@ -3639,6 +3736,9 @@ InitializedEntity::InitializedEntity(ASTContext &Context, unsigned Index,
   } else if (const VectorType *VT = Parent.getType()->getAs<VectorType>()) {
     Kind = EK_VectorElement;
     Type = VT->getElementType();
+  } else if (const MatrixType *MT = Parent.getType()->getAs<MatrixType>()) {
+    Kind = EK_MatrixElement;
+    Type = MT->getElementType();
   } else {
     const ComplexType *CT = Parent.getType()->getAs<ComplexType>();
     assert(CT && "Unexpected type");
@@ -3687,6 +3787,7 @@ DeclarationName InitializedEntity::getName() const {
   case EK_Delegating:
   case EK_ArrayElement:
   case EK_VectorElement:
+  case EK_MatrixElement:
   case EK_ComplexElement:
   case EK_BlockElement:
   case EK_LambdaToBlockConversionBlockElement:
@@ -3720,6 +3821,7 @@ ValueDecl *InitializedEntity::getDecl() const {
   case EK_Delegating:
   case EK_ArrayElement:
   case EK_VectorElement:
+  case EK_MatrixElement:
   case EK_ComplexElement:
   case EK_BlockElement:
   case EK_LambdaToBlockConversionBlockElement:
@@ -3753,6 +3855,7 @@ bool InitializedEntity::allowsNRVO() const {
   case EK_Delegating:
   case EK_ArrayElement:
   case EK_VectorElement:
+  case EK_MatrixElement:
   case EK_ComplexElement:
   case EK_BlockElement:
   case EK_LambdaToBlockConversionBlockElement:
@@ -3792,6 +3895,9 @@ unsigned InitializedEntity::dumpImpl(raw_ostream &OS) const {
   case EK_Delegating: OS << "Delegating"; break;
   case EK_ArrayElement: OS << "ArrayElement " << Index; break;
   case EK_VectorElement: OS << "VectorElement " << Index; break;
+  case EK_MatrixElement:
+    OS << "MatrixElement " << Index;
+    break;
   case EK_ComplexElement: OS << "ComplexElement " << Index; break;
   case EK_BlockElement: OS << "Block"; break;
   case EK_LambdaToBlockConversionBlockElement:
@@ -6847,6 +6953,67 @@ void InitializationSequence::InitializeFrom(Sema &S,
     return;
   }
 
+  if (S.getLangOpts().HLSL && DestType->isMatrixType() &&
+      (SourceType.isNull() ||
+       !Context.hasSameUnqualifiedType(SourceType, DestType))) {
+
+    llvm::SmallVector<Expr *> InitArgs;
+
+    for (Expr *Arg : Args) {
+      QualType AT = Arg->getType();
+
+      // Expand matrix arguments element-by-element (col-major).
+      if (AT->isMatrixType()) {
+        const auto *MTy = AT->castAs<ConstantMatrixType>();
+        unsigned Rows = MTy->getNumRows();
+        unsigned Cols = MTy->getNumColumns();
+        QualType ElemTy = MTy->getElementType();
+
+        for (unsigned c = 0; c < Cols; ++c) {
+          for (unsigned r = 0; r < Rows; ++r) {
+            // row index literal
+            Expr *RowIdx = IntegerLiteral::Create(
+                Context, llvm::APInt(Context.getIntWidth(Context.IntTy), r),
+                Context.IntTy, Arg->getBeginLoc());
+            // column index literal
+            Expr *ColIdx = IntegerLiteral::Create(
+                Context, llvm::APInt(Context.getIntWidth(Context.IntTy), c),
+                Context.IntTy, Arg->getBeginLoc());
+
+            InitArgs.emplace_back(new (Context) MatrixSubscriptExpr(
+                Arg, RowIdx, ColIdx, ElemTy, Arg->getEndLoc()));
+          }
+        }
+
+        // Keep your vector expansion, in case vectors appear in the argument
+        // list.
+      } else if (AT->isExtVectorType()) {
+        const auto *VTy = AT->castAs<ExtVectorType>();
+        unsigned Elm = VTy->getNumElements();
+        for (unsigned Idx = 0; Idx < Elm; ++Idx) {
+          InitArgs.emplace_back(new (Context) ArraySubscriptExpr(
+              Arg,
+              IntegerLiteral::Create(
+                  Context, llvm::APInt(Context.getIntWidth(Context.IntTy), Idx),
+                  Context.IntTy, Arg->getBeginLoc()),
+              VTy->getElementType(), Arg->getValueKind(), Arg->getObjectKind(),
+              Arg->getEndLoc()));
+        }
+
+      } else {
+        // Scalar or other: forward as-is
+        InitArgs.emplace_back(Arg);
+      }
+    }
+
+    InitListExpr *ILE = new (Context) InitListExpr(
+        S.getASTContext(), SourceLocation(), InitArgs, SourceLocation());
+
+    Args[0] = ILE;
+    AddListInitializationStep(DestType);
+    return;
+  }
+
   // The remaining cases all need a source type.
   if (Args.size() > 1) {
     SetFailed(FK_TooManyInitsForScalar);
@@ -6999,6 +7166,7 @@ static AssignmentAction getAssignmentAction(const InitializedEntity &Entity,
   case InitializedEntity::EK_Binding:
   case InitializedEntity::EK_ArrayElement:
   case InitializedEntity::EK_VectorElement:
+  case InitializedEntity::EK_MatrixElement:
   case InitializedEntity::EK_ComplexElement:
   case InitializedEntity::EK_BlockElement:
   case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7024,6 +7192,7 @@ static bool shouldBindAsTemporary(const InitializedEntity &Entity) {
   case InitializedEntity::EK_Base:
   case InitializedEntity::EK_Delegating:
   case InitializedEntity::EK_VectorElement:
+  case InitializedEntity::EK_MatrixElement:
   case InitializedEntity::EK_ComplexElement:
   case InitializedEntity::EK_Exception:
   case InitializedEntity::EK_BlockElement:
@@ -7054,6 +7223,7 @@ static bool shouldDestroyEntity(const InitializedEntity &Entity) {
     case InitializedEntity::EK_Base:
     case InitializedEntity::EK_Delegating:
     case InitializedEntity::EK_VectorElement:
+    case InitializedEntity::EK_MatrixElement:
     case InitializedEntity::EK_ComplexElement:
     case InitializedEntity::EK_BlockElement:
     case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7107,6 +7277,7 @@ static SourceLocation getInitializationLoc(const InitializedEntity &Entity,
   case InitializedEntity::EK_Base:
   case InitializedEntity::EK_Delegating:
   case InitializedEntity::EK_VectorElement:
+  case InitializedEntity::EK_MatrixElement:
   case InitializedEntity::EK_ComplexElement:
   case InitializedEntity::EK_BlockElement:
   case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7858,9 +8029,11 @@ ExprResult InitializationSequence::Perform(Sema &S,
 
   // HLSL allows vector initialization to function like list initialization, but
   // use the syntax of a C++-like constructor.
-  bool IsHLSLVectorInit = S.getLangOpts().HLSL && DestType->isExtVectorType() &&
-                          isa<InitListExpr>(Args[0]);
-  (void)IsHLSLVectorInit;
+  bool IsHLSLVectorOrMatrixInit =
+      S.getLangOpts().HLSL &&
+      (DestType->isExtVectorType() || DestType->isMatrixType()) &&
+      isa<InitListExpr>(Args[0]);
+  (void)IsHLSLVectorOrMatrixInit;
 
   // For initialization steps that start with a single initializer,
   // grab the only argument out the Args and place it into the "current"
@@ -7899,7 +8072,7 @@ ExprResult InitializationSequence::Perform(Sema &S,
   case SK_StdInitializerList:
   case SK_OCLSamplerInit:
   case SK_OCLZeroOpaqueType: {
-    assert(Args.size() == 1 || IsHLSLVectorInit);
+    assert(Args.size() == 1 || IsHLSLVectorOrMatrixInit);
     CurInit = Args[0];
     if (!CurInit.get()) return ExprError();
     break;
diff --git a/clang/test/AST/HLSL/matrix-constructors.hlsl b/clang/test/AST/HLSL/matrix-constructors.hlsl
new file mode 100644
index 0000000000000..faee0162a314b
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-constructors.hlsl
@@ -0,0 +1,338 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s 
+
+typedef float float2x1 __attribute__((matrix_type(2,1)));
+typedef float float2x3 __attribute__((matrix_type(2,3)));
+typedef float float2x2 __attribute__((matrix_type(2,2)));
+typedef float float2 __attribute__((ext_vector_type(2)));
+typedef float float4 __attribute__((ext_vector_type(4)));
+
+[numthreads(1,1,1)]
+void ok() {
+
+  // CHECK: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:36> col:12 A 'float2x3':'matrix<float, 2, 3>' cinit
+  // CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} <col:16, col:36> 'float2x3':'matrix<float, 2, 3>' functional cast to float2x3 <NoOp>
+  // CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} <col:25, col:35> 'float2x3':'matrix<float, 2, 3>'
+  // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:25> 'float' <IntegralToFloating>
+  // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:25> 'int' 1
+  // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:27> 'float' <IntegralToFloating>
+  // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:27> 'int' 2
+  // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:29> 'float' <IntegralToFloating>
+  // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:29> 'int' 3
+  // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:31> 'float' <IntegralToFloating>
+  // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:31> 'int' 4
+  // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:33> 'float' <IntegralToFloating>
+  // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:33> ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 26, 2025

@llvm/pr-subscribers-hlsl

Author: Farzon Lotfi (farzonl)

Changes

fixes #159434

In HLSL matrices are matrix_type in all respects except that they support a constructor style syntax for initializing matrices. This change adds a translation of vector constructor arguments into initializer lists.

This supports the following HLSL syntax:
(1) HLSL matrices support constructor syntax
(2) HLSL matrices are expanded to constituate components in constructor


Patch is 47.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160960.diff

6 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4-4)
  • (modified) clang/include/clang/Sema/Initialization.h (+8-4)
  • (modified) clang/lib/Sema/CheckExprLifetime.cpp (+1)
  • (modified) clang/lib/Sema/SemaInit.cpp (+186-13)
  • (added) clang/test/AST/HLSL/matrix-constructors.hlsl (+338)
  • (added) clang/test/SemaHLSL/BuiltIns/matrix-constructors-errors.hlsl (+24)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index bd0e53d3086b0..19e4c548d0208 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -6543,9 +6543,9 @@ def warn_extern_init : Warning<"'extern' variable has an initializer">,
 def err_variable_object_no_init : Error<
   "variable-sized object may not be initialized">;
 def err_excess_initializers : Error<
-  "excess elements in %select{array|vector|scalar|union|struct}0 initializer">;
+  "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">;
 def ext_excess_initializers : ExtWarn<
-  "excess elements in %select{array|vector|scalar|union|struct}0 initializer">,
+  "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">,
   InGroup<ExcessInitializers>;
 def err_excess_initializers_for_sizeless_type : Error<
   "excess elements in initializer for indivisible sizeless type %0">;
@@ -11086,8 +11086,8 @@ def err_first_argument_to_cwsc_pdtor_call : Error<
 def err_second_argument_to_cwsc_not_pointer : Error<
   "second argument to __builtin_call_with_static_chain must be of pointer type">;
 
-def err_vector_incorrect_num_elements : Error<
-  "%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">;
+def err_tensor_incorrect_num_elements : Error<
+  "%select{too many|too few}0 elements in %select{vector|matrix}1 %select{initialization|operand}4 (expected %2 elements, have %3)">;
 def err_altivec_empty_initializer : Error<"expected initializer">;
 
 def err_vector_incorrect_bit_count : Error<
diff --git a/clang/include/clang/Sema/Initialization.h b/clang/include/clang/Sema/Initialization.h
index d7675ea153afb..865b6428f3081 100644
--- a/clang/include/clang/Sema/Initialization.h
+++ b/clang/include/clang/Sema/Initialization.h
@@ -91,6 +91,10 @@ class alignas(8) InitializedEntity {
     /// or vector.
     EK_VectorElement,
 
+    /// The entity being initialized is an element of a matrix.
+    /// or matrix.
+    EK_MatrixElement,
+
     /// The entity being initialized is a field of block descriptor for
     /// the copied-in c++ object.
     EK_BlockElement,
@@ -205,8 +209,8 @@ class alignas(8) InitializedEntity {
     /// virtual base.
     llvm::PointerIntPair<const CXXBaseSpecifier *, 1> Base;
 
-    /// When Kind == EK_ArrayElement, EK_VectorElement, or
-    /// EK_ComplexElement, the index of the array or vector element being
+    /// When Kind == EK_ArrayElement, EK_VectorElement, or EK_MatrixElement,
+    /// or EK_ComplexElement, the index of the array or vector element being
     /// initialized.
     unsigned Index;
 
@@ -536,7 +540,7 @@ class alignas(8) InitializedEntity {
   /// element's index.
   unsigned getElementIndex() const {
     assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
-           getKind() == EK_ComplexElement);
+           getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
     return Index;
   }
 
@@ -544,7 +548,7 @@ class alignas(8) InitializedEntity {
   /// element, sets the element index.
   void setElementIndex(unsigned Index) {
     assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
-           getKind() == EK_ComplexElement);
+           getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
     this->Index = Index;
   }
 
diff --git a/clang/lib/Sema/CheckExprLifetime.cpp b/clang/lib/Sema/CheckExprLifetime.cpp
index e02e00231e58e..d647fbf007838 100644
--- a/clang/lib/Sema/CheckExprLifetime.cpp
+++ b/clang/lib/Sema/CheckExprLifetime.cpp
@@ -154,6 +154,7 @@ getEntityLifetime(const InitializedEntity *Entity,
   case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
   case InitializedEntity::EK_LambdaCapture:
   case InitializedEntity::EK_VectorElement:
+  case clang::InitializedEntity::EK_MatrixElement:
   case InitializedEntity::EK_ComplexElement:
     return {nullptr, LK_FullExpression};
 
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index c97129336736b..0f06d715600b3 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -17,6 +17,7 @@
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExprObjC.h"
 #include "clang/AST/IgnoreExpr.h"
+#include "clang/AST/TypeBase.h"
 #include "clang/AST/TypeLoc.h"
 #include "clang/Basic/SourceManager.h"
 #include "clang/Basic/Specifiers.h"
@@ -403,6 +404,9 @@ class InitListChecker {
                           unsigned &Index,
                           InitListExpr *StructuredList,
                           unsigned &StructuredIndex);
+  void CheckMatrixType(const InitializedEntity &Entity, InitListExpr *IList,
+                       QualType DeclType, unsigned &Index,
+                       InitListExpr *StructuredList, unsigned &StructuredIndex);
   void CheckVectorType(const InitializedEntity &Entity,
                        InitListExpr *IList, QualType DeclType, unsigned &Index,
                        InitListExpr *StructuredList,
@@ -1003,7 +1007,8 @@ InitListChecker::FillInEmptyInitializations(const InitializedEntity &Entity,
       return;
 
     if (ElementEntity.getKind() == InitializedEntity::EK_ArrayElement ||
-        ElementEntity.getKind() == InitializedEntity::EK_VectorElement)
+        ElementEntity.getKind() == InitializedEntity::EK_VectorElement ||
+        ElementEntity.getKind() == InitializedEntity::EK_MatrixElement)
       ElementEntity.setElementIndex(Init);
 
     if (Init >= NumInits && (ILE->hasArrayFiller() || SkipEmptyInitChecks))
@@ -1273,6 +1278,7 @@ static void warnBracedScalarInit(Sema &S, const InitializedEntity &Entity,
 
   switch (Entity.getKind()) {
   case InitializedEntity::EK_VectorElement:
+  case InitializedEntity::EK_MatrixElement:
   case InitializedEntity::EK_ComplexElement:
   case InitializedEntity::EK_ArrayElement:
   case InitializedEntity::EK_Parameter:
@@ -1372,11 +1378,12 @@ void InitListChecker::CheckExplicitInitList(const InitializedEntity &Entity,
       SemaRef.Diag(IList->getInit(Index)->getBeginLoc(), DK)
           << T << IList->getInit(Index)->getSourceRange();
     } else {
-      int initKind = T->isArrayType() ? 0 :
-                     T->isVectorType() ? 1 :
-                     T->isScalarType() ? 2 :
-                     T->isUnionType() ? 3 :
-                     4;
+      int initKind = T->isArrayType()    ? 0
+                     : T->isVectorType() ? 1
+                     : T->isMatrixType() ? 2
+                     : T->isScalarType() ? 3
+                     : T->isUnionType()  ? 4
+                                         : 5;
 
       unsigned DK = ExtraInitsIsError ? diag::err_excess_initializers
                                       : diag::ext_excess_initializers;
@@ -1430,6 +1437,9 @@ void InitListChecker::CheckListElementTypes(const InitializedEntity &Entity,
   } else if (DeclType->isVectorType()) {
     CheckVectorType(Entity, IList, DeclType, Index,
                     StructuredList, StructuredIndex);
+  } else if (DeclType->isMatrixType()) {
+    CheckMatrixType(Entity, IList, DeclType, Index, StructuredList,
+                    StructuredIndex);
   } else if (const RecordDecl *RD = DeclType->getAsRecordDecl()) {
     auto Bases =
         CXXRecordDecl::base_class_const_range(CXXRecordDecl::base_class_const_iterator(),
@@ -1877,6 +1887,93 @@ void InitListChecker::CheckReferenceType(const InitializedEntity &Entity,
     AggrDeductionCandidateParamTypes->push_back(DeclType);
 }
 
+void InitListChecker::CheckMatrixType(const InitializedEntity &Entity,
+                                      InitListExpr *IList, QualType DeclType,
+                                      unsigned &Index,
+                                      InitListExpr *StructuredList,
+                                      unsigned &StructuredIndex) {
+  if (!SemaRef.getLangOpts().HLSL)
+    return;
+
+  const ConstantMatrixType *MT = DeclType->castAs<ConstantMatrixType>();
+  QualType ElemTy = MT->getElementType();
+  const unsigned Rows = MT->getNumRows();
+  const unsigned Cols = MT->getNumColumns();
+  const unsigned MaxElts = Rows * Cols;
+
+  unsigned NumEltsInit = 0;
+  InitializedEntity ElemEnt =
+      InitializedEntity::InitializeElement(SemaRef.Context, 0, Entity);
+
+  // A Matrix initalizer should be able to take scalars, vectors, and matrices.
+  auto HandleInit = [&](InitListExpr *List, unsigned &Idx) {
+    Expr *Init = List->getInit(Idx);
+    QualType ITy = Init->getType();
+
+    if (ITy->isVectorType()) {
+      const VectorType *IVT = ITy->castAs<VectorType>();
+      unsigned N = IVT->getNumElements();
+      QualType VTy =
+          ITy->isExtVectorType()
+              ? SemaRef.Context.getExtVectorType(ElemTy, N)
+              : SemaRef.Context.getVectorType(ElemTy, N, IVT->getVectorKind());
+      ElemEnt.setElementIndex(Idx);
+      CheckSubElementType(ElemEnt, List, VTy, Idx, StructuredList,
+                          StructuredIndex);
+      NumEltsInit += N;
+      return;
+    }
+
+    if (ITy->isMatrixType()) {
+      const ConstantMatrixType *IMT = ITy->castAs<ConstantMatrixType>();
+      unsigned N = IMT->getNumRows() * IMT->getNumColumns();
+      QualType MTy = SemaRef.Context.getConstantMatrixType(
+          ElemTy, IMT->getNumRows(), IMT->getNumColumns());
+      ElemEnt.setElementIndex(Idx);
+      CheckSubElementType(ElemEnt, List, MTy, Idx, StructuredList,
+                          StructuredIndex);
+      NumEltsInit += N;
+      return;
+    }
+
+    // Scalar element
+    ElemEnt.setElementIndex(Idx);
+    CheckSubElementType(ElemEnt, List, ElemTy, Idx, StructuredList,
+                        StructuredIndex);
+    ++NumEltsInit;
+  };
+
+  // Column-major: each top-level sublist is treated as a column.
+  while (NumEltsInit < MaxElts && Index < IList->getNumInits()) {
+    Expr *Init = IList->getInit(Index);
+
+    if (auto *SubList = dyn_cast<InitListExpr>(Init)) {
+      unsigned SubIdx = 0;
+      unsigned Row = 0;
+      while (Row < Rows && SubIdx < SubList->getNumInits() &&
+             NumEltsInit < MaxElts) {
+        HandleInit(SubList, SubIdx);
+        ++Row;
+      }
+      ++Index; // advance past this column sublist
+      continue;
+    }
+
+    // Not a sublist: just consume directly.
+    HandleInit(IList, Index);
+  }
+
+  // HLSL requires exactly Rows*Cols initializers after flattening.
+  if (NumEltsInit != MaxElts) {
+    if (!VerifyOnly)
+      SemaRef.Diag(IList->getBeginLoc(),
+                   diag::err_tensor_incorrect_num_elements)
+          << (NumEltsInit < MaxElts) << /*matrix*/ 1 << MaxElts << NumEltsInit
+          << /*initialization*/ 0;
+    hadError = true;
+  }
+}
+
 void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
                                       InitListExpr *IList, QualType DeclType,
                                       unsigned &Index,
@@ -2026,9 +2123,9 @@ void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
   if (numEltsInit != maxElements) {
     if (!VerifyOnly)
       SemaRef.Diag(IList->getBeginLoc(),
-                   diag::err_vector_incorrect_num_elements)
-          << (numEltsInit < maxElements) << maxElements << numEltsInit
-          << /*initialization*/ 0;
+                   diag::err_tensor_incorrect_num_elements)
+          << (numEltsInit < maxElements) << /*vector*/ 0 << maxElements
+          << numEltsInit << /*initialization*/ 0;
     hadError = true;
   }
 }
@@ -3639,6 +3736,9 @@ InitializedEntity::InitializedEntity(ASTContext &Context, unsigned Index,
   } else if (const VectorType *VT = Parent.getType()->getAs<VectorType>()) {
     Kind = EK_VectorElement;
     Type = VT->getElementType();
+  } else if (const MatrixType *MT = Parent.getType()->getAs<MatrixType>()) {
+    Kind = EK_MatrixElement;
+    Type = MT->getElementType();
   } else {
     const ComplexType *CT = Parent.getType()->getAs<ComplexType>();
     assert(CT && "Unexpected type");
@@ -3687,6 +3787,7 @@ DeclarationName InitializedEntity::getName() const {
   case EK_Delegating:
   case EK_ArrayElement:
   case EK_VectorElement:
+  case EK_MatrixElement:
   case EK_ComplexElement:
   case EK_BlockElement:
   case EK_LambdaToBlockConversionBlockElement:
@@ -3720,6 +3821,7 @@ ValueDecl *InitializedEntity::getDecl() const {
   case EK_Delegating:
   case EK_ArrayElement:
   case EK_VectorElement:
+  case EK_MatrixElement:
   case EK_ComplexElement:
   case EK_BlockElement:
   case EK_LambdaToBlockConversionBlockElement:
@@ -3753,6 +3855,7 @@ bool InitializedEntity::allowsNRVO() const {
   case EK_Delegating:
   case EK_ArrayElement:
   case EK_VectorElement:
+  case EK_MatrixElement:
   case EK_ComplexElement:
   case EK_BlockElement:
   case EK_LambdaToBlockConversionBlockElement:
@@ -3792,6 +3895,9 @@ unsigned InitializedEntity::dumpImpl(raw_ostream &OS) const {
   case EK_Delegating: OS << "Delegating"; break;
   case EK_ArrayElement: OS << "ArrayElement " << Index; break;
   case EK_VectorElement: OS << "VectorElement " << Index; break;
+  case EK_MatrixElement:
+    OS << "MatrixElement " << Index;
+    break;
   case EK_ComplexElement: OS << "ComplexElement " << Index; break;
   case EK_BlockElement: OS << "Block"; break;
   case EK_LambdaToBlockConversionBlockElement:
@@ -6847,6 +6953,67 @@ void InitializationSequence::InitializeFrom(Sema &S,
     return;
   }
 
+  if (S.getLangOpts().HLSL && DestType->isMatrixType() &&
+      (SourceType.isNull() ||
+       !Context.hasSameUnqualifiedType(SourceType, DestType))) {
+
+    llvm::SmallVector<Expr *> InitArgs;
+
+    for (Expr *Arg : Args) {
+      QualType AT = Arg->getType();
+
+      // Expand matrix arguments element-by-element (col-major).
+      if (AT->isMatrixType()) {
+        const auto *MTy = AT->castAs<ConstantMatrixType>();
+        unsigned Rows = MTy->getNumRows();
+        unsigned Cols = MTy->getNumColumns();
+        QualType ElemTy = MTy->getElementType();
+
+        for (unsigned c = 0; c < Cols; ++c) {
+          for (unsigned r = 0; r < Rows; ++r) {
+            // row index literal
+            Expr *RowIdx = IntegerLiteral::Create(
+                Context, llvm::APInt(Context.getIntWidth(Context.IntTy), r),
+                Context.IntTy, Arg->getBeginLoc());
+            // column index literal
+            Expr *ColIdx = IntegerLiteral::Create(
+                Context, llvm::APInt(Context.getIntWidth(Context.IntTy), c),
+                Context.IntTy, Arg->getBeginLoc());
+
+            InitArgs.emplace_back(new (Context) MatrixSubscriptExpr(
+                Arg, RowIdx, ColIdx, ElemTy, Arg->getEndLoc()));
+          }
+        }
+
+        // Keep your vector expansion, in case vectors appear in the argument
+        // list.
+      } else if (AT->isExtVectorType()) {
+        const auto *VTy = AT->castAs<ExtVectorType>();
+        unsigned Elm = VTy->getNumElements();
+        for (unsigned Idx = 0; Idx < Elm; ++Idx) {
+          InitArgs.emplace_back(new (Context) ArraySubscriptExpr(
+              Arg,
+              IntegerLiteral::Create(
+                  Context, llvm::APInt(Context.getIntWidth(Context.IntTy), Idx),
+                  Context.IntTy, Arg->getBeginLoc()),
+              VTy->getElementType(), Arg->getValueKind(), Arg->getObjectKind(),
+              Arg->getEndLoc()));
+        }
+
+      } else {
+        // Scalar or other: forward as-is
+        InitArgs.emplace_back(Arg);
+      }
+    }
+
+    InitListExpr *ILE = new (Context) InitListExpr(
+        S.getASTContext(), SourceLocation(), InitArgs, SourceLocation());
+
+    Args[0] = ILE;
+    AddListInitializationStep(DestType);
+    return;
+  }
+
   // The remaining cases all need a source type.
   if (Args.size() > 1) {
     SetFailed(FK_TooManyInitsForScalar);
@@ -6999,6 +7166,7 @@ static AssignmentAction getAssignmentAction(const InitializedEntity &Entity,
   case InitializedEntity::EK_Binding:
   case InitializedEntity::EK_ArrayElement:
   case InitializedEntity::EK_VectorElement:
+  case InitializedEntity::EK_MatrixElement:
   case InitializedEntity::EK_ComplexElement:
   case InitializedEntity::EK_BlockElement:
   case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7024,6 +7192,7 @@ static bool shouldBindAsTemporary(const InitializedEntity &Entity) {
   case InitializedEntity::EK_Base:
   case InitializedEntity::EK_Delegating:
   case InitializedEntity::EK_VectorElement:
+  case InitializedEntity::EK_MatrixElement:
   case InitializedEntity::EK_ComplexElement:
   case InitializedEntity::EK_Exception:
   case InitializedEntity::EK_BlockElement:
@@ -7054,6 +7223,7 @@ static bool shouldDestroyEntity(const InitializedEntity &Entity) {
     case InitializedEntity::EK_Base:
     case InitializedEntity::EK_Delegating:
     case InitializedEntity::EK_VectorElement:
+    case InitializedEntity::EK_MatrixElement:
     case InitializedEntity::EK_ComplexElement:
     case InitializedEntity::EK_BlockElement:
     case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7107,6 +7277,7 @@ static SourceLocation getInitializationLoc(const InitializedEntity &Entity,
   case InitializedEntity::EK_Base:
   case InitializedEntity::EK_Delegating:
   case InitializedEntity::EK_VectorElement:
+  case InitializedEntity::EK_MatrixElement:
   case InitializedEntity::EK_ComplexElement:
   case InitializedEntity::EK_BlockElement:
   case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7858,9 +8029,11 @@ ExprResult InitializationSequence::Perform(Sema &S,
 
   // HLSL allows vector initialization to function like list initialization, but
   // use the syntax of a C++-like constructor.
-  bool IsHLSLVectorInit = S.getLangOpts().HLSL && DestType->isExtVectorType() &&
-                          isa<InitListExpr>(Args[0]);
-  (void)IsHLSLVectorInit;
+  bool IsHLSLVectorOrMatrixInit =
+      S.getLangOpts().HLSL &&
+      (DestType->isExtVectorType() || DestType->isMatrixType()) &&
+      isa<InitListExpr>(Args[0]);
+  (void)IsHLSLVectorOrMatrixInit;
 
   // For initialization steps that start with a single initializer,
   // grab the only argument out the Args and place it into the "current"
@@ -7899,7 +8072,7 @@ ExprResult InitializationSequence::Perform(Sema &S,
   case SK_StdInitializerList:
   case SK_OCLSamplerInit:
   case SK_OCLZeroOpaqueType: {
-    assert(Args.size() == 1 || IsHLSLVectorInit);
+    assert(Args.size() == 1 || IsHLSLVectorOrMatrixInit);
     CurInit = Args[0];
     if (!CurInit.get()) return ExprError();
     break;
diff --git a/clang/test/AST/HLSL/matrix-constructors.hlsl b/clang/test/AST/HLSL/matrix-constructors.hlsl
new file mode 100644
index 0000000000000..faee0162a314b
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-constructors.hlsl
@@ -0,0 +1,338 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s 
+
+typedef float float2x1 __attribute__((matrix_type(2,1)));
+typedef float float2x3 __attribute__((matrix_type(2,3)));
+typedef float float2x2 __attribute__((matrix_type(2,2)));
+typedef float float2 __attribute__((ext_vector_type(2)));
+typedef float float4 __attribute__((ext_vector_type(4)));
+
+[numthreads(1,1,1)]
+void ok() {
+
+  // CHECK: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:36> col:12 A 'float2x3':'matrix<float, 2, 3>' cinit
+  // CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} <col:16, col:36> 'float2x3':'matrix<float, 2, 3>' functional cast to float2x3 <NoOp>
+  // CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} <col:25, col:35> 'float2x3':'matrix<float, 2, 3>'
+  // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:25> 'float' <IntegralToFloating>
+  // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:25> 'int' 1
+  // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:27> 'float' <IntegralToFloating>
+  // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:27> 'int' 2
+  // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:29> 'float' <IntegralToFloating>
+  // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:29> 'int' 3
+  // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:31> 'float' <IntegralToFloating>
+  // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:31> 'int' 4
+  // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:33> 'float' <IntegralToFloating>
+  // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:33> ...
[truncated]

@farzonl farzonl marked this pull request as draft September 26, 2025 22:47
@farzonl farzonl force-pushed the feature/matrix_constructor_issue-159434 branch 2 times, most recently from 0bf8626 to 3fbcceb Compare September 26, 2025 22:57
Copy link

github-actions bot commented Sep 26, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

fixes llvm#159434

In HLSL matrices are `matrix_type` in all respects except that they
support a constructor style syntax for initializing matrices. This
change adds a translation of vector constructor arguments into
initializer lists.

This supports the following HLSL syntax:
(1) HLSL matrices support constructor syntax
(2) HLSL matrices are expanded to constituate components in constructor
@farzonl farzonl force-pushed the feature/matrix_constructor_issue-159434 branch from 3fbcceb to 526c1cd Compare September 26, 2025 23:05
@farzonl farzonl self-assigned this Sep 26, 2025
@farzonl farzonl marked this pull request as ready for review September 26, 2025 23:10
Comment on lines +218 to +230
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:34> 'int' 1
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:36> 'float' <IntegralToFloating>
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:36> 'int' 2
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:38> 'float' <IntegralToFloating>
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:38> 'int' 3
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:40> 'float' <IntegralToFloating>
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:40> 'int' 4
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:25> 'int' 1
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:25> 'int' 1
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:44> 'float' <IntegralToFloating>
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:44> 'int' 5
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:47> 'float' <IntegralToFloating>
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:47> 'int' 6
Copy link
Member Author

@farzonl farzonl Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://godbolt.org/z/Wh1daeMYP
From the dxc godbolt seems like this needs to be 1,4,2,5,3,6

<6 x float> <float 1.000000e+00, float 4.000000e+00, float 2.000000e+00, float 5.000000e+00, float 3.000000e+00, float 6.000000e+00>

It seems like we are doing
[0,0], [1,0], [0,1], [1,1], 5, 6
1, 3, 2,4,5,6

@farzonl farzonl marked this pull request as draft September 26, 2025 23:27
@farzonl
Copy link
Member Author

farzonl commented Sep 26, 2025

Investigating failures revert to draft
Failed Tests (2):
Clang :: SemaCXX/err_init_conversion_failed.cpp
Clang :: SemaCXX/paren-list-agg-init.cpp

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any public documentation for the HLSL syntax that could be linked?

InitializedEntity::InitializeElement(SemaRef.Context, 0, Entity);

// A Matrix initalizer should be able to take scalars, vectors, and matrices.
auto HandleInit = [&](InitListExpr *List, unsigned &Idx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to also support initialization syntax for regular matrix types, addressing the TODO from https://clang.llvm.org/docs/MatrixTypes.html#todos

Copy link
Member Author

@farzonl farzonl Sep 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. is there a spec for how clang matrix_type constructors are intended to work that I can compare against what HLSL does? the todos section is pretty light and doesn't indicate what is expected.

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this approach probably needs some adjustment. HLSL allows constructor-like syntax but it behaves the same way the HLSL initializer list syntax does (note: It’s a CXX functional cast, not a constructor), so matrix initialization should be done through SemaHLSL’s initialization list handling.

Here’s an example of an odd case that really makes this special:


struct F {
    float f[16];
};

export void fn() {
    F f;
    float4x4 M = float4x4(f);
}

https://godbolt.org/z/z9GoGvnrK

The language spec that describes aggregate initialization is available here.

@farzonl
Copy link
Member Author

farzonl commented Sep 27, 2025

Is there any public documentation for the HLSL syntax that could be linked?

Let me see what I can find. I'm basing my implementation off of what works in DXC.

@farzonl
Copy link
Member Author

farzonl commented Sep 27, 2025

I think this approach probably needs some adjustment. HLSL allows constructor-like syntax but it behaves the same way the HLSL initializer list syntax does, so matrix initialization should be done through SemaHLSL’s initialization list handling.

Will look into this thanks for spec reference. If I do things through SemaHLSL for matrices how do we want to support this request?
#160960 (comment)

Also why did we do vectors via SemaInit? Are vectors also using SemaHLSL’s initialization list handling? I'm assuming if so it came after the commit linked below

9f499d9

@llvm-beanz
Copy link
Collaborator

I think this approach probably needs some adjustment. HLSL allows constructor-like syntax but it behaves the same way the HLSL initializer list syntax does, so matrix initialization should be done through SemaHLSL’s initialization list handling.

Will look into this thanks for spec reference. If I do things through SemaHLSL for matrices how do we want to support this request? #160960 (comment)

I think we likely will want to handle C/C++ matrix initializers in a more sane way than the HLSL initializers need to work (at least for current HLSL). More on this below...

Also why did we do vectors via SemaInit? Are vectors also using SemaHLSL’s initialization list handling? I'm assuming if so it came after the commit linked below

9f499d9

This commit predates our efforts to have spec language drafted before we make changes in Clang and is a pretty great example of why we are emphasizing testing and spec writing.

Amusingly, that code actually does eventually call down to the SemaHLSL code for initialization lists, so it does actually support the edge cases it should and "do the right thing", but you today you should be able to delete this whole loop (https://github.com/llvm/llvm-project/blob/main/clang/lib/Sema/SemaInit.cpp#L6827), because everything it does is done in a more complete way during list initialization.

A similar simplification is likely what you need for your PR here, which would have the effect of just passing the arguments through to the list initialization behavior for the underlying language. This is probably the right solution for both HLSL and C++ as it will keep the special grossness of HLSL to itself and allow C++ to have initializers that behave just like initializer lists.

@farzonl farzonl closed this Oct 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[HLSL] Support HLSL matrix initializers

4 participants