Skip to content

Commit 26980a2

Browse files
committed
[HLSL] Add support for the HLSL matrix type
fixes #109839 This change is really simple. It creates a matrix alias that will let HLSL use the existing clang `matrix_type` infra. The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL. Testing therefore is limited to exercising the alias. The main difference in this attempt is the type printer.
1 parent ef7de8d commit 26980a2

File tree

10 files changed

+451
-8
lines changed

10 files changed

+451
-8
lines changed

clang/include/clang/Driver/Options.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4582,7 +4582,7 @@ defm ptrauth_block_descriptor_pointers : OptInCC1FFlag<"ptrauth-block-descriptor
45824582
def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
45834583
Visibility<[ClangOption, CC1Option]>,
45844584
HelpText<"Enable matrix data type and related builtin functions">,
4585-
MarshallingInfoFlag<LangOpts<"MatrixTypes">>;
4585+
MarshallingInfoFlag<LangOpts<"MatrixTypes">, hlsl.KeyPath>;
45864586

45874587
defm raw_string_literals : BoolFOption<"raw-string-literals",
45884588
LangOpts<"RawStringLiterals">, Default<std#".hasRawStringLiterals()">,

clang/include/clang/Sema/HLSLExternalSemaSource.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
4444
private:
4545
void defineTrivialHLSLTypes();
4646
void defineHLSLVectorAlias();
47+
void defineHLSLMatrixAlias();
4748
void defineHLSLTypesWithForwardDeclarations();
4849
void onCompletion(CXXRecordDecl *Record, CompletionFunction Fn);
4950
};

clang/lib/AST/TypePrinter.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -846,16 +846,41 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
846846
}
847847
}
848848

849-
void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
850-
raw_ostream &OS) {
851-
printBefore(T->getElementType(), OS);
852-
OS << " __attribute__((matrix_type(";
849+
static void printDims(const ConstantMatrixType *T, raw_ostream &OS) {
853850
OS << T->getNumRows() << ", " << T->getNumColumns();
851+
}
852+
853+
static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) {
854+
OS << "matrix<";
855+
TP.printBefore(T->getElementType(), OS);
856+
}
857+
858+
static void printHLSLMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
859+
OS << ", ";
860+
printDims(T, OS);
861+
OS << ">";
862+
}
863+
864+
static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) {
865+
TP.printBefore(T->getElementType(), OS);
866+
OS << " __attribute__((matrix_type(";
867+
printDims(T, OS);
854868
OS << ")))";
855869
}
856870

857-
void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
858-
raw_ostream &OS) {
871+
void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, raw_ostream &OS) {
872+
if (Policy.UseHLSLTypes) {
873+
printHLSLMatrixBefore(*this, T, OS);
874+
return;
875+
}
876+
printClangMatrixBefore(*this, T, OS);
877+
}
878+
879+
void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
880+
if (Policy.UseHLSLTypes) {
881+
printHLSLMatrixAfter(T, OS);
882+
return;
883+
}
859884
printAfter(T->getElementType(), OS);
860885
}
861886

clang/lib/Headers/hlsl/hlsl_basic_types.h

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,239 @@ typedef vector<float64_t, 2> float64_t2;
115115
typedef vector<float64_t, 3> float64_t3;
116116
typedef vector<float64_t, 4> float64_t4;
117117

118+
ifdef __HLSL_ENABLE_16_BIT
119+
typedef matrix<int16_t, 1, 1> int16_t1x1;
120+
typedef matrix<int16_t, 1, 2> int16_t1x2;
121+
typedef matrix<int16_t, 1, 3> int16_t1x3;
122+
typedef matrix<int16_t, 1, 4> int16_t1x4;
123+
typedef matrix<int16_t, 2, 1> int16_t2x1;
124+
typedef matrix<int16_t, 2, 2> int16_t2x2;
125+
typedef matrix<int16_t, 2, 3> int16_t2x3;
126+
typedef matrix<int16_t, 2, 4> int16_t2x4;
127+
typedef matrix<int16_t, 3, 1> int16_t3x1;
128+
typedef matrix<int16_t, 3, 2> int16_t3x2;
129+
typedef matrix<int16_t, 3, 3> int16_t3x3;
130+
typedef matrix<int16_t, 3, 4> int16_t3x4;
131+
typedef matrix<int16_t, 4, 1> int16_t4x1;
132+
typedef matrix<int16_t, 4, 2> int16_t4x2;
133+
typedef matrix<int16_t, 4, 3> int16_t4x3;
134+
typedef matrix<int16_t, 4, 4> int16_t4x4;
135+
typedef matrix<uint16_t, 1, 1> uint16_t1x1;
136+
typedef matrix<uint16_t, 1, 2> uint16_t1x2;
137+
typedef matrix<uint16_t, 1, 3> uint16_t1x3;
138+
typedef matrix<uint16_t, 1, 4> uint16_t1x4;
139+
typedef matrix<uint16_t, 2, 1> uint16_t2x1;
140+
typedef matrix<uint16_t, 2, 2> uint16_t2x2;
141+
typedef matrix<uint16_t, 2, 3> uint16_t2x3;
142+
typedef matrix<uint16_t, 2, 4> uint16_t2x4;
143+
typedef matrix<uint16_t, 3, 1> uint16_t3x1;
144+
typedef matrix<uint16_t, 3, 2> uint16_t3x2;
145+
typedef matrix<uint16_t, 3, 3> uint16_t3x3;
146+
typedef matrix<uint16_t, 3, 4> uint16_t3x4;
147+
typedef matrix<uint16_t, 4, 1> uint16_t4x1;
148+
typedef matrix<uint16_t, 4, 2> uint16_t4x2;
149+
typedef matrix<uint16_t, 4, 3> uint16_t4x3;
150+
typedef matrix<uint16_t, 4, 4> uint16_t4x4;
151+
#endif
152+
153+
typedef matrix<int, 1, 1> int1x1;
154+
typedef matrix<int, 1, 2> int1x2;
155+
typedef matrix<int, 1, 3> int1x3;
156+
typedef matrix<int, 1, 4> int1x4;
157+
typedef matrix<int, 2, 1> int2x1;
158+
typedef matrix<int, 2, 2> int2x2;
159+
typedef matrix<int, 2, 3> int2x3;
160+
typedef matrix<int, 2, 4> int2x4;
161+
typedef matrix<int, 3, 1> int3x1;
162+
typedef matrix<int, 3, 2> int3x2;
163+
typedef matrix<int, 3, 3> int3x3;
164+
typedef matrix<int, 3, 4> int3x4;
165+
typedef matrix<int, 4, 1> int4x1;
166+
typedef matrix<int, 4, 2> int4x2;
167+
typedef matrix<int, 4, 3> int4x3;
168+
typedef matrix<int, 4, 4> int4x4;
169+
typedef matrix<uint, 1, 1> uint1x1;
170+
typedef matrix<uint, 1, 2> uint1x2;
171+
typedef matrix<uint, 1, 3> uint1x3;
172+
typedef matrix<uint, 1, 4> uint1x4;
173+
typedef matrix<uint, 2, 1> uint2x1;
174+
typedef matrix<uint, 2, 2> uint2x2;
175+
typedef matrix<uint, 2, 3> uint2x3;
176+
typedef matrix<uint, 2, 4> uint2x4;
177+
typedef matrix<uint, 3, 1> uint3x1;
178+
typedef matrix<uint, 3, 2> uint3x2;
179+
typedef matrix<uint, 3, 3> uint3x3;
180+
typedef matrix<uint, 3, 4> uint3x4;
181+
typedef matrix<uint, 4, 1> uint4x1;
182+
typedef matrix<uint, 4, 2> uint4x2;
183+
typedef matrix<uint, 4, 3> uint4x3;
184+
typedef matrix<uint, 4, 4> uint4x4;
185+
typedef matrix<int32_t, 1, 1> int32_t1x1;
186+
typedef matrix<int32_t, 1, 2> int32_t1x2;
187+
typedef matrix<int32_t, 1, 3> int32_t1x3;
188+
typedef matrix<int32_t, 1, 4> int32_t1x4;
189+
typedef matrix<int32_t, 2, 1> int32_t2x1;
190+
typedef matrix<int32_t, 2, 2> int32_t2x2;
191+
typedef matrix<int32_t, 2, 3> int32_t2x3;
192+
typedef matrix<int32_t, 2, 4> int32_t2x4;
193+
typedef matrix<int32_t, 3, 1> int32_t3x1;
194+
typedef matrix<int32_t, 3, 2> int32_t3x2;
195+
typedef matrix<int32_t, 3, 3> int32_t3x3;
196+
typedef matrix<int32_t, 3, 4> int32_t3x4;
197+
typedef matrix<int32_t, 4, 1> int32_t4x1;
198+
typedef matrix<int32_t, 4, 2> int32_t4x2;
199+
typedef matrix<int32_t, 4, 3> int32_t4x3;
200+
typedef matrix<int32_t, 4, 4> int32_t4x4;
201+
typedef matrix<uint32_t, 1, 1> uint32_t1x1;
202+
typedef matrix<uint32_t, 1, 2> uint32_t1x2;
203+
typedef matrix<uint32_t, 1, 3> uint32_t1x3;
204+
typedef matrix<uint32_t, 1, 4> uint32_t1x4;
205+
typedef matrix<uint32_t, 2, 1> uint32_t2x1;
206+
typedef matrix<uint32_t, 2, 2> uint32_t2x2;
207+
typedef matrix<uint32_t, 2, 3> uint32_t2x3;
208+
typedef matrix<uint32_t, 2, 4> uint32_t2x4;
209+
typedef matrix<uint32_t, 3, 1> uint32_t3x1;
210+
typedef matrix<uint32_t, 3, 2> uint32_t3x2;
211+
typedef matrix<uint32_t, 3, 3> uint32_t3x3;
212+
typedef matrix<uint32_t, 3, 4> uint32_t3x4;
213+
typedef matrix<uint32_t, 4, 1> uint32_t4x1;
214+
typedef matrix<uint32_t, 4, 2> uint32_t4x2;
215+
typedef matrix<uint32_t, 4, 3> uint32_t4x3;
216+
typedef matrix<uint32_t, 4, 4> uint32_t4x4;
217+
typedef matrix<int64_t, 1, 1> int64_t1x1;
218+
typedef matrix<int64_t, 1, 2> int64_t1x2;
219+
typedef matrix<int64_t, 1, 3> int64_t1x3;
220+
typedef matrix<int64_t, 1, 4> int64_t1x4;
221+
typedef matrix<int64_t, 2, 1> int64_t2x1;
222+
typedef matrix<int64_t, 2, 2> int64_t2x2;
223+
typedef matrix<int64_t, 2, 3> int64_t2x3;
224+
typedef matrix<int64_t, 2, 4> int64_t2x4;
225+
typedef matrix<int64_t, 3, 1> int64_t3x1;
226+
typedef matrix<int64_t, 3, 2> int64_t3x2;
227+
typedef matrix<int64_t, 3, 3> int64_t3x3;
228+
typedef matrix<int64_t, 3, 4> int64_t3x4;
229+
typedef matrix<int64_t, 4, 1> int64_t4x1;
230+
typedef matrix<int64_t, 4, 2> int64_t4x2;
231+
typedef matrix<int64_t, 4, 3> int64_t4x3;
232+
typedef matrix<int64_t, 4, 4> int64_t4x4;
233+
typedef matrix<uint64_t, 1, 1> uint64_t1x1;
234+
typedef matrix<uint64_t, 1, 2> uint64_t1x2;
235+
typedef matrix<uint64_t, 1, 3> uint64_t1x3;
236+
typedef matrix<uint64_t, 1, 4> uint64_t1x4;
237+
typedef matrix<uint64_t, 2, 1> uint64_t2x1;
238+
typedef matrix<uint64_t, 2, 2> uint64_t2x2;
239+
typedef matrix<uint64_t, 2, 3> uint64_t2x3;
240+
typedef matrix<uint64_t, 2, 4> uint64_t2x4;
241+
typedef matrix<uint64_t, 3, 1> uint64_t3x1;
242+
typedef matrix<uint64_t, 3, 2> uint64_t3x2;
243+
typedef matrix<uint64_t, 3, 3> uint64_t3x3;
244+
typedef matrix<uint64_t, 3, 4> uint64_t3x4;
245+
typedef matrix<uint64_t, 4, 1> uint64_t4x1;
246+
typedef matrix<uint64_t, 4, 2> uint64_t4x2;
247+
typedef matrix<uint64_t, 4, 3> uint64_t4x3;
248+
typedef matrix<uint64_t, 4, 4> uint64_t4x4;
249+
250+
typedef matrix<half, 1, 1> half1x1;
251+
typedef matrix<half, 1, 2> half1x2;
252+
typedef matrix<half, 1, 3> half1x3;
253+
typedef matrix<half, 1, 4> half1x4;
254+
typedef matrix<half, 2, 1> half2x1;
255+
typedef matrix<half, 2, 2> half2x2;
256+
typedef matrix<half, 2, 3> half2x3;
257+
typedef matrix<half, 2, 4> half2x4;
258+
typedef matrix<half, 3, 1> half3x1;
259+
typedef matrix<half, 3, 2> half3x2;
260+
typedef matrix<half, 3, 3> half3x3;
261+
typedef matrix<half, 3, 4> half3x4;
262+
typedef matrix<half, 4, 1> half4x1;
263+
typedef matrix<half, 4, 2> half4x2;
264+
typedef matrix<half, 4, 3> half4x3;
265+
typedef matrix<half, 4, 4> half4x4;
266+
typedef matrix<float, 1, 1> float1x1;
267+
typedef matrix<float, 1, 2> float1x2;
268+
typedef matrix<float, 1, 3> float1x3;
269+
typedef matrix<float, 1, 4> float1x4;
270+
typedef matrix<float, 2, 1> float2x1;
271+
typedef matrix<float, 2, 2> float2x2;
272+
typedef matrix<float, 2, 3> float2x3;
273+
typedef matrix<float, 2, 4> float2x4;
274+
typedef matrix<float, 3, 1> float3x1;
275+
typedef matrix<float, 3, 2> float3x2;
276+
typedef matrix<float, 3, 3> float3x3;
277+
typedef matrix<float, 3, 4> float3x4;
278+
typedef matrix<float, 4, 1> float4x1;
279+
typedef matrix<float, 4, 2> float4x2;
280+
typedef matrix<float, 4, 3> float4x3;
281+
typedef matrix<float, 4, 4> float4x4;
282+
typedef matrix<double, 1, 1> double1x1;
283+
typedef matrix<double, 1, 2> double1x2;
284+
typedef matrix<double, 1, 3> double1x3;
285+
typedef matrix<double, 1, 4> double1x4;
286+
typedef matrix<double, 2, 1> double2x1;
287+
typedef matrix<double, 2, 2> double2x2;
288+
typedef matrix<double, 2, 3> double2x3;
289+
typedef matrix<double, 2, 4> double2x4;
290+
typedef matrix<double, 3, 1> double3x1;
291+
typedef matrix<double, 3, 2> double3x2;
292+
typedef matrix<double, 3, 3> double3x3;
293+
typedef matrix<double, 3, 4> double3x4;
294+
typedef matrix<double, 4, 1> double4x1;
295+
typedef matrix<double, 4, 2> double4x2;
296+
typedef matrix<double, 4, 3> double4x3;
297+
typedef matrix<double, 4, 4> double4x4;
298+
299+
#ifdef __HLSL_ENABLE_16_BIT
300+
typedef matrix<float16_t, 1, 1> float16_t1x1;
301+
typedef matrix<float16_t, 1, 2> float16_t1x2;
302+
typedef matrix<float16_t, 1, 3> float16_t1x3;
303+
typedef matrix<float16_t, 1, 4> float16_t1x4;
304+
typedef matrix<float16_t, 2, 1> float16_t2x1;
305+
typedef matrix<float16_t, 2, 2> float16_t2x2;
306+
typedef matrix<float16_t, 2, 3> float16_t2x3;
307+
typedef matrix<float16_t, 2, 4> float16_t2x4;
308+
typedef matrix<float16_t, 3, 1> float16_t3x1;
309+
typedef matrix<float16_t, 3, 2> float16_t3x2;
310+
typedef matrix<float16_t, 3, 3> float16_t3x3;
311+
typedef matrix<float16_t, 3, 4> float16_t3x4;
312+
typedef matrix<float16_t, 4, 1> float16_t4x1;
313+
typedef matrix<float16_t, 4, 2> float16_t4x2;
314+
typedef matrix<float16_t, 4, 3> float16_t4x3;
315+
typedef matrix<float16_t, 4, 4> float16_t4x4;
316+
#endif
317+
318+
typedef matrix<float32_t, 1, 1> float32_t1x1;
319+
typedef matrix<float32_t, 1, 2> float32_t1x2;
320+
typedef matrix<float32_t, 1, 3> float32_t1x3;
321+
typedef matrix<float32_t, 1, 4> float32_t1x4;
322+
typedef matrix<float32_t, 2, 1> float32_t2x1;
323+
typedef matrix<float32_t, 2, 2> float32_t2x2;
324+
typedef matrix<float32_t, 2, 3> float32_t2x3;
325+
typedef matrix<float32_t, 2, 4> float32_t2x4;
326+
typedef matrix<float32_t, 3, 1> float32_t3x1;
327+
typedef matrix<float32_t, 3, 2> float32_t3x2;
328+
typedef matrix<float32_t, 3, 3> float32_t3x3;
329+
typedef matrix<float32_t, 3, 4> float32_t3x4;
330+
typedef matrix<float32_t, 4, 1> float32_t4x1;
331+
typedef matrix<float32_t, 4, 2> float32_t4x2;
332+
typedef matrix<float32_t, 4, 3> float32_t4x3;
333+
typedef matrix<float32_t, 4, 4> float32_t4x4;
334+
typedef matrix<float64_t, 1, 1> float64_t1x1;
335+
typedef matrix<float64_t, 1, 2> float64_t1x2;
336+
typedef matrix<float64_t, 1, 3> float64_t1x3;
337+
typedef matrix<float64_t, 1, 4> float64_t1x4;
338+
typedef matrix<float64_t, 2, 1> float64_t2x1;
339+
typedef matrix<float64_t, 2, 2> float64_t2x2;
340+
typedef matrix<float64_t, 2, 3> float64_t2x3;
341+
typedef matrix<float64_t, 2, 4> float64_t2x4;
342+
typedef matrix<float64_t, 3, 1> float64_t3x1;
343+
typedef matrix<float64_t, 3, 2> float64_t3x2;
344+
typedef matrix<float64_t, 3, 3> float64_t3x3;
345+
typedef matrix<float64_t, 3, 4> float64_t3x4;
346+
typedef matrix<float64_t, 4, 1> float64_t4x1;
347+
typedef matrix<float64_t, 4, 2> float64_t4x2;
348+
typedef matrix<float64_t, 4, 3> float64_t4x3;
349+
typedef matrix<float64_t, 4, 4> float64_t4x4;
350+
118351
} // namespace hlsl
119352

120353
#endif //_HLSL_HLSL_BASIC_TYPES_H_

clang/lib/Sema/HLSLExternalSemaSource.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,80 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
121121
HLSLNamespace->addDecl(Template);
122122
}
123123

124+
void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
125+
ASTContext &AST = SemaPtr->getASTContext();
126+
llvm::SmallVector<NamedDecl *> TemplateParams;
127+
128+
auto *TypeParam = TemplateTypeParmDecl::Create(
129+
AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
130+
&AST.Idents.get("element", tok::TokenKind::identifier), false, false);
131+
TypeParam->setDefaultArgument(
132+
AST, SemaPtr->getTrivialTemplateArgumentLoc(
133+
TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
134+
135+
TemplateParams.emplace_back(TypeParam);
136+
137+
// these should be 64 bit to be consistent with other clang matrices.
138+
auto *RowsParam = NonTypeTemplateParmDecl::Create(
139+
AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
140+
&AST.Idents.get("rows_count", tok::TokenKind::identifier), AST.IntTy,
141+
false, AST.getTrivialTypeSourceInfo(AST.IntTy));
142+
llvm::APInt RVal(AST.getIntWidth(AST.IntTy), 4);
143+
TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
144+
/*IsDefaulted=*/true);
145+
RowsParam->setDefaultArgument(
146+
AST, SemaPtr->getTrivialTemplateArgumentLoc(RDefault, AST.IntTy,
147+
SourceLocation(), RowsParam));
148+
TemplateParams.emplace_back(RowsParam);
149+
150+
auto *ColsParam = NonTypeTemplateParmDecl::Create(
151+
AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 2,
152+
&AST.Idents.get("cols_count", tok::TokenKind::identifier), AST.IntTy,
153+
false, AST.getTrivialTypeSourceInfo(AST.IntTy));
154+
llvm::APInt CVal(AST.getIntWidth(AST.IntTy), 4);
155+
TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
156+
/*IsDefaulted=*/true);
157+
ColsParam->setDefaultArgument(
158+
AST, SemaPtr->getTrivialTemplateArgumentLoc(CDefault, AST.IntTy,
159+
SourceLocation(), ColsParam));
160+
TemplateParams.emplace_back(ColsParam);
161+
162+
auto *ParamList =
163+
TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
164+
TemplateParams, SourceLocation(), nullptr);
165+
166+
IdentifierInfo &II = AST.Idents.get("matrix", tok::TokenKind::identifier);
167+
168+
QualType AliasType = AST.getDependentSizedMatrixType(
169+
AST.getTemplateTypeParmType(0, 0, false, TypeParam),
170+
DeclRefExpr::Create(
171+
AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam, false,
172+
DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
173+
AST.IntTy, VK_LValue),
174+
DeclRefExpr::Create(
175+
AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam, false,
176+
DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
177+
AST.IntTy, VK_LValue),
178+
SourceLocation());
179+
180+
auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
181+
SourceLocation(), &II,
182+
AST.getTrivialTypeSourceInfo(AliasType));
183+
Record->setImplicit(true);
184+
185+
auto *Template =
186+
TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
187+
Record->getIdentifier(), ParamList, Record);
188+
189+
Record->setDescribedAliasTemplate(Template);
190+
Template->setImplicit(true);
191+
Template->setLexicalDeclContext(Record->getDeclContext());
192+
HLSLNamespace->addDecl(Template);
193+
}
194+
124195
void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
125196
defineHLSLVectorAlias();
197+
defineHLSLMatrixAlias();
126198
}
127199

128200
/// Set up common members and attributes for buffer types

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3285,7 +3285,6 @@ static void BuildFlattenedTypeList(QualType BaseTy,
32853285
while (!WorkList.empty()) {
32863286
QualType T = WorkList.pop_back_val();
32873287
T = T.getCanonicalType().getUnqualifiedType();
3288-
assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
32893288
if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
32903289
llvm::SmallVector<QualType, 16> ElementFields;
32913290
// Generally I've avoided recursion in this algorithm, but arrays of

0 commit comments

Comments
 (0)