Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Basic/LangOptions.def
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ ENUM_LANGOPT(HLSLVersion, HLSLLangStd, 16, HLSL_Unset, NotCompatible, "HLSL Vers
LANGOPT(HLSLStrictAvailability, 1, 0, NotCompatible,
"Strict availability diagnostic mode for HLSL built-in functions.")
LANGOPT(HLSLSpvUseUnknownImageFormat, 1, 0, NotCompatible, "For storage images and texel buffers, sets the default format to 'Unknown' when not specified via the `vk::image_format` attribute. If this option is not used, the format is inferred from the resource's data type.")
VALUE_LANGOPT(HLSLMaxMatrixDimension, 32, 4, NotCompatible, "maximum allowed matrix dimension")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a language option? Since it is just for HLSL it will always be 4.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to expand it to a lang opt for all matrix types in #160190.
see the diff d81fd36

Doing it in this pr was adding a bunch of changes I felt distracted from what I was trying to accomplish. so It will be in its own pr.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, but for this PR it could just be constant 4 in HLSLExternalSemaSource. Then we can discuss the language option separately.


LANGOPT(CUDAIsDevice , 1, 0, NotCompatible, "compiling for CUDA device")
LANGOPT(CUDAAllowVariadicFunctions, 1, 0, NotCompatible, "allowing variadic functions in CUDA device code")
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -4582,7 +4582,7 @@ defm ptrauth_block_descriptor_pointers : OptInCC1FFlag<"ptrauth-block-descriptor
def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
Visibility<[ClangOption, CC1Option]>,
HelpText<"Enable matrix data type and related builtin functions">,
MarshallingInfoFlag<LangOpts<"MatrixTypes">>;
MarshallingInfoFlag<LangOpts<"MatrixTypes">, hlsl.KeyPath>;

defm raw_string_literals : BoolFOption<"raw-string-literals",
LangOpts<"RawStringLiterals">, Default<std#".hasRawStringLiterals()">,
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Sema/HLSLExternalSemaSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
private:
void defineTrivialHLSLTypes();
void defineHLSLVectorAlias();
void defineHLSLMatrixAlias();
void defineHLSLTypesWithForwardDeclarations();
void onCompletion(CXXRecordDecl *Record, CompletionFunction Fn);
};
Expand Down
37 changes: 33 additions & 4 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,16 +846,45 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
}
}

void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
raw_ostream &OS) {
printBefore(T->getElementType(), OS);
OS << " __attribute__((matrix_type(";
static void printDims(const ConstantMatrixType *T, raw_ostream &OS) {
OS << T->getNumRows() << ", " << T->getNumColumns();
}

static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T,
raw_ostream &OS) {
OS << "matrix<";
TP.printBefore(T->getElementType(), OS);
}

static void printHLSLMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
OS << ", ";
printDims(T, OS);
OS << ">";
}

static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T,
raw_ostream &OS) {
TP.printBefore(T->getElementType(), OS);
OS << " __attribute__((matrix_type(";
printDims(T, OS);
OS << ")))";
}

void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
raw_ostream &OS) {
if (Policy.UseHLSLTypes) {
printHLSLMatrixBefore(*this, T, OS);
return;
}
printClangMatrixBefore(*this, T, OS);
}

void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
raw_ostream &OS) {
if (Policy.UseHLSLTypes) {
printHLSLMatrixAfter(T, OS);
return;
}
printAfter(T->getElementType(), OS);
}

Expand Down
233 changes: 233 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_basic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,239 @@ typedef vector<float64_t, 2> float64_t2;
typedef vector<float64_t, 3> float64_t3;
typedef vector<float64_t, 4> float64_t4;

#ifdef __HLSL_ENABLE_16_BIT
typedef matrix<int16_t, 1, 1> int16_t1x1;
typedef matrix<int16_t, 1, 2> int16_t1x2;
typedef matrix<int16_t, 1, 3> int16_t1x3;
typedef matrix<int16_t, 1, 4> int16_t1x4;
typedef matrix<int16_t, 2, 1> int16_t2x1;
typedef matrix<int16_t, 2, 2> int16_t2x2;
typedef matrix<int16_t, 2, 3> int16_t2x3;
typedef matrix<int16_t, 2, 4> int16_t2x4;
typedef matrix<int16_t, 3, 1> int16_t3x1;
typedef matrix<int16_t, 3, 2> int16_t3x2;
typedef matrix<int16_t, 3, 3> int16_t3x3;
typedef matrix<int16_t, 3, 4> int16_t3x4;
typedef matrix<int16_t, 4, 1> int16_t4x1;
typedef matrix<int16_t, 4, 2> int16_t4x2;
typedef matrix<int16_t, 4, 3> int16_t4x3;
typedef matrix<int16_t, 4, 4> int16_t4x4;
typedef matrix<uint16_t, 1, 1> uint16_t1x1;
typedef matrix<uint16_t, 1, 2> uint16_t1x2;
typedef matrix<uint16_t, 1, 3> uint16_t1x3;
typedef matrix<uint16_t, 1, 4> uint16_t1x4;
typedef matrix<uint16_t, 2, 1> uint16_t2x1;
typedef matrix<uint16_t, 2, 2> uint16_t2x2;
typedef matrix<uint16_t, 2, 3> uint16_t2x3;
typedef matrix<uint16_t, 2, 4> uint16_t2x4;
typedef matrix<uint16_t, 3, 1> uint16_t3x1;
typedef matrix<uint16_t, 3, 2> uint16_t3x2;
typedef matrix<uint16_t, 3, 3> uint16_t3x3;
typedef matrix<uint16_t, 3, 4> uint16_t3x4;
typedef matrix<uint16_t, 4, 1> uint16_t4x1;
typedef matrix<uint16_t, 4, 2> uint16_t4x2;
typedef matrix<uint16_t, 4, 3> uint16_t4x3;
typedef matrix<uint16_t, 4, 4> uint16_t4x4;
#endif

typedef matrix<int, 1, 1> int1x1;
typedef matrix<int, 1, 2> int1x2;
typedef matrix<int, 1, 3> int1x3;
typedef matrix<int, 1, 4> int1x4;
typedef matrix<int, 2, 1> int2x1;
typedef matrix<int, 2, 2> int2x2;
typedef matrix<int, 2, 3> int2x3;
typedef matrix<int, 2, 4> int2x4;
typedef matrix<int, 3, 1> int3x1;
typedef matrix<int, 3, 2> int3x2;
typedef matrix<int, 3, 3> int3x3;
typedef matrix<int, 3, 4> int3x4;
typedef matrix<int, 4, 1> int4x1;
typedef matrix<int, 4, 2> int4x2;
typedef matrix<int, 4, 3> int4x3;
typedef matrix<int, 4, 4> int4x4;
typedef matrix<uint, 1, 1> uint1x1;
typedef matrix<uint, 1, 2> uint1x2;
typedef matrix<uint, 1, 3> uint1x3;
typedef matrix<uint, 1, 4> uint1x4;
typedef matrix<uint, 2, 1> uint2x1;
typedef matrix<uint, 2, 2> uint2x2;
typedef matrix<uint, 2, 3> uint2x3;
typedef matrix<uint, 2, 4> uint2x4;
typedef matrix<uint, 3, 1> uint3x1;
typedef matrix<uint, 3, 2> uint3x2;
typedef matrix<uint, 3, 3> uint3x3;
typedef matrix<uint, 3, 4> uint3x4;
typedef matrix<uint, 4, 1> uint4x1;
typedef matrix<uint, 4, 2> uint4x2;
typedef matrix<uint, 4, 3> uint4x3;
typedef matrix<uint, 4, 4> uint4x4;
typedef matrix<int32_t, 1, 1> int32_t1x1;
typedef matrix<int32_t, 1, 2> int32_t1x2;
typedef matrix<int32_t, 1, 3> int32_t1x3;
typedef matrix<int32_t, 1, 4> int32_t1x4;
typedef matrix<int32_t, 2, 1> int32_t2x1;
typedef matrix<int32_t, 2, 2> int32_t2x2;
typedef matrix<int32_t, 2, 3> int32_t2x3;
typedef matrix<int32_t, 2, 4> int32_t2x4;
typedef matrix<int32_t, 3, 1> int32_t3x1;
typedef matrix<int32_t, 3, 2> int32_t3x2;
typedef matrix<int32_t, 3, 3> int32_t3x3;
typedef matrix<int32_t, 3, 4> int32_t3x4;
typedef matrix<int32_t, 4, 1> int32_t4x1;
typedef matrix<int32_t, 4, 2> int32_t4x2;
typedef matrix<int32_t, 4, 3> int32_t4x3;
typedef matrix<int32_t, 4, 4> int32_t4x4;
typedef matrix<uint32_t, 1, 1> uint32_t1x1;
typedef matrix<uint32_t, 1, 2> uint32_t1x2;
typedef matrix<uint32_t, 1, 3> uint32_t1x3;
typedef matrix<uint32_t, 1, 4> uint32_t1x4;
typedef matrix<uint32_t, 2, 1> uint32_t2x1;
typedef matrix<uint32_t, 2, 2> uint32_t2x2;
typedef matrix<uint32_t, 2, 3> uint32_t2x3;
typedef matrix<uint32_t, 2, 4> uint32_t2x4;
typedef matrix<uint32_t, 3, 1> uint32_t3x1;
typedef matrix<uint32_t, 3, 2> uint32_t3x2;
typedef matrix<uint32_t, 3, 3> uint32_t3x3;
typedef matrix<uint32_t, 3, 4> uint32_t3x4;
typedef matrix<uint32_t, 4, 1> uint32_t4x1;
typedef matrix<uint32_t, 4, 2> uint32_t4x2;
typedef matrix<uint32_t, 4, 3> uint32_t4x3;
typedef matrix<uint32_t, 4, 4> uint32_t4x4;
typedef matrix<int64_t, 1, 1> int64_t1x1;
typedef matrix<int64_t, 1, 2> int64_t1x2;
typedef matrix<int64_t, 1, 3> int64_t1x3;
typedef matrix<int64_t, 1, 4> int64_t1x4;
typedef matrix<int64_t, 2, 1> int64_t2x1;
typedef matrix<int64_t, 2, 2> int64_t2x2;
typedef matrix<int64_t, 2, 3> int64_t2x3;
typedef matrix<int64_t, 2, 4> int64_t2x4;
typedef matrix<int64_t, 3, 1> int64_t3x1;
typedef matrix<int64_t, 3, 2> int64_t3x2;
typedef matrix<int64_t, 3, 3> int64_t3x3;
typedef matrix<int64_t, 3, 4> int64_t3x4;
typedef matrix<int64_t, 4, 1> int64_t4x1;
typedef matrix<int64_t, 4, 2> int64_t4x2;
typedef matrix<int64_t, 4, 3> int64_t4x3;
typedef matrix<int64_t, 4, 4> int64_t4x4;
typedef matrix<uint64_t, 1, 1> uint64_t1x1;
typedef matrix<uint64_t, 1, 2> uint64_t1x2;
typedef matrix<uint64_t, 1, 3> uint64_t1x3;
typedef matrix<uint64_t, 1, 4> uint64_t1x4;
typedef matrix<uint64_t, 2, 1> uint64_t2x1;
typedef matrix<uint64_t, 2, 2> uint64_t2x2;
typedef matrix<uint64_t, 2, 3> uint64_t2x3;
typedef matrix<uint64_t, 2, 4> uint64_t2x4;
typedef matrix<uint64_t, 3, 1> uint64_t3x1;
typedef matrix<uint64_t, 3, 2> uint64_t3x2;
typedef matrix<uint64_t, 3, 3> uint64_t3x3;
typedef matrix<uint64_t, 3, 4> uint64_t3x4;
typedef matrix<uint64_t, 4, 1> uint64_t4x1;
typedef matrix<uint64_t, 4, 2> uint64_t4x2;
typedef matrix<uint64_t, 4, 3> uint64_t4x3;
typedef matrix<uint64_t, 4, 4> uint64_t4x4;

typedef matrix<half, 1, 1> half1x1;
typedef matrix<half, 1, 2> half1x2;
typedef matrix<half, 1, 3> half1x3;
typedef matrix<half, 1, 4> half1x4;
typedef matrix<half, 2, 1> half2x1;
typedef matrix<half, 2, 2> half2x2;
typedef matrix<half, 2, 3> half2x3;
typedef matrix<half, 2, 4> half2x4;
typedef matrix<half, 3, 1> half3x1;
typedef matrix<half, 3, 2> half3x2;
typedef matrix<half, 3, 3> half3x3;
typedef matrix<half, 3, 4> half3x4;
typedef matrix<half, 4, 1> half4x1;
typedef matrix<half, 4, 2> half4x2;
typedef matrix<half, 4, 3> half4x3;
typedef matrix<half, 4, 4> half4x4;
typedef matrix<float, 1, 1> float1x1;
typedef matrix<float, 1, 2> float1x2;
typedef matrix<float, 1, 3> float1x3;
typedef matrix<float, 1, 4> float1x4;
typedef matrix<float, 2, 1> float2x1;
typedef matrix<float, 2, 2> float2x2;
typedef matrix<float, 2, 3> float2x3;
typedef matrix<float, 2, 4> float2x4;
typedef matrix<float, 3, 1> float3x1;
typedef matrix<float, 3, 2> float3x2;
typedef matrix<float, 3, 3> float3x3;
typedef matrix<float, 3, 4> float3x4;
typedef matrix<float, 4, 1> float4x1;
typedef matrix<float, 4, 2> float4x2;
typedef matrix<float, 4, 3> float4x3;
typedef matrix<float, 4, 4> float4x4;
typedef matrix<double, 1, 1> double1x1;
typedef matrix<double, 1, 2> double1x2;
typedef matrix<double, 1, 3> double1x3;
typedef matrix<double, 1, 4> double1x4;
typedef matrix<double, 2, 1> double2x1;
typedef matrix<double, 2, 2> double2x2;
typedef matrix<double, 2, 3> double2x3;
typedef matrix<double, 2, 4> double2x4;
typedef matrix<double, 3, 1> double3x1;
typedef matrix<double, 3, 2> double3x2;
typedef matrix<double, 3, 3> double3x3;
typedef matrix<double, 3, 4> double3x4;
typedef matrix<double, 4, 1> double4x1;
typedef matrix<double, 4, 2> double4x2;
typedef matrix<double, 4, 3> double4x3;
typedef matrix<double, 4, 4> double4x4;

#ifdef __HLSL_ENABLE_16_BIT
typedef matrix<float16_t, 1, 1> float16_t1x1;
typedef matrix<float16_t, 1, 2> float16_t1x2;
typedef matrix<float16_t, 1, 3> float16_t1x3;
typedef matrix<float16_t, 1, 4> float16_t1x4;
typedef matrix<float16_t, 2, 1> float16_t2x1;
typedef matrix<float16_t, 2, 2> float16_t2x2;
typedef matrix<float16_t, 2, 3> float16_t2x3;
typedef matrix<float16_t, 2, 4> float16_t2x4;
typedef matrix<float16_t, 3, 1> float16_t3x1;
typedef matrix<float16_t, 3, 2> float16_t3x2;
typedef matrix<float16_t, 3, 3> float16_t3x3;
typedef matrix<float16_t, 3, 4> float16_t3x4;
typedef matrix<float16_t, 4, 1> float16_t4x1;
typedef matrix<float16_t, 4, 2> float16_t4x2;
typedef matrix<float16_t, 4, 3> float16_t4x3;
typedef matrix<float16_t, 4, 4> float16_t4x4;
#endif

typedef matrix<float32_t, 1, 1> float32_t1x1;
typedef matrix<float32_t, 1, 2> float32_t1x2;
typedef matrix<float32_t, 1, 3> float32_t1x3;
typedef matrix<float32_t, 1, 4> float32_t1x4;
typedef matrix<float32_t, 2, 1> float32_t2x1;
typedef matrix<float32_t, 2, 2> float32_t2x2;
typedef matrix<float32_t, 2, 3> float32_t2x3;
typedef matrix<float32_t, 2, 4> float32_t2x4;
typedef matrix<float32_t, 3, 1> float32_t3x1;
typedef matrix<float32_t, 3, 2> float32_t3x2;
typedef matrix<float32_t, 3, 3> float32_t3x3;
typedef matrix<float32_t, 3, 4> float32_t3x4;
typedef matrix<float32_t, 4, 1> float32_t4x1;
typedef matrix<float32_t, 4, 2> float32_t4x2;
typedef matrix<float32_t, 4, 3> float32_t4x3;
typedef matrix<float32_t, 4, 4> float32_t4x4;
typedef matrix<float64_t, 1, 1> float64_t1x1;
typedef matrix<float64_t, 1, 2> float64_t1x2;
typedef matrix<float64_t, 1, 3> float64_t1x3;
typedef matrix<float64_t, 1, 4> float64_t1x4;
typedef matrix<float64_t, 2, 1> float64_t2x1;
typedef matrix<float64_t, 2, 2> float64_t2x2;
typedef matrix<float64_t, 2, 3> float64_t2x3;
typedef matrix<float64_t, 2, 4> float64_t2x4;
typedef matrix<float64_t, 3, 1> float64_t3x1;
typedef matrix<float64_t, 3, 2> float64_t3x2;
typedef matrix<float64_t, 3, 3> float64_t3x3;
typedef matrix<float64_t, 3, 4> float64_t3x4;
typedef matrix<float64_t, 4, 1> float64_t4x1;
typedef matrix<float64_t, 4, 2> float64_t4x2;
typedef matrix<float64_t, 4, 3> float64_t4x3;
typedef matrix<float64_t, 4, 4> float64_t4x4;

} // namespace hlsl

#endif //_HLSL_HLSL_BASIC_TYPES_H_
Loading