-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[HLSL] Add support for the HLSL matrix type #159446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
fixes llvm#109839 This change is really simple. It creates a matrix alias that will let HLSL use the existing clang `matrix_type` infra. The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL. Testing therefore is limited to exercising the alias. The main difference in this attempt is the type printer.
@llvm/pr-subscribers-hlsl @llvm/pr-subscribers-clang Author: Farzon Lotfi (farzonl) Changesfixes #109839 This change is really simple. It creates a matrix alias that will let HLSL use the existing clang The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL. Testing therefore is limited to exercising the alias, sema errors, and basic codegen. The main difference in this attempt is the type printer and less of an emphasis on tests where things overlap with existing Patch is 24.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159446.diff 10 Files Affected:
diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index a7c514e809aa9..e8869e04390f3 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -4582,7 +4582,7 @@ defm ptrauth_block_descriptor_pointers : OptInCC1FFlag<"ptrauth-block-descriptor
def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
Visibility<[ClangOption, CC1Option]>,
HelpText<"Enable matrix data type and related builtin functions">,
- MarshallingInfoFlag<LangOpts<"MatrixTypes">>;
+ MarshallingInfoFlag<LangOpts<"MatrixTypes">, hlsl.KeyPath>;
defm raw_string_literals : BoolFOption<"raw-string-literals",
LangOpts<"RawStringLiterals">, Default<std#".hasRawStringLiterals()">,
diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h
index d93fb8c8eef6b..049fc7b8fe3f2 100644
--- a/clang/include/clang/Sema/HLSLExternalSemaSource.h
+++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -44,6 +44,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
private:
void defineTrivialHLSLTypes();
void defineHLSLVectorAlias();
+ void defineHLSLMatrixAlias();
void defineHLSLTypesWithForwardDeclarations();
void onCompletion(CXXRecordDecl *Record, CompletionFunction Fn);
};
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index cd59678d67f2f..82859c4015772 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -846,16 +846,41 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
}
}
-void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
- raw_ostream &OS) {
- printBefore(T->getElementType(), OS);
- OS << " __attribute__((matrix_type(";
+static void printDims(const ConstantMatrixType *T, raw_ostream &OS) {
OS << T->getNumRows() << ", " << T->getNumColumns();
+}
+
+static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) {
+ OS << "matrix<";
+ TP.printBefore(T->getElementType(), OS);
+}
+
+static void printHLSLMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
+ OS << ", ";
+ printDims(T, OS);
+ OS << ">";
+}
+
+static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) {
+ TP.printBefore(T->getElementType(), OS);
+ OS << " __attribute__((matrix_type(";
+ printDims(T, OS);
OS << ")))";
}
-void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
- raw_ostream &OS) {
+void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, raw_ostream &OS) {
+ if (Policy.UseHLSLTypes) {
+ printHLSLMatrixBefore(*this, T, OS);
+ return;
+ }
+ printClangMatrixBefore(*this, T, OS);
+}
+
+void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
+ if (Policy.UseHLSLTypes) {
+ printHLSLMatrixAfter(T, OS);
+ return;
+ }
printAfter(T->getElementType(), OS);
}
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index eff94e0d7f950..c750261952d4f 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -115,6 +115,239 @@ typedef vector<float64_t, 2> float64_t2;
typedef vector<float64_t, 3> float64_t3;
typedef vector<float64_t, 4> float64_t4;
+ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<int16_t, 1, 1> int16_t1x1;
+typedef matrix<int16_t, 1, 2> int16_t1x2;
+typedef matrix<int16_t, 1, 3> int16_t1x3;
+typedef matrix<int16_t, 1, 4> int16_t1x4;
+typedef matrix<int16_t, 2, 1> int16_t2x1;
+typedef matrix<int16_t, 2, 2> int16_t2x2;
+typedef matrix<int16_t, 2, 3> int16_t2x3;
+typedef matrix<int16_t, 2, 4> int16_t2x4;
+typedef matrix<int16_t, 3, 1> int16_t3x1;
+typedef matrix<int16_t, 3, 2> int16_t3x2;
+typedef matrix<int16_t, 3, 3> int16_t3x3;
+typedef matrix<int16_t, 3, 4> int16_t3x4;
+typedef matrix<int16_t, 4, 1> int16_t4x1;
+typedef matrix<int16_t, 4, 2> int16_t4x2;
+typedef matrix<int16_t, 4, 3> int16_t4x3;
+typedef matrix<int16_t, 4, 4> int16_t4x4;
+typedef matrix<uint16_t, 1, 1> uint16_t1x1;
+typedef matrix<uint16_t, 1, 2> uint16_t1x2;
+typedef matrix<uint16_t, 1, 3> uint16_t1x3;
+typedef matrix<uint16_t, 1, 4> uint16_t1x4;
+typedef matrix<uint16_t, 2, 1> uint16_t2x1;
+typedef matrix<uint16_t, 2, 2> uint16_t2x2;
+typedef matrix<uint16_t, 2, 3> uint16_t2x3;
+typedef matrix<uint16_t, 2, 4> uint16_t2x4;
+typedef matrix<uint16_t, 3, 1> uint16_t3x1;
+typedef matrix<uint16_t, 3, 2> uint16_t3x2;
+typedef matrix<uint16_t, 3, 3> uint16_t3x3;
+typedef matrix<uint16_t, 3, 4> uint16_t3x4;
+typedef matrix<uint16_t, 4, 1> uint16_t4x1;
+typedef matrix<uint16_t, 4, 2> uint16_t4x2;
+typedef matrix<uint16_t, 4, 3> uint16_t4x3;
+typedef matrix<uint16_t, 4, 4> uint16_t4x4;
+#endif
+
+typedef matrix<int, 1, 1> int1x1;
+typedef matrix<int, 1, 2> int1x2;
+typedef matrix<int, 1, 3> int1x3;
+typedef matrix<int, 1, 4> int1x4;
+typedef matrix<int, 2, 1> int2x1;
+typedef matrix<int, 2, 2> int2x2;
+typedef matrix<int, 2, 3> int2x3;
+typedef matrix<int, 2, 4> int2x4;
+typedef matrix<int, 3, 1> int3x1;
+typedef matrix<int, 3, 2> int3x2;
+typedef matrix<int, 3, 3> int3x3;
+typedef matrix<int, 3, 4> int3x4;
+typedef matrix<int, 4, 1> int4x1;
+typedef matrix<int, 4, 2> int4x2;
+typedef matrix<int, 4, 3> int4x3;
+typedef matrix<int, 4, 4> int4x4;
+typedef matrix<uint, 1, 1> uint1x1;
+typedef matrix<uint, 1, 2> uint1x2;
+typedef matrix<uint, 1, 3> uint1x3;
+typedef matrix<uint, 1, 4> uint1x4;
+typedef matrix<uint, 2, 1> uint2x1;
+typedef matrix<uint, 2, 2> uint2x2;
+typedef matrix<uint, 2, 3> uint2x3;
+typedef matrix<uint, 2, 4> uint2x4;
+typedef matrix<uint, 3, 1> uint3x1;
+typedef matrix<uint, 3, 2> uint3x2;
+typedef matrix<uint, 3, 3> uint3x3;
+typedef matrix<uint, 3, 4> uint3x4;
+typedef matrix<uint, 4, 1> uint4x1;
+typedef matrix<uint, 4, 2> uint4x2;
+typedef matrix<uint, 4, 3> uint4x3;
+typedef matrix<uint, 4, 4> uint4x4;
+typedef matrix<int32_t, 1, 1> int32_t1x1;
+typedef matrix<int32_t, 1, 2> int32_t1x2;
+typedef matrix<int32_t, 1, 3> int32_t1x3;
+typedef matrix<int32_t, 1, 4> int32_t1x4;
+typedef matrix<int32_t, 2, 1> int32_t2x1;
+typedef matrix<int32_t, 2, 2> int32_t2x2;
+typedef matrix<int32_t, 2, 3> int32_t2x3;
+typedef matrix<int32_t, 2, 4> int32_t2x4;
+typedef matrix<int32_t, 3, 1> int32_t3x1;
+typedef matrix<int32_t, 3, 2> int32_t3x2;
+typedef matrix<int32_t, 3, 3> int32_t3x3;
+typedef matrix<int32_t, 3, 4> int32_t3x4;
+typedef matrix<int32_t, 4, 1> int32_t4x1;
+typedef matrix<int32_t, 4, 2> int32_t4x2;
+typedef matrix<int32_t, 4, 3> int32_t4x3;
+typedef matrix<int32_t, 4, 4> int32_t4x4;
+typedef matrix<uint32_t, 1, 1> uint32_t1x1;
+typedef matrix<uint32_t, 1, 2> uint32_t1x2;
+typedef matrix<uint32_t, 1, 3> uint32_t1x3;
+typedef matrix<uint32_t, 1, 4> uint32_t1x4;
+typedef matrix<uint32_t, 2, 1> uint32_t2x1;
+typedef matrix<uint32_t, 2, 2> uint32_t2x2;
+typedef matrix<uint32_t, 2, 3> uint32_t2x3;
+typedef matrix<uint32_t, 2, 4> uint32_t2x4;
+typedef matrix<uint32_t, 3, 1> uint32_t3x1;
+typedef matrix<uint32_t, 3, 2> uint32_t3x2;
+typedef matrix<uint32_t, 3, 3> uint32_t3x3;
+typedef matrix<uint32_t, 3, 4> uint32_t3x4;
+typedef matrix<uint32_t, 4, 1> uint32_t4x1;
+typedef matrix<uint32_t, 4, 2> uint32_t4x2;
+typedef matrix<uint32_t, 4, 3> uint32_t4x3;
+typedef matrix<uint32_t, 4, 4> uint32_t4x4;
+typedef matrix<int64_t, 1, 1> int64_t1x1;
+typedef matrix<int64_t, 1, 2> int64_t1x2;
+typedef matrix<int64_t, 1, 3> int64_t1x3;
+typedef matrix<int64_t, 1, 4> int64_t1x4;
+typedef matrix<int64_t, 2, 1> int64_t2x1;
+typedef matrix<int64_t, 2, 2> int64_t2x2;
+typedef matrix<int64_t, 2, 3> int64_t2x3;
+typedef matrix<int64_t, 2, 4> int64_t2x4;
+typedef matrix<int64_t, 3, 1> int64_t3x1;
+typedef matrix<int64_t, 3, 2> int64_t3x2;
+typedef matrix<int64_t, 3, 3> int64_t3x3;
+typedef matrix<int64_t, 3, 4> int64_t3x4;
+typedef matrix<int64_t, 4, 1> int64_t4x1;
+typedef matrix<int64_t, 4, 2> int64_t4x2;
+typedef matrix<int64_t, 4, 3> int64_t4x3;
+typedef matrix<int64_t, 4, 4> int64_t4x4;
+typedef matrix<uint64_t, 1, 1> uint64_t1x1;
+typedef matrix<uint64_t, 1, 2> uint64_t1x2;
+typedef matrix<uint64_t, 1, 3> uint64_t1x3;
+typedef matrix<uint64_t, 1, 4> uint64_t1x4;
+typedef matrix<uint64_t, 2, 1> uint64_t2x1;
+typedef matrix<uint64_t, 2, 2> uint64_t2x2;
+typedef matrix<uint64_t, 2, 3> uint64_t2x3;
+typedef matrix<uint64_t, 2, 4> uint64_t2x4;
+typedef matrix<uint64_t, 3, 1> uint64_t3x1;
+typedef matrix<uint64_t, 3, 2> uint64_t3x2;
+typedef matrix<uint64_t, 3, 3> uint64_t3x3;
+typedef matrix<uint64_t, 3, 4> uint64_t3x4;
+typedef matrix<uint64_t, 4, 1> uint64_t4x1;
+typedef matrix<uint64_t, 4, 2> uint64_t4x2;
+typedef matrix<uint64_t, 4, 3> uint64_t4x3;
+typedef matrix<uint64_t, 4, 4> uint64_t4x4;
+
+typedef matrix<half, 1, 1> half1x1;
+typedef matrix<half, 1, 2> half1x2;
+typedef matrix<half, 1, 3> half1x3;
+typedef matrix<half, 1, 4> half1x4;
+typedef matrix<half, 2, 1> half2x1;
+typedef matrix<half, 2, 2> half2x2;
+typedef matrix<half, 2, 3> half2x3;
+typedef matrix<half, 2, 4> half2x4;
+typedef matrix<half, 3, 1> half3x1;
+typedef matrix<half, 3, 2> half3x2;
+typedef matrix<half, 3, 3> half3x3;
+typedef matrix<half, 3, 4> half3x4;
+typedef matrix<half, 4, 1> half4x1;
+typedef matrix<half, 4, 2> half4x2;
+typedef matrix<half, 4, 3> half4x3;
+typedef matrix<half, 4, 4> half4x4;
+typedef matrix<float, 1, 1> float1x1;
+typedef matrix<float, 1, 2> float1x2;
+typedef matrix<float, 1, 3> float1x3;
+typedef matrix<float, 1, 4> float1x4;
+typedef matrix<float, 2, 1> float2x1;
+typedef matrix<float, 2, 2> float2x2;
+typedef matrix<float, 2, 3> float2x3;
+typedef matrix<float, 2, 4> float2x4;
+typedef matrix<float, 3, 1> float3x1;
+typedef matrix<float, 3, 2> float3x2;
+typedef matrix<float, 3, 3> float3x3;
+typedef matrix<float, 3, 4> float3x4;
+typedef matrix<float, 4, 1> float4x1;
+typedef matrix<float, 4, 2> float4x2;
+typedef matrix<float, 4, 3> float4x3;
+typedef matrix<float, 4, 4> float4x4;
+typedef matrix<double, 1, 1> double1x1;
+typedef matrix<double, 1, 2> double1x2;
+typedef matrix<double, 1, 3> double1x3;
+typedef matrix<double, 1, 4> double1x4;
+typedef matrix<double, 2, 1> double2x1;
+typedef matrix<double, 2, 2> double2x2;
+typedef matrix<double, 2, 3> double2x3;
+typedef matrix<double, 2, 4> double2x4;
+typedef matrix<double, 3, 1> double3x1;
+typedef matrix<double, 3, 2> double3x2;
+typedef matrix<double, 3, 3> double3x3;
+typedef matrix<double, 3, 4> double3x4;
+typedef matrix<double, 4, 1> double4x1;
+typedef matrix<double, 4, 2> double4x2;
+typedef matrix<double, 4, 3> double4x3;
+typedef matrix<double, 4, 4> double4x4;
+
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<float16_t, 1, 1> float16_t1x1;
+typedef matrix<float16_t, 1, 2> float16_t1x2;
+typedef matrix<float16_t, 1, 3> float16_t1x3;
+typedef matrix<float16_t, 1, 4> float16_t1x4;
+typedef matrix<float16_t, 2, 1> float16_t2x1;
+typedef matrix<float16_t, 2, 2> float16_t2x2;
+typedef matrix<float16_t, 2, 3> float16_t2x3;
+typedef matrix<float16_t, 2, 4> float16_t2x4;
+typedef matrix<float16_t, 3, 1> float16_t3x1;
+typedef matrix<float16_t, 3, 2> float16_t3x2;
+typedef matrix<float16_t, 3, 3> float16_t3x3;
+typedef matrix<float16_t, 3, 4> float16_t3x4;
+typedef matrix<float16_t, 4, 1> float16_t4x1;
+typedef matrix<float16_t, 4, 2> float16_t4x2;
+typedef matrix<float16_t, 4, 3> float16_t4x3;
+typedef matrix<float16_t, 4, 4> float16_t4x4;
+#endif
+
+typedef matrix<float32_t, 1, 1> float32_t1x1;
+typedef matrix<float32_t, 1, 2> float32_t1x2;
+typedef matrix<float32_t, 1, 3> float32_t1x3;
+typedef matrix<float32_t, 1, 4> float32_t1x4;
+typedef matrix<float32_t, 2, 1> float32_t2x1;
+typedef matrix<float32_t, 2, 2> float32_t2x2;
+typedef matrix<float32_t, 2, 3> float32_t2x3;
+typedef matrix<float32_t, 2, 4> float32_t2x4;
+typedef matrix<float32_t, 3, 1> float32_t3x1;
+typedef matrix<float32_t, 3, 2> float32_t3x2;
+typedef matrix<float32_t, 3, 3> float32_t3x3;
+typedef matrix<float32_t, 3, 4> float32_t3x4;
+typedef matrix<float32_t, 4, 1> float32_t4x1;
+typedef matrix<float32_t, 4, 2> float32_t4x2;
+typedef matrix<float32_t, 4, 3> float32_t4x3;
+typedef matrix<float32_t, 4, 4> float32_t4x4;
+typedef matrix<float64_t, 1, 1> float64_t1x1;
+typedef matrix<float64_t, 1, 2> float64_t1x2;
+typedef matrix<float64_t, 1, 3> float64_t1x3;
+typedef matrix<float64_t, 1, 4> float64_t1x4;
+typedef matrix<float64_t, 2, 1> float64_t2x1;
+typedef matrix<float64_t, 2, 2> float64_t2x2;
+typedef matrix<float64_t, 2, 3> float64_t2x3;
+typedef matrix<float64_t, 2, 4> float64_t2x4;
+typedef matrix<float64_t, 3, 1> float64_t3x1;
+typedef matrix<float64_t, 3, 2> float64_t3x2;
+typedef matrix<float64_t, 3, 3> float64_t3x3;
+typedef matrix<float64_t, 3, 4> float64_t3x4;
+typedef matrix<float64_t, 4, 1> float64_t4x1;
+typedef matrix<float64_t, 4, 2> float64_t4x2;
+typedef matrix<float64_t, 4, 3> float64_t4x3;
+typedef matrix<float64_t, 4, 4> float64_t4x4;
+
} // namespace hlsl
#endif //_HLSL_HLSL_BASIC_TYPES_H_
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 3386d8da281e9..9ac60712b87b8 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -121,8 +121,80 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
HLSLNamespace->addDecl(Template);
}
+void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
+ ASTContext &AST = SemaPtr->getASTContext();
+ llvm::SmallVector<NamedDecl *> TemplateParams;
+
+ auto *TypeParam = TemplateTypeParmDecl::Create(
+ AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+ &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+ TypeParam->setDefaultArgument(
+ AST, SemaPtr->getTrivialTemplateArgumentLoc(
+ TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
+
+ TemplateParams.emplace_back(TypeParam);
+
+ // these should be 64 bit to be consistent with other clang matrices.
+ auto *RowsParam = NonTypeTemplateParmDecl::Create(
+ AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+ &AST.Idents.get("rows_count", tok::TokenKind::identifier), AST.IntTy,
+ false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+ llvm::APInt RVal(AST.getIntWidth(AST.IntTy), 4);
+ TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
+ /*IsDefaulted=*/true);
+ RowsParam->setDefaultArgument(
+ AST, SemaPtr->getTrivialTemplateArgumentLoc(RDefault, AST.IntTy,
+ SourceLocation(), RowsParam));
+ TemplateParams.emplace_back(RowsParam);
+
+ auto *ColsParam = NonTypeTemplateParmDecl::Create(
+ AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 2,
+ &AST.Idents.get("cols_count", tok::TokenKind::identifier), AST.IntTy,
+ false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+ llvm::APInt CVal(AST.getIntWidth(AST.IntTy), 4);
+ TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
+ /*IsDefaulted=*/true);
+ ColsParam->setDefaultArgument(
+ AST, SemaPtr->getTrivialTemplateArgumentLoc(CDefault, AST.IntTy,
+ SourceLocation(), ColsParam));
+ TemplateParams.emplace_back(ColsParam);
+
+ auto *ParamList =
+ TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
+ TemplateParams, SourceLocation(), nullptr);
+
+ IdentifierInfo &II = AST.Idents.get("matrix", tok::TokenKind::identifier);
+
+ QualType AliasType = AST.getDependentSizedMatrixType(
+ AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+ DeclRefExpr::Create(
+ AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam, false,
+ DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
+ AST.IntTy, VK_LValue),
+ DeclRefExpr::Create(
+ AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam, false,
+ DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
+ AST.IntTy, VK_LValue),
+ SourceLocation());
+
+ auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+ SourceLocation(), &II,
+ AST.getTrivialTypeSourceInfo(AliasType));
+ Record->setImplicit(true);
+
+ auto *Template =
+ TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+ Record->getIdentifier(), ParamList, Record);
+
+ Record->setDescribedAliasTemplate(Template);
+ Template->setImplicit(true);
+ Template->setLexicalDeclContext(Record->getDeclContext());
+ HLSLNamespace->addDecl(Template);
+}
+
void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
defineHLSLVectorAlias();
+ defineHLSLMatrixAlias();
}
/// Set up common members and attributes for buffer types
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 0af38472b0fec..38f81a24945c5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3285,7 +3285,6 @@ static void BuildFlattenedTypeList(QualType BaseTy,
while (!WorkList.empty()) {
QualType T = WorkList.pop_back_val();
T = T.getCanonicalType().getUnqualifiedType();
- assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
llvm::SmallVector<QualType, 16> ElementFields;
// Generally I've avoided recursion in this algorithm, but arrays of
diff --git a/clang/test/AST/HLSL/matrix-alias.hlsl b/clang/test/AST/HLSL/matrix-alias.hlsl
new file mode 100644
index 0000000000000..2758b6f0d202f
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-alias.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -ast-dump -o - %s | FileCheck %s
+
+// Test that matrix aliases are set up properly for HLSL
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'vector<element, element_count>'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'vector<element, element_count>' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// CHECK-SAME: NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+ // Verify that the alias is generated inside the hlsl namespace.
+ hlsl::matrix<float, 2, 2> Mat2x2f;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:26:3, col:36>
+ // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:29> col:29 Mat2x2f 'hlsl::matrix<float, 2, 2>'
+
+ // Verify that you don't need to specify the namespace.
+ matrix<int, 2, 2> Mat2x2i;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:32:3, col:28>
+ // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:21> col:21 Mat2x2i 'matrix<int, 2, 2>'
+
+ // Build a bigger matrix.
+ matrix<double, 4, 4> Mat4x4d;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:38:3, col:31>
+ // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:24> col:24 Mat4x4d 'matrix<double, 4, 4>'
+
+ // Verify that the implicit arguments generate the correct type.
+ matrix<> ImpMat4x4;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:44:3, col:21>
+ // CHECK-NEXT: ...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
5081bc3
to
9e18ebb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me at a high level, great to see more users of the matrix intrinsics!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing we didn't do with the vector alias that we should do for the matrix
one is add a concept requirement that the dimensions be in the range [1,4].
We removed the restriction on vector length in HLSL, but I don't believe we have any plans to do the same for matrices.
LANGOPT(HLSLStrictAvailability, 1, 0, NotCompatible, | ||
"Strict availability diagnostic mode for HLSL built-in functions.") | ||
LANGOPT(HLSLSpvUseUnknownImageFormat, 1, 0, NotCompatible, "For storage images and texel buffers, sets the default format to 'Unknown' when not specified via the `vk::image_format` attribute. If this option is not used, the format is inferred from the resource's data type.") | ||
VALUE_LANGOPT(MaxMatrixDimension, 32, (1 << 20) - 1, NotCompatible, "maximum allowed matrix dimension") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal should likely be to replace the constexpr MaxElementsPerDimension
with this language option.
llvm-project/clang/include/clang/AST/TypeBase.h
Lines 4371 to 4379 in 129c683
class ConstantMatrixType final : public MatrixType { | |
protected: | |
friend class ASTContext; | |
/// Number of rows and columns. | |
unsigned NumRows; | |
unsigned NumColumns; | |
static constexpr unsigned MaxElementsPerDimension = (1 << 20) - 1; |
But that would require updating all of these static constexpr functions with something that takes a SemaPtr
llvm-project/clang/include/clang/AST/TypeBase.h
Lines 4399 to 4407 in 129c683
/// Returns true if \p NumElements is a valid matrix dimension. | |
static constexpr bool isDimensionValid(size_t NumElements) { | |
return NumElements > 0 && NumElements <= MaxElementsPerDimension; | |
} | |
/// Returns the maximum number of elements per dimension. | |
static constexpr unsigned getMaxElementsPerDimension() { | |
return MaxElementsPerDimension; | |
} |
For now we will make this an HLSL only language opt and create an issue to make it more generic:
VALUE_LANGOPT(MaxMatrixDimension, 32, (1 << 20) - 1, NotCompatible, "maximum allowed matrix dimension") | |
VALUE_LANGOPT(HLSLMaxMatrixDimension, 32, 4, NotCompatible, "maximum allowed matrix dimension") |
issue #160190
|
||
auto *ParamList = TemplateParameterList::Create( | ||
AST, SourceLocation(), SourceLocation(), TemplateParams, SourceLocation(), | ||
RequiresExpr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@llvm-beanz I had two questions on this commit for you.
- First should we be adding the size limits on the alias like I am doing above or should we be doing it on the underlying
matrix_type
? - Second it was pretty easy to enforce the size limits without a concept. Are you ok with The RequiresExpr that TemplateParameterList takes as our means of enforcing a size limit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- First should we be adding the size limits on the alias like I am doing above or should we be doing it on the underlying
matrix_type
?
I think this probably boils down to the quality of the error messages, and how we would handle larger matrices in the backend. Since we probably don't want to handle larger matrices we should probably restrict matrix_type
. If we get better diagnostics from the requirement being on the template, we may want to do that also.
- Second it was pretty easy to enforce the size limits without a concept. Are you ok with The RequiresExpr that TemplateParameterList takes as our means of enforcing a size limit?
Requirement expressions are part of C++ concepts, that was my intention for how this would be done. I don't think we need a separately defined ConceptDecl
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'n that case I'm going to take enforcement of size on matrix_type
as a change we will do on #160190 since it sounds like the diagnostics would be better to do this on the alias aswell as the underlying type.
dcb1a9e
to
d81fd36
Compare
LANGOPT(HLSLStrictAvailability, 1, 0, NotCompatible, | ||
"Strict availability diagnostic mode for HLSL built-in functions.") | ||
LANGOPT(HLSLSpvUseUnknownImageFormat, 1, 0, NotCompatible, "For storage images and texel buffers, sets the default format to 'Unknown' when not specified via the `vk::image_format` attribute. If this option is not used, the format is inferred from the resource's data type.") | ||
VALUE_LANGOPT(HLSLMaxMatrixDimension, 32, 4, NotCompatible, "maximum allowed matrix dimension") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be a language option? Since it is just for HLSL it will always be 4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, but for this PR it could just be constant 4 in HLSLExternalSemaSource. Then we can discuss the language option separately.
fixes #109839
This change is really simple. It creates a matrix alias that will let HLSL use the existing clang
matrix_type
infra.The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL.
Testing therefore is limited to exercising the alias, sema errors, and basic codegen.
future work will add things like constructors and accessors.
The main difference in this attempt is the type printer and less of an emphasis on tests where things overlap with existing
matrix_type
testing like cast behavior.