Skip to content

Commit bf1c9e4

Browse files
authored
Change some getAsStructureType() uses to getAs<RecordType>() (microsoft#4707)
In some cases, code using getAsStructureType() seemed to expect any user- defined type to result in RecordType here, but only 'struct' types would, leaving 'class' types to fail certain code paths. Some code paths had an additional getAs<RecordType>() if the getAsStructureType() returned nullptr, but at that point, why bother with getAsStructureType() in the first place? This updates cases in CGHLSLMS.cpp that looked to be misusing getAsStructureType to simply use getAs<RecordType>() instead. One case is with constructing type annotations, where two branches are used and code is almost identical, except skipping size return when a member is a resource was only in the struct path. I think removing this separate path and checking for resource on any RecordType makes sense here.
1 parent d5aa3ff commit bf1c9e4

File tree

2 files changed

+43
-33
lines changed

2 files changed

+43
-33
lines changed

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,23 +1215,7 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
12151215
AddTypeAnnotation(GetHLSLResourceResultType(Ty), dxilTypeSys, arrayEltSize);
12161216
// Resources don't count towards cbuffer size.
12171217
return 0;
1218-
} else if (const RecordType *RT = paramTy->getAsStructureType()) {
1219-
RecordDecl *RD = RT->getDecl();
1220-
llvm::StructType *ST = CGM.getTypes().ConvertRecordDeclType(RD);
1221-
// Skip if already created.
1222-
if (DxilStructAnnotation *annotation = dxilTypeSys.GetStructAnnotation(ST)) {
1223-
unsigned structSize = annotation->GetCBufferSize();
1224-
return structSize;
1225-
}
1226-
DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST,
1227-
GetNumTemplateArgsForRecordDecl(RT->getDecl()));
1228-
DxilPayloadAnnotation *payloadAnnotation = nullptr;
1229-
if (ValidatePayloadDecl(RT->getDecl(), *m_pHLModule->GetShaderModel(), CGM.getDiags(), CGM.getCodeGenOpts()))
1230-
payloadAnnotation = dxilTypeSys.AddPayloadAnnotation(ST);
1231-
unsigned size = ConstructStructAnnotation(annotation, payloadAnnotation, RD, dxilTypeSys);
1232-
// Resources don't count towards cbuffer size.
1233-
return IsHLSLResourceType(Ty) ? 0 : size;
1234-
} else if (const RecordType *RT = dyn_cast<RecordType>(paramTy)) {
1218+
} else if (const RecordType *RT = paramTy->getAs<RecordType>()) {
12351219
// For this pointer.
12361220
RecordDecl *RD = RT->getDecl();
12371221
llvm::StructType *ST = CGM.getTypes().ConvertRecordDeclType(RD);
@@ -1245,7 +1229,9 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
12451229
DxilPayloadAnnotation* payloadAnnotation = nullptr;
12461230
if (ValidatePayloadDecl(RT->getDecl(), *m_pHLModule->GetShaderModel(), CGM.getDiags(), CGM.getCodeGenOpts()))
12471231
payloadAnnotation = dxilTypeSys.AddPayloadAnnotation(ST);
1248-
return ConstructStructAnnotation(annotation, payloadAnnotation, RD, dxilTypeSys);
1232+
unsigned size = ConstructStructAnnotation(annotation, payloadAnnotation, RD, dxilTypeSys);
1233+
// Resources don't count towards cbuffer size.
1234+
return IsHLSLResourceType(Ty) ? 0 : size;
12491235
} else if (IsStringType(Ty)) {
12501236
// string won't be included in cbuffer
12511237
return 0;
@@ -3177,10 +3163,7 @@ static void CollectScalarTypes(std::vector<QualType> &ScalarTys, QualType Ty) {
31773163
CollectScalarTypes(ScalarTys, EltTy);
31783164
}
31793165
} else {
3180-
const RecordType *RT = Ty->getAsStructureType();
3181-
// For CXXRecord.
3182-
if (!RT)
3183-
RT = Ty->getAs<RecordType>();
3166+
const RecordType *RT = Ty->getAs<RecordType>();
31843167
RecordDecl *RD = RT->getDecl();
31853168
for (FieldDecl *field : RD->fields())
31863169
CollectScalarTypes(ScalarTys, field->getType());
@@ -3994,7 +3977,7 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
39943977
elts.emplace_back(Builder.CreateLoad(val));
39953978
eltTys.emplace_back(Ty);
39963979
} else {
3997-
RecordDecl *RD = Ty->getAsStructureType()->getDecl();
3980+
const RecordDecl *RD = Ty->getAs<RecordType>()->getDecl();
39983981
const CGRecordLayout& RL = CGF.getTypes().getCGRecordLayout(RD);
39993982

40003983
// Take care base.
@@ -4124,10 +4107,7 @@ static void AddMissingCastOpsInInitList(SmallVector<Value *, 4> &elts, SmallVect
41244107
// Skip hlsl object.
41254108
idx++;
41264109
} else {
4127-
const RecordType *RT = Ty->getAsStructureType();
4128-
// For CXXRecord.
4129-
if (!RT)
4130-
RT = Ty->getAs<RecordType>();
4110+
const RecordType *RT = Ty->getAs<RecordType>();
41314111
RecordDecl *RD = RT->getDecl();
41324112
// Take care base.
41334113
if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
@@ -4210,10 +4190,7 @@ static void StoreInitListToDestPtr(Value *DestPtr,
42104190
} else {
42114191
Constant *zero = Builder.getInt32(0);
42124192

4213-
const RecordType *RT = Type->getAsStructureType();
4214-
// For CXXRecord.
4215-
if (!RT)
4216-
RT = Type->getAs<RecordType>();
4193+
const RecordType *RT = Type->getAs<RecordType>();
42174194
RecordDecl *RD = RT->getDecl();
42184195
const CGRecordLayout &RL = Types.getCGRecordLayout(RD);
42194196
// Take care base.
@@ -5353,7 +5330,7 @@ void CGMSHLSLRuntime::FlattenAggregatePtrToGepList(
53535330
EltTyList.push_back(Type);
53545331
return;
53555332
}
5356-
const clang::RecordType *RT = Type->getAsStructureType();
5333+
const clang::RecordType *RT = Type->getAs<RecordType>();
53575334
RecordDecl *RD = RT->getDecl();
53585335

53595336
const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);
@@ -5802,7 +5779,7 @@ void CGMSHLSLRuntime::EmitHLSLSplat(
58025779
} else if (StructType *ST = dyn_cast<StructType>(Ty)) {
58035780
DXASSERT(!dxilutil::IsHLSLObjectType(ST), "cannot cast to hlsl object, Sema should reject");
58045781

5805-
const clang::RecordType *RT = Type->getAsStructureType();
5782+
const clang::RecordType *RT = Type->getAs<RecordType>();
58065783
RecordDecl *RD = RT->getDecl();
58075784

58085785
const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
2+
// RUN: %dxc -E main -T ps_6_0 -HV 2021 %s | FileCheck %s -check-prefix=ERROR
3+
4+
// CHECK: define void @main()
5+
// CHECK: %[[H:[^ ]+]] = call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %{{[^,]+}}, i32 0)
6+
// CHECK: %[[f:[^ ]+]] = extractvalue %dx.types.CBufRet.f32 %[[H]], 0
7+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %[[f]])
8+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float %[[f]])
9+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float %[[f]])
10+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float %[[f]])
11+
12+
// ERROR: error: no matching function for call to 'badCall'
13+
// ERROR: note: candidate function not viable: no known conversion from 'A' to 'B' for 1st argument
14+
15+
class A {
16+
float f;
17+
int i;
18+
};
19+
class B {
20+
float f;
21+
int i;
22+
};
23+
24+
float4 badCall(B data) {
25+
return (float4)data.f;
26+
}
27+
28+
A g_dnc;
29+
30+
float4 main() : SV_Target {
31+
A dnc = g_dnc;
32+
return badCall(dnc);
33+
}

0 commit comments

Comments
 (0)