Skip to content

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Sep 17, 2025

fixes #109839

This change is really simple. It creates a matrix alias that will let HLSL use the existing clang matrix_type infra.

The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL.

Testing therefore is limited to exercising the alias, sema errors, and basic codegen.
future work will add things like constructors and accessors.

The main difference in this attempt is the type printer and less of an emphasis on tests where things overlap with existing matrix_type testing like cast behavior.

fixes llvm#109839

This change is really simple. It creates a matrix alias
that will let HLSL  use the existing clang `matrix_type` infra.

The only additional change was to add explict alias for the
typed  dimensions of 1-4 inclusive matricies available in HLSL.

Testing therefore is limited to exercising the alias.

The main difference in this attempt is the type printer.
@farzonl farzonl requested a review from fhahn September 17, 2025 20:21
@farzonl farzonl self-assigned this Sep 17, 2025
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics HLSL HLSL Language Support labels Sep 17, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2025

@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Farzon Lotfi (farzonl)

Changes

fixes #109839

This change is really simple. It creates a matrix alias that will let HLSL use the existing clang matrix_type infra.

The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL.

Testing therefore is limited to exercising the alias, sema errors, and basic codegen.
future work will add things like constructors and accessors.

The main difference in this attempt is the type printer and less of an emphasis on tests where things overlap with existing matrix_type testing like cast behavior.


Patch is 24.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159446.diff

10 Files Affected:

  • (modified) clang/include/clang/Driver/Options.td (+1-1)
  • (modified) clang/include/clang/Sema/HLSLExternalSemaSource.h (+1)
  • (modified) clang/lib/AST/TypePrinter.cpp (+31-6)
  • (modified) clang/lib/Headers/hlsl/hlsl_basic_types.h (+233)
  • (modified) clang/lib/Sema/HLSLExternalSemaSource.cpp (+72)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (-1)
  • (added) clang/test/AST/HLSL/matrix-alias.hlsl (+49)
  • (added) clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl (+30)
  • (added) clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl (+5)
  • (added) clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl (+29)
diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index a7c514e809aa9..e8869e04390f3 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -4582,7 +4582,7 @@ defm ptrauth_block_descriptor_pointers : OptInCC1FFlag<"ptrauth-block-descriptor
 def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
     Visibility<[ClangOption, CC1Option]>,
     HelpText<"Enable matrix data type and related builtin functions">,
-    MarshallingInfoFlag<LangOpts<"MatrixTypes">>;
+    MarshallingInfoFlag<LangOpts<"MatrixTypes">, hlsl.KeyPath>;
 
 defm raw_string_literals : BoolFOption<"raw-string-literals",
     LangOpts<"RawStringLiterals">, Default<std#".hasRawStringLiterals()">,
diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h
index d93fb8c8eef6b..049fc7b8fe3f2 100644
--- a/clang/include/clang/Sema/HLSLExternalSemaSource.h
+++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -44,6 +44,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
 private:
   void defineTrivialHLSLTypes();
   void defineHLSLVectorAlias();
+  void defineHLSLMatrixAlias();
   void defineHLSLTypesWithForwardDeclarations();
   void onCompletion(CXXRecordDecl *Record, CompletionFunction Fn);
 };
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index cd59678d67f2f..82859c4015772 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -846,16 +846,41 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
   }
 }
 
-void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
-                                            raw_ostream &OS) {
-  printBefore(T->getElementType(), OS);
-  OS << " __attribute__((matrix_type(";
+static void printDims(const ConstantMatrixType *T, raw_ostream &OS) {
   OS << T->getNumRows() << ", " << T->getNumColumns();
+}
+
+static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) {
+  OS << "matrix<";
+  TP.printBefore(T->getElementType(), OS);
+}
+
+static void printHLSLMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
+  OS << ", ";
+  printDims(T, OS);
+  OS << ">";
+}
+
+static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) {
+  TP.printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  printDims(T, OS);
   OS << ")))";
 }
 
-void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
-                                           raw_ostream &OS) {
+void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, raw_ostream &OS) {
+  if (Policy.UseHLSLTypes) {
+    printHLSLMatrixBefore(*this, T, OS);
+    return;
+  }
+  printClangMatrixBefore(*this, T, OS);
+}
+
+void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
+  if (Policy.UseHLSLTypes) {
+    printHLSLMatrixAfter(T, OS);
+    return;
+  }
   printAfter(T->getElementType(), OS);
 }
 
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index eff94e0d7f950..c750261952d4f 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -115,6 +115,239 @@ typedef vector<float64_t, 2> float64_t2;
 typedef vector<float64_t, 3> float64_t3;
 typedef vector<float64_t, 4> float64_t4;
 
+ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<int16_t, 1, 1> int16_t1x1;
+typedef matrix<int16_t, 1, 2> int16_t1x2;
+typedef matrix<int16_t, 1, 3> int16_t1x3;
+typedef matrix<int16_t, 1, 4> int16_t1x4;
+typedef matrix<int16_t, 2, 1> int16_t2x1;
+typedef matrix<int16_t, 2, 2> int16_t2x2;
+typedef matrix<int16_t, 2, 3> int16_t2x3;
+typedef matrix<int16_t, 2, 4> int16_t2x4;
+typedef matrix<int16_t, 3, 1> int16_t3x1;
+typedef matrix<int16_t, 3, 2> int16_t3x2;
+typedef matrix<int16_t, 3, 3> int16_t3x3;
+typedef matrix<int16_t, 3, 4> int16_t3x4;
+typedef matrix<int16_t, 4, 1> int16_t4x1;
+typedef matrix<int16_t, 4, 2> int16_t4x2;
+typedef matrix<int16_t, 4, 3> int16_t4x3;
+typedef matrix<int16_t, 4, 4> int16_t4x4;
+typedef matrix<uint16_t, 1, 1> uint16_t1x1;
+typedef matrix<uint16_t, 1, 2> uint16_t1x2;
+typedef matrix<uint16_t, 1, 3> uint16_t1x3;
+typedef matrix<uint16_t, 1, 4> uint16_t1x4;
+typedef matrix<uint16_t, 2, 1> uint16_t2x1;
+typedef matrix<uint16_t, 2, 2> uint16_t2x2;
+typedef matrix<uint16_t, 2, 3> uint16_t2x3;
+typedef matrix<uint16_t, 2, 4> uint16_t2x4;
+typedef matrix<uint16_t, 3, 1> uint16_t3x1;
+typedef matrix<uint16_t, 3, 2> uint16_t3x2;
+typedef matrix<uint16_t, 3, 3> uint16_t3x3;
+typedef matrix<uint16_t, 3, 4> uint16_t3x4;
+typedef matrix<uint16_t, 4, 1> uint16_t4x1;
+typedef matrix<uint16_t, 4, 2> uint16_t4x2;
+typedef matrix<uint16_t, 4, 3> uint16_t4x3;
+typedef matrix<uint16_t, 4, 4> uint16_t4x4;
+#endif
+
+typedef matrix<int, 1, 1> int1x1;
+typedef matrix<int, 1, 2> int1x2;
+typedef matrix<int, 1, 3> int1x3;
+typedef matrix<int, 1, 4> int1x4;
+typedef matrix<int, 2, 1> int2x1;
+typedef matrix<int, 2, 2> int2x2;
+typedef matrix<int, 2, 3> int2x3;
+typedef matrix<int, 2, 4> int2x4;
+typedef matrix<int, 3, 1> int3x1;
+typedef matrix<int, 3, 2> int3x2;
+typedef matrix<int, 3, 3> int3x3;
+typedef matrix<int, 3, 4> int3x4;
+typedef matrix<int, 4, 1> int4x1;
+typedef matrix<int, 4, 2> int4x2;
+typedef matrix<int, 4, 3> int4x3;
+typedef matrix<int, 4, 4> int4x4;
+typedef matrix<uint, 1, 1> uint1x1;
+typedef matrix<uint, 1, 2> uint1x2;
+typedef matrix<uint, 1, 3> uint1x3;
+typedef matrix<uint, 1, 4> uint1x4;
+typedef matrix<uint, 2, 1> uint2x1;
+typedef matrix<uint, 2, 2> uint2x2;
+typedef matrix<uint, 2, 3> uint2x3;
+typedef matrix<uint, 2, 4> uint2x4;
+typedef matrix<uint, 3, 1> uint3x1;
+typedef matrix<uint, 3, 2> uint3x2;
+typedef matrix<uint, 3, 3> uint3x3;
+typedef matrix<uint, 3, 4> uint3x4;
+typedef matrix<uint, 4, 1> uint4x1;
+typedef matrix<uint, 4, 2> uint4x2;
+typedef matrix<uint, 4, 3> uint4x3;
+typedef matrix<uint, 4, 4> uint4x4;
+typedef matrix<int32_t, 1, 1> int32_t1x1;
+typedef matrix<int32_t, 1, 2> int32_t1x2;
+typedef matrix<int32_t, 1, 3> int32_t1x3;
+typedef matrix<int32_t, 1, 4> int32_t1x4;
+typedef matrix<int32_t, 2, 1> int32_t2x1;
+typedef matrix<int32_t, 2, 2> int32_t2x2;
+typedef matrix<int32_t, 2, 3> int32_t2x3;
+typedef matrix<int32_t, 2, 4> int32_t2x4;
+typedef matrix<int32_t, 3, 1> int32_t3x1;
+typedef matrix<int32_t, 3, 2> int32_t3x2;
+typedef matrix<int32_t, 3, 3> int32_t3x3;
+typedef matrix<int32_t, 3, 4> int32_t3x4;
+typedef matrix<int32_t, 4, 1> int32_t4x1;
+typedef matrix<int32_t, 4, 2> int32_t4x2;
+typedef matrix<int32_t, 4, 3> int32_t4x3;
+typedef matrix<int32_t, 4, 4> int32_t4x4;
+typedef matrix<uint32_t, 1, 1> uint32_t1x1;
+typedef matrix<uint32_t, 1, 2> uint32_t1x2;
+typedef matrix<uint32_t, 1, 3> uint32_t1x3;
+typedef matrix<uint32_t, 1, 4> uint32_t1x4;
+typedef matrix<uint32_t, 2, 1> uint32_t2x1;
+typedef matrix<uint32_t, 2, 2> uint32_t2x2;
+typedef matrix<uint32_t, 2, 3> uint32_t2x3;
+typedef matrix<uint32_t, 2, 4> uint32_t2x4;
+typedef matrix<uint32_t, 3, 1> uint32_t3x1;
+typedef matrix<uint32_t, 3, 2> uint32_t3x2;
+typedef matrix<uint32_t, 3, 3> uint32_t3x3;
+typedef matrix<uint32_t, 3, 4> uint32_t3x4;
+typedef matrix<uint32_t, 4, 1> uint32_t4x1;
+typedef matrix<uint32_t, 4, 2> uint32_t4x2;
+typedef matrix<uint32_t, 4, 3> uint32_t4x3;
+typedef matrix<uint32_t, 4, 4> uint32_t4x4;
+typedef matrix<int64_t, 1, 1> int64_t1x1;
+typedef matrix<int64_t, 1, 2> int64_t1x2;
+typedef matrix<int64_t, 1, 3> int64_t1x3;
+typedef matrix<int64_t, 1, 4> int64_t1x4;
+typedef matrix<int64_t, 2, 1> int64_t2x1;
+typedef matrix<int64_t, 2, 2> int64_t2x2;
+typedef matrix<int64_t, 2, 3> int64_t2x3;
+typedef matrix<int64_t, 2, 4> int64_t2x4;
+typedef matrix<int64_t, 3, 1> int64_t3x1;
+typedef matrix<int64_t, 3, 2> int64_t3x2;
+typedef matrix<int64_t, 3, 3> int64_t3x3;
+typedef matrix<int64_t, 3, 4> int64_t3x4;
+typedef matrix<int64_t, 4, 1> int64_t4x1;
+typedef matrix<int64_t, 4, 2> int64_t4x2;
+typedef matrix<int64_t, 4, 3> int64_t4x3;
+typedef matrix<int64_t, 4, 4> int64_t4x4;
+typedef matrix<uint64_t, 1, 1> uint64_t1x1;
+typedef matrix<uint64_t, 1, 2> uint64_t1x2;
+typedef matrix<uint64_t, 1, 3> uint64_t1x3;
+typedef matrix<uint64_t, 1, 4> uint64_t1x4;
+typedef matrix<uint64_t, 2, 1> uint64_t2x1;
+typedef matrix<uint64_t, 2, 2> uint64_t2x2;
+typedef matrix<uint64_t, 2, 3> uint64_t2x3;
+typedef matrix<uint64_t, 2, 4> uint64_t2x4;
+typedef matrix<uint64_t, 3, 1> uint64_t3x1;
+typedef matrix<uint64_t, 3, 2> uint64_t3x2;
+typedef matrix<uint64_t, 3, 3> uint64_t3x3;
+typedef matrix<uint64_t, 3, 4> uint64_t3x4;
+typedef matrix<uint64_t, 4, 1> uint64_t4x1;
+typedef matrix<uint64_t, 4, 2> uint64_t4x2;
+typedef matrix<uint64_t, 4, 3> uint64_t4x3;
+typedef matrix<uint64_t, 4, 4> uint64_t4x4;
+
+typedef matrix<half, 1, 1> half1x1;
+typedef matrix<half, 1, 2> half1x2;
+typedef matrix<half, 1, 3> half1x3;
+typedef matrix<half, 1, 4> half1x4;
+typedef matrix<half, 2, 1> half2x1;
+typedef matrix<half, 2, 2> half2x2;
+typedef matrix<half, 2, 3> half2x3;
+typedef matrix<half, 2, 4> half2x4;
+typedef matrix<half, 3, 1> half3x1;
+typedef matrix<half, 3, 2> half3x2;
+typedef matrix<half, 3, 3> half3x3;
+typedef matrix<half, 3, 4> half3x4;
+typedef matrix<half, 4, 1> half4x1;
+typedef matrix<half, 4, 2> half4x2;
+typedef matrix<half, 4, 3> half4x3;
+typedef matrix<half, 4, 4> half4x4;
+typedef matrix<float, 1, 1> float1x1;
+typedef matrix<float, 1, 2> float1x2;
+typedef matrix<float, 1, 3> float1x3;
+typedef matrix<float, 1, 4> float1x4;
+typedef matrix<float, 2, 1> float2x1;
+typedef matrix<float, 2, 2> float2x2;
+typedef matrix<float, 2, 3> float2x3;
+typedef matrix<float, 2, 4> float2x4;
+typedef matrix<float, 3, 1> float3x1;
+typedef matrix<float, 3, 2> float3x2;
+typedef matrix<float, 3, 3> float3x3;
+typedef matrix<float, 3, 4> float3x4;
+typedef matrix<float, 4, 1> float4x1;
+typedef matrix<float, 4, 2> float4x2;
+typedef matrix<float, 4, 3> float4x3;
+typedef matrix<float, 4, 4> float4x4;
+typedef matrix<double, 1, 1> double1x1;
+typedef matrix<double, 1, 2> double1x2;
+typedef matrix<double, 1, 3> double1x3;
+typedef matrix<double, 1, 4> double1x4;
+typedef matrix<double, 2, 1> double2x1;
+typedef matrix<double, 2, 2> double2x2;
+typedef matrix<double, 2, 3> double2x3;
+typedef matrix<double, 2, 4> double2x4;
+typedef matrix<double, 3, 1> double3x1;
+typedef matrix<double, 3, 2> double3x2;
+typedef matrix<double, 3, 3> double3x3;
+typedef matrix<double, 3, 4> double3x4;
+typedef matrix<double, 4, 1> double4x1;
+typedef matrix<double, 4, 2> double4x2;
+typedef matrix<double, 4, 3> double4x3;
+typedef matrix<double, 4, 4> double4x4;
+
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<float16_t, 1, 1> float16_t1x1;
+typedef matrix<float16_t, 1, 2> float16_t1x2;
+typedef matrix<float16_t, 1, 3> float16_t1x3;
+typedef matrix<float16_t, 1, 4> float16_t1x4;
+typedef matrix<float16_t, 2, 1> float16_t2x1;
+typedef matrix<float16_t, 2, 2> float16_t2x2;
+typedef matrix<float16_t, 2, 3> float16_t2x3;
+typedef matrix<float16_t, 2, 4> float16_t2x4;
+typedef matrix<float16_t, 3, 1> float16_t3x1;
+typedef matrix<float16_t, 3, 2> float16_t3x2;
+typedef matrix<float16_t, 3, 3> float16_t3x3;
+typedef matrix<float16_t, 3, 4> float16_t3x4;
+typedef matrix<float16_t, 4, 1> float16_t4x1;
+typedef matrix<float16_t, 4, 2> float16_t4x2;
+typedef matrix<float16_t, 4, 3> float16_t4x3;
+typedef matrix<float16_t, 4, 4> float16_t4x4;
+#endif
+
+typedef matrix<float32_t, 1, 1> float32_t1x1;
+typedef matrix<float32_t, 1, 2> float32_t1x2;
+typedef matrix<float32_t, 1, 3> float32_t1x3;
+typedef matrix<float32_t, 1, 4> float32_t1x4;
+typedef matrix<float32_t, 2, 1> float32_t2x1;
+typedef matrix<float32_t, 2, 2> float32_t2x2;
+typedef matrix<float32_t, 2, 3> float32_t2x3;
+typedef matrix<float32_t, 2, 4> float32_t2x4;
+typedef matrix<float32_t, 3, 1> float32_t3x1;
+typedef matrix<float32_t, 3, 2> float32_t3x2;
+typedef matrix<float32_t, 3, 3> float32_t3x3;
+typedef matrix<float32_t, 3, 4> float32_t3x4;
+typedef matrix<float32_t, 4, 1> float32_t4x1;
+typedef matrix<float32_t, 4, 2> float32_t4x2;
+typedef matrix<float32_t, 4, 3> float32_t4x3;
+typedef matrix<float32_t, 4, 4> float32_t4x4;
+typedef matrix<float64_t, 1, 1> float64_t1x1;
+typedef matrix<float64_t, 1, 2> float64_t1x2;
+typedef matrix<float64_t, 1, 3> float64_t1x3;
+typedef matrix<float64_t, 1, 4> float64_t1x4;
+typedef matrix<float64_t, 2, 1> float64_t2x1;
+typedef matrix<float64_t, 2, 2> float64_t2x2;
+typedef matrix<float64_t, 2, 3> float64_t2x3;
+typedef matrix<float64_t, 2, 4> float64_t2x4;
+typedef matrix<float64_t, 3, 1> float64_t3x1;
+typedef matrix<float64_t, 3, 2> float64_t3x2;
+typedef matrix<float64_t, 3, 3> float64_t3x3;
+typedef matrix<float64_t, 3, 4> float64_t3x4;
+typedef matrix<float64_t, 4, 1> float64_t4x1;
+typedef matrix<float64_t, 4, 2> float64_t4x2;
+typedef matrix<float64_t, 4, 3> float64_t4x3;
+typedef matrix<float64_t, 4, 4> float64_t4x4;
+
 } // namespace hlsl
 
 #endif //_HLSL_HLSL_BASIC_TYPES_H_
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 3386d8da281e9..9ac60712b87b8 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -121,8 +121,80 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
   HLSLNamespace->addDecl(Template);
 }
 
+void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
+  ASTContext &AST = SemaPtr->getASTContext();
+ llvm::SmallVector<NamedDecl *> TemplateParams;
+
+  auto *TypeParam = TemplateTypeParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+      &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+  TypeParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(
+               TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
+
+  TemplateParams.emplace_back(TypeParam);
+
+  // these should be 64 bit to be consistent with other clang matrices.
+  auto *RowsParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+      &AST.Idents.get("rows_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  llvm::APInt RVal(AST.getIntWidth(AST.IntTy), 4);
+  TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
+                            /*IsDefaulted=*/true);
+  RowsParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(RDefault, AST.IntTy,
+                                                  SourceLocation(), RowsParam));
+  TemplateParams.emplace_back(RowsParam);
+
+  auto *ColsParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 2,
+      &AST.Idents.get("cols_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  llvm::APInt CVal(AST.getIntWidth(AST.IntTy), 4);
+  TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
+                            /*IsDefaulted=*/true);
+  ColsParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(CDefault, AST.IntTy,
+                                                  SourceLocation(), ColsParam));
+  TemplateParams.emplace_back(ColsParam);
+
+  auto *ParamList =
+      TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
+                                    TemplateParams, SourceLocation(), nullptr);
+
+  IdentifierInfo &II = AST.Idents.get("matrix", tok::TokenKind::identifier);
+
+  QualType AliasType = AST.getDependentSizedMatrixType(
+      AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam, false,
+          DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam, false,
+          DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      SourceLocation());
+
+  auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                       SourceLocation(), &II,
+                                       AST.getTrivialTypeSourceInfo(AliasType));
+  Record->setImplicit(true);
+
+  auto *Template =
+      TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                    Record->getIdentifier(), ParamList, Record);
+
+  Record->setDescribedAliasTemplate(Template);
+  Template->setImplicit(true);
+  Template->setLexicalDeclContext(Record->getDeclContext());
+  HLSLNamespace->addDecl(Template);
+}
+
 void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
   defineHLSLVectorAlias();
+  defineHLSLMatrixAlias();
 }
 
 /// Set up common members and attributes for buffer types
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 0af38472b0fec..38f81a24945c5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3285,7 +3285,6 @@ static void BuildFlattenedTypeList(QualType BaseTy,
   while (!WorkList.empty()) {
     QualType T = WorkList.pop_back_val();
     T = T.getCanonicalType().getUnqualifiedType();
-    assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
     if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
       llvm::SmallVector<QualType, 16> ElementFields;
       // Generally I've avoided recursion in this algorithm, but arrays of
diff --git a/clang/test/AST/HLSL/matrix-alias.hlsl b/clang/test/AST/HLSL/matrix-alias.hlsl
new file mode 100644
index 0000000000000..2758b6f0d202f
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-alias.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -ast-dump -o - %s | FileCheck %s
+
+// Test that matrix aliases are set up properly for HLSL
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'vector<element, element_count>'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'vector<element, element_count>' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// CHECK-SAME: NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+  // Verify that the alias is generated inside the hlsl namespace.
+  hlsl::matrix<float, 2, 2> Mat2x2f;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:26:3, col:36>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:29> col:29 Mat2x2f 'hlsl::matrix<float, 2, 2>'
+
+  // Verify that you don't need to specify the namespace.
+  matrix<int, 2, 2> Mat2x2i;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:32:3, col:28>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:21> col:21 Mat2x2i 'matrix<int, 2, 2>'
+
+  // Build a bigger matrix.
+  matrix<double, 4, 4> Mat4x4d;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:38:3, col:31>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:24> col:24 Mat4x4d 'matrix<double, 4, 4>'
+
+  // Verify that the implicit arguments generate the correct type.
+  matrix<> ImpMat4x4;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:44:3, col:21>
+  // CHECK-NEXT: ...
[truncated]

Copy link

github-actions bot commented Sep 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@farzonl farzonl force-pushed the feature/hlsl_matrix-issue-109839 branch from 5081bc3 to 9e18ebb Compare September 18, 2025 14:36
Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me at a high level, great to see more users of the matrix intrinsics!

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we didn't do with the vector alias that we should do for the matrix one is add a concept requirement that the dimensions be in the range [1,4].

We removed the restriction on vector length in HLSL, but I don't believe we have any plans to do the same for matrices.

LANGOPT(HLSLStrictAvailability, 1, 0, NotCompatible,
"Strict availability diagnostic mode for HLSL built-in functions.")
LANGOPT(HLSLSpvUseUnknownImageFormat, 1, 0, NotCompatible, "For storage images and texel buffers, sets the default format to 'Unknown' when not specified via the `vk::image_format` attribute. If this option is not used, the format is inferred from the resource's data type.")
VALUE_LANGOPT(MaxMatrixDimension, 32, (1 << 20) - 1, NotCompatible, "maximum allowed matrix dimension")
Copy link
Member Author

@farzonl farzonl Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal should likely be to replace the constexpr MaxElementsPerDimension with this language option.

class ConstantMatrixType final : public MatrixType {
protected:
friend class ASTContext;
/// Number of rows and columns.
unsigned NumRows;
unsigned NumColumns;
static constexpr unsigned MaxElementsPerDimension = (1 << 20) - 1;

But that would require updating all of these static constexpr functions with something that takes a SemaPtr

/// Returns true if \p NumElements is a valid matrix dimension.
static constexpr bool isDimensionValid(size_t NumElements) {
return NumElements > 0 && NumElements <= MaxElementsPerDimension;
}
/// Returns the maximum number of elements per dimension.
static constexpr unsigned getMaxElementsPerDimension() {
return MaxElementsPerDimension;
}

For now we will make this an HLSL only language opt and create an issue to make it more generic:

Suggested change
VALUE_LANGOPT(MaxMatrixDimension, 32, (1 << 20) - 1, NotCompatible, "maximum allowed matrix dimension")
VALUE_LANGOPT(HLSLMaxMatrixDimension, 32, 4, NotCompatible, "maximum allowed matrix dimension")

issue #160190


auto *ParamList = TemplateParameterList::Create(
AST, SourceLocation(), SourceLocation(), TemplateParams, SourceLocation(),
RequiresExpr);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@llvm-beanz I had two questions on this commit for you.

  1. First should we be adding the size limits on the alias like I am doing above or should we be doing it on the underlying matrix_type?
  2. Second it was pretty easy to enforce the size limits without a concept. Are you ok with The RequiresExpr that TemplateParameterList takes as our means of enforcing a size limit?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. First should we be adding the size limits on the alias like I am doing above or should we be doing it on the underlying matrix_type?

I think this probably boils down to the quality of the error messages, and how we would handle larger matrices in the backend. Since we probably don't want to handle larger matrices we should probably restrict matrix_type. If we get better diagnostics from the requirement being on the template, we may want to do that also.

  1. Second it was pretty easy to enforce the size limits without a concept. Are you ok with The RequiresExpr that TemplateParameterList takes as our means of enforcing a size limit?

Requirement expressions are part of C++ concepts, that was my intention for how this would be done. I don't think we need a separately defined ConceptDecl.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'n that case I'm going to take enforcement of size on matrix_type as a change we will do on #160190 since it sounds like the diagnostics would be better to do this on the alias aswell as the underlying type.

@farzonl farzonl force-pushed the feature/hlsl_matrix-issue-109839 branch from dcb1a9e to d81fd36 Compare September 22, 2025 22:41
LANGOPT(HLSLStrictAvailability, 1, 0, NotCompatible,
"Strict availability diagnostic mode for HLSL built-in functions.")
LANGOPT(HLSLSpvUseUnknownImageFormat, 1, 0, NotCompatible, "For storage images and texel buffers, sets the default format to 'Unknown' when not specified via the `vk::image_format` attribute. If this option is not used, the format is inferred from the resource's data type.")
VALUE_LANGOPT(HLSLMaxMatrixDimension, 32, 4, NotCompatible, "maximum allowed matrix dimension")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a language option? Since it is just for HLSL it will always be 4.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to expand it to a lang opt for all matrix types in #160190.
see the diff d81fd36

Doing it in this pr was adding a bunch of changes I felt distracted from what I was trying to accomplish. so It will be in its own pr.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, but for this PR it could just be constant 4 in HLSLExternalSemaSource. Then we can discuss the language option separately.

@farzonl farzonl merged commit 6ac40aa into llvm:main Sep 23, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[HLSL] Enable clang extension matrices in HLSL
5 participants