Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Sema/HLSLExternalSemaSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
llvm::DenseMap<CXXRecordDecl *, CompletionFunction> Completions;

void defineHLSLVectorAlias();
void defineHLSLMatrixAlias();
void defineTrivialHLSLTypes();
void defineHLSLTypesWithForwardDeclarations();

Expand Down
2 changes: 1 addition & 1 deletion clang/lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ void ASTContext::InitBuiltinTypes(const TargetInfo &Target,
if (LangOpts.OpenACC && !LangOpts.OpenMP) {
InitBuiltinType(ArraySectionTy, BuiltinType::ArraySection);
}
if (LangOpts.MatrixTypes)
if (LangOpts.MatrixTypes || LangOpts.HLSL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the only line this will be needed then feel free to ignore this comment.

My question is could there be places where LangOpts.MatrixTypes could diverge from HLSL?

If so instead is it possible for LangOpts.HLSL to turn on LangOpts.MatrixTypes?

Copy link
Collaborator

@llvm-beanz llvm-beanz Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add MatrixTypes to the definition of HLSL's variants in LangStandards.def, which is probably a safer approach to this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into that. I didn't want to enable matrices by default because I didn't really want to allow the matrix attribute syntax in HLSL.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done this two ways now:
ae1a58b just enables matrices when HLSL is enabled
eeb3165 attempts to use LangStandards and make it dependent on HLSL versions

I confess I don't know why the second is preferable to the first. If I've misunderstood @llvm-beanz's suggestion, please let me know.

I'm actually unclear on how any of this prevents the problem @farzonl brought up. If what MatrixTypes represents changes so that it no longer matches what HLSL wants, then those changes will seep in without our notice. Given that the design discussion we had about this rejected the notion that we'd forge our own matrix type that inherited from the clang type, I don't see any way around that problem unless we pay attention to its evolution and weigh in when needed. Not that I'm saying that's a bad idea regardless.

All that the above changes alter is that now matrices can be declared using the clang extension which will sidestep at least the size restrictions we have not yet applied, but intend to.

I'm unconvinced this is an improvement, but it was easy enough to do that I thought I'd give us more concrete options to discuss.

InitBuiltinType(IncompleteMatrixIdxTy, BuiltinType::IncompleteMatrixIdx);

// Builtin types for 'id', 'Class', and 'SEL'.
Expand Down
40 changes: 28 additions & 12 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -852,34 +852,50 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {

void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
raw_ostream &OS) {
if (Policy.UseHLSLTypes)
OS << "matrix<";
printBefore(T->getElementType(), OS);
OS << " __attribute__((matrix_type(";
OS << T->getNumRows() << ", " << T->getNumColumns();
OS << ")))";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems wrong to me. IIUC this changes something like float __attribute__((matrix_type(4, 4)))* to float * __attribute__((matrix_type(4, 4))), which would mean the element type is float* rather than float.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my response to Florian here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be simpler to read to duplicate some code but have the HSL and C++ matrixes types printed completely separately?

}

void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
raw_ostream &OS) {
printAfter(T->getElementType(), OS);
if (Policy.UseHLSLTypes) {
OS << ", ";
OS << T->getNumRows() << ", " << T->getNumColumns();
OS << ">";
} else {
OS << " __attribute__((matrix_type(";
OS << T->getNumRows() << ", " << T->getNumColumns();
OS << ")))";
}
}

void TypePrinter::printDependentSizedMatrixBefore(
const DependentSizedMatrixType *T, raw_ostream &OS) {
if (Policy.UseHLSLTypes)
OS << "matrix<";
printBefore(T->getElementType(), OS);
OS << " __attribute__((matrix_type(";
if (T->getRowExpr()) {
T->getRowExpr()->printPretty(OS, nullptr, Policy);
}
OS << ", ";
if (T->getColumnExpr()) {
T->getColumnExpr()->printPretty(OS, nullptr, Policy);
}
OS << ")))";
}

void TypePrinter::printDependentSizedMatrixAfter(
const DependentSizedMatrixType *T, raw_ostream &OS) {
printAfter(T->getElementType(), OS);
if (Policy.UseHLSLTypes)
OS << ", ";
else
OS << " __attribute__((matrix_type(";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still have this in the "after" function. That will print this after pointer annotations which changes its meaning incorrectly.

I'm guessing there isn't an existing test case for this, otherwise this change would have broken something.

cc: @fhahn

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I guess I missed this one. AFAICT, DependentSizedMatrices are never printed.


if (Expr *E = T->getRowExpr())
E->printPretty(OS, nullptr, Policy);
OS << ", ";
if (Expr *E = T->getColumnExpr())
E->printPretty(OS, nullptr, Policy);

if (Policy.UseHLSLTypes)
OS << ">";
else
OS << ")))";
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to this file do alter how pointer and reference matrices are printed. I'd like to get @fhahn's opinion on this even if on nothing else.

}

void
Expand Down
232 changes: 232 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_basic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,238 @@ typedef vector<float64_t, 2> float64_t2;
typedef vector<float64_t, 3> float64_t3;
typedef vector<float64_t, 4> float64_t4;

#ifdef __HLSL_ENABLE_16_BIT
typedef matrix<int16_t, 1, 1> int16_t1x1;
typedef matrix<int16_t, 1, 2> int16_t1x2;
typedef matrix<int16_t, 1, 3> int16_t1x3;
typedef matrix<int16_t, 1, 4> int16_t1x4;
typedef matrix<int16_t, 2, 1> int16_t2x1;
typedef matrix<int16_t, 2, 2> int16_t2x2;
typedef matrix<int16_t, 2, 3> int16_t2x3;
typedef matrix<int16_t, 2, 4> int16_t2x4;
typedef matrix<int16_t, 3, 1> int16_t3x1;
typedef matrix<int16_t, 3, 2> int16_t3x2;
typedef matrix<int16_t, 3, 3> int16_t3x3;
typedef matrix<int16_t, 3, 4> int16_t3x4;
typedef matrix<int16_t, 4, 1> int16_t4x1;
typedef matrix<int16_t, 4, 2> int16_t4x2;
typedef matrix<int16_t, 4, 3> int16_t4x3;
typedef matrix<int16_t, 4, 4> int16_t4x4;
typedef matrix<uint16_t, 1, 1> uint16_t1x1;
typedef matrix<uint16_t, 1, 2> uint16_t1x2;
typedef matrix<uint16_t, 1, 3> uint16_t1x3;
typedef matrix<uint16_t, 1, 4> uint16_t1x4;
typedef matrix<uint16_t, 2, 1> uint16_t2x1;
typedef matrix<uint16_t, 2, 2> uint16_t2x2;
typedef matrix<uint16_t, 2, 3> uint16_t2x3;
typedef matrix<uint16_t, 2, 4> uint16_t2x4;
typedef matrix<uint16_t, 3, 1> uint16_t3x1;
typedef matrix<uint16_t, 3, 2> uint16_t3x2;
typedef matrix<uint16_t, 3, 3> uint16_t3x3;
typedef matrix<uint16_t, 3, 4> uint16_t3x4;
typedef matrix<uint16_t, 4, 1> uint16_t4x1;
typedef matrix<uint16_t, 4, 2> uint16_t4x2;
typedef matrix<uint16_t, 4, 3> uint16_t4x3;
typedef matrix<uint16_t, 4, 4> uint16_t4x4;
#endif
typedef matrix<int, 1, 1> int1x1;
typedef matrix<int, 1, 2> int1x2;
typedef matrix<int, 1, 3> int1x3;
typedef matrix<int, 1, 4> int1x4;
typedef matrix<int, 2, 1> int2x1;
typedef matrix<int, 2, 2> int2x2;
typedef matrix<int, 2, 3> int2x3;
typedef matrix<int, 2, 4> int2x4;
typedef matrix<int, 3, 1> int3x1;
typedef matrix<int, 3, 2> int3x2;
typedef matrix<int, 3, 3> int3x3;
typedef matrix<int, 3, 4> int3x4;
typedef matrix<int, 4, 1> int4x1;
typedef matrix<int, 4, 2> int4x2;
typedef matrix<int, 4, 3> int4x3;
typedef matrix<int, 4, 4> int4x4;
typedef matrix<uint, 1, 1> uint1x1;
typedef matrix<uint, 1, 2> uint1x2;
typedef matrix<uint, 1, 3> uint1x3;
typedef matrix<uint, 1, 4> uint1x4;
typedef matrix<uint, 2, 1> uint2x1;
typedef matrix<uint, 2, 2> uint2x2;
typedef matrix<uint, 2, 3> uint2x3;
typedef matrix<uint, 2, 4> uint2x4;
typedef matrix<uint, 3, 1> uint3x1;
typedef matrix<uint, 3, 2> uint3x2;
typedef matrix<uint, 3, 3> uint3x3;
typedef matrix<uint, 3, 4> uint3x4;
typedef matrix<uint, 4, 1> uint4x1;
typedef matrix<uint, 4, 2> uint4x2;
typedef matrix<uint, 4, 3> uint4x3;
typedef matrix<uint, 4, 4> uint4x4;
typedef matrix<int32_t, 1, 1> int32_t1x1;
typedef matrix<int32_t, 1, 2> int32_t1x2;
typedef matrix<int32_t, 1, 3> int32_t1x3;
typedef matrix<int32_t, 1, 4> int32_t1x4;
typedef matrix<int32_t, 2, 1> int32_t2x1;
typedef matrix<int32_t, 2, 2> int32_t2x2;
typedef matrix<int32_t, 2, 3> int32_t2x3;
typedef matrix<int32_t, 2, 4> int32_t2x4;
typedef matrix<int32_t, 3, 1> int32_t3x1;
typedef matrix<int32_t, 3, 2> int32_t3x2;
typedef matrix<int32_t, 3, 3> int32_t3x3;
typedef matrix<int32_t, 3, 4> int32_t3x4;
typedef matrix<int32_t, 4, 1> int32_t4x1;
typedef matrix<int32_t, 4, 2> int32_t4x2;
typedef matrix<int32_t, 4, 3> int32_t4x3;
typedef matrix<int32_t, 4, 4> int32_t4x4;
typedef matrix<uint32_t, 1, 1> uint32_t1x1;
typedef matrix<uint32_t, 1, 2> uint32_t1x2;
typedef matrix<uint32_t, 1, 3> uint32_t1x3;
typedef matrix<uint32_t, 1, 4> uint32_t1x4;
typedef matrix<uint32_t, 2, 1> uint32_t2x1;
typedef matrix<uint32_t, 2, 2> uint32_t2x2;
typedef matrix<uint32_t, 2, 3> uint32_t2x3;
typedef matrix<uint32_t, 2, 4> uint32_t2x4;
typedef matrix<uint32_t, 3, 1> uint32_t3x1;
typedef matrix<uint32_t, 3, 2> uint32_t3x2;
typedef matrix<uint32_t, 3, 3> uint32_t3x3;
typedef matrix<uint32_t, 3, 4> uint32_t3x4;
typedef matrix<uint32_t, 4, 1> uint32_t4x1;
typedef matrix<uint32_t, 4, 2> uint32_t4x2;
typedef matrix<uint32_t, 4, 3> uint32_t4x3;
typedef matrix<uint32_t, 4, 4> uint32_t4x4;
typedef matrix<int64_t, 1, 1> int64_t1x1;
typedef matrix<int64_t, 1, 2> int64_t1x2;
typedef matrix<int64_t, 1, 3> int64_t1x3;
typedef matrix<int64_t, 1, 4> int64_t1x4;
typedef matrix<int64_t, 2, 1> int64_t2x1;
typedef matrix<int64_t, 2, 2> int64_t2x2;
typedef matrix<int64_t, 2, 3> int64_t2x3;
typedef matrix<int64_t, 2, 4> int64_t2x4;
typedef matrix<int64_t, 3, 1> int64_t3x1;
typedef matrix<int64_t, 3, 2> int64_t3x2;
typedef matrix<int64_t, 3, 3> int64_t3x3;
typedef matrix<int64_t, 3, 4> int64_t3x4;
typedef matrix<int64_t, 4, 1> int64_t4x1;
typedef matrix<int64_t, 4, 2> int64_t4x2;
typedef matrix<int64_t, 4, 3> int64_t4x3;
typedef matrix<int64_t, 4, 4> int64_t4x4;
typedef matrix<uint64_t, 1, 1> uint64_t1x1;
typedef matrix<uint64_t, 1, 2> uint64_t1x2;
typedef matrix<uint64_t, 1, 3> uint64_t1x3;
typedef matrix<uint64_t, 1, 4> uint64_t1x4;
typedef matrix<uint64_t, 2, 1> uint64_t2x1;
typedef matrix<uint64_t, 2, 2> uint64_t2x2;
typedef matrix<uint64_t, 2, 3> uint64_t2x3;
typedef matrix<uint64_t, 2, 4> uint64_t2x4;
typedef matrix<uint64_t, 3, 1> uint64_t3x1;
typedef matrix<uint64_t, 3, 2> uint64_t3x2;
typedef matrix<uint64_t, 3, 3> uint64_t3x3;
typedef matrix<uint64_t, 3, 4> uint64_t3x4;
typedef matrix<uint64_t, 4, 1> uint64_t4x1;
typedef matrix<uint64_t, 4, 2> uint64_t4x2;
typedef matrix<uint64_t, 4, 3> uint64_t4x3;
typedef matrix<uint64_t, 4, 4> uint64_t4x4;

typedef matrix<half, 1, 1> half1x1;
typedef matrix<half, 1, 2> half1x2;
typedef matrix<half, 1, 3> half1x3;
typedef matrix<half, 1, 4> half1x4;
typedef matrix<half, 2, 1> half2x1;
typedef matrix<half, 2, 2> half2x2;
typedef matrix<half, 2, 3> half2x3;
typedef matrix<half, 2, 4> half2x4;
typedef matrix<half, 3, 1> half3x1;
typedef matrix<half, 3, 2> half3x2;
typedef matrix<half, 3, 3> half3x3;
typedef matrix<half, 3, 4> half3x4;
typedef matrix<half, 4, 1> half4x1;
typedef matrix<half, 4, 2> half4x2;
typedef matrix<half, 4, 3> half4x3;
typedef matrix<half, 4, 4> half4x4;
typedef matrix<float, 1, 1> float1x1;
typedef matrix<float, 1, 2> float1x2;
typedef matrix<float, 1, 3> float1x3;
typedef matrix<float, 1, 4> float1x4;
typedef matrix<float, 2, 1> float2x1;
typedef matrix<float, 2, 2> float2x2;
typedef matrix<float, 2, 3> float2x3;
typedef matrix<float, 2, 4> float2x4;
typedef matrix<float, 3, 1> float3x1;
typedef matrix<float, 3, 2> float3x2;
typedef matrix<float, 3, 3> float3x3;
typedef matrix<float, 3, 4> float3x4;
typedef matrix<float, 4, 1> float4x1;
typedef matrix<float, 4, 2> float4x2;
typedef matrix<float, 4, 3> float4x3;
typedef matrix<float, 4, 4> float4x4;
typedef matrix<double, 1, 1> double1x1;
typedef matrix<double, 1, 2> double1x2;
typedef matrix<double, 1, 3> double1x3;
typedef matrix<double, 1, 4> double1x4;
typedef matrix<double, 2, 1> double2x1;
typedef matrix<double, 2, 2> double2x2;
typedef matrix<double, 2, 3> double2x3;
typedef matrix<double, 2, 4> double2x4;
typedef matrix<double, 3, 1> double3x1;
typedef matrix<double, 3, 2> double3x2;
typedef matrix<double, 3, 3> double3x3;
typedef matrix<double, 3, 4> double3x4;
typedef matrix<double, 4, 1> double4x1;
typedef matrix<double, 4, 2> double4x2;
typedef matrix<double, 4, 3> double4x3;
typedef matrix<double, 4, 4> double4x4;

#ifdef __HLSL_ENABLE_16_BIT
typedef matrix<float16_t, 1, 1> float16_t1x1;
typedef matrix<float16_t, 1, 2> float16_t1x2;
typedef matrix<float16_t, 1, 3> float16_t1x3;
typedef matrix<float16_t, 1, 4> float16_t1x4;
typedef matrix<float16_t, 2, 1> float16_t2x1;
typedef matrix<float16_t, 2, 2> float16_t2x2;
typedef matrix<float16_t, 2, 3> float16_t2x3;
typedef matrix<float16_t, 2, 4> float16_t2x4;
typedef matrix<float16_t, 3, 1> float16_t3x1;
typedef matrix<float16_t, 3, 2> float16_t3x2;
typedef matrix<float16_t, 3, 3> float16_t3x3;
typedef matrix<float16_t, 3, 4> float16_t3x4;
typedef matrix<float16_t, 4, 1> float16_t4x1;
typedef matrix<float16_t, 4, 2> float16_t4x2;
typedef matrix<float16_t, 4, 3> float16_t4x3;
typedef matrix<float16_t, 4, 4> float16_t4x4;
#endif

typedef matrix<float32_t, 1, 1> float32_t1x1;
typedef matrix<float32_t, 1, 2> float32_t1x2;
typedef matrix<float32_t, 1, 3> float32_t1x3;
typedef matrix<float32_t, 1, 4> float32_t1x4;
typedef matrix<float32_t, 2, 1> float32_t2x1;
typedef matrix<float32_t, 2, 2> float32_t2x2;
typedef matrix<float32_t, 2, 3> float32_t2x3;
typedef matrix<float32_t, 2, 4> float32_t2x4;
typedef matrix<float32_t, 3, 1> float32_t3x1;
typedef matrix<float32_t, 3, 2> float32_t3x2;
typedef matrix<float32_t, 3, 3> float32_t3x3;
typedef matrix<float32_t, 3, 4> float32_t3x4;
typedef matrix<float32_t, 4, 1> float32_t4x1;
typedef matrix<float32_t, 4, 2> float32_t4x2;
typedef matrix<float32_t, 4, 3> float32_t4x3;
typedef matrix<float32_t, 4, 4> float32_t4x4;
typedef matrix<float64_t, 1, 1> float64_t1x1;
typedef matrix<float64_t, 1, 2> float64_t1x2;
typedef matrix<float64_t, 1, 3> float64_t1x3;
typedef matrix<float64_t, 1, 4> float64_t1x4;
typedef matrix<float64_t, 2, 1> float64_t2x1;
typedef matrix<float64_t, 2, 2> float64_t2x2;
typedef matrix<float64_t, 2, 3> float64_t2x3;
typedef matrix<float64_t, 2, 4> float64_t2x4;
typedef matrix<float64_t, 3, 1> float64_t3x1;
typedef matrix<float64_t, 3, 2> float64_t3x2;
typedef matrix<float64_t, 3, 3> float64_t3x3;
typedef matrix<float64_t, 3, 4> float64_t3x4;
typedef matrix<float64_t, 4, 1> float64_t4x1;
typedef matrix<float64_t, 4, 2> float64_t4x2;
typedef matrix<float64_t, 4, 3> float64_t4x3;
typedef matrix<float64_t, 4, 4> float64_t4x4;

} // namespace hlsl

#endif //_HLSL_HLSL_BASIC_TYPES_H_
Loading
Loading