Skip to content

Commit 30a7579

Browse files
[spirv] Fixes vk::BufferPointer constructor expression construction. (microsoft#7331)
Constructors are now properly attached to the template class declaration instead of a specialization. Closes microsoft#6489 (again). --------- Co-authored-by: Nathan Gauër <[email protected]>
1 parent 47e11af commit 30a7579

File tree

3 files changed

+102
-23
lines changed

3 files changed

+102
-23
lines changed

tools/clang/lib/AST/ASTContextHLSL.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,19 +1390,27 @@ CXXRecordDecl *hlsl::DeclareVkBufferPointerType(ASTContext &context,
13901390
DeclarationName(&context.Idents.get("Get")), true);
13911391
CanQualType canQualType =
13921392
recordDecl->getTypeForDecl()->getCanonicalTypeUnqualified();
1393-
CreateConstructorDeclarationWithParams(
1393+
auto *copyConstructorDecl = CreateConstructorDeclarationWithParams(
13941394
context, recordDecl, context.VoidTy,
13951395
{context.getRValueReferenceType(canQualType)}, {"bufferPointer"},
1396-
context.DeclarationNames.getCXXConstructorName(canQualType), false);
1397-
CreateConstructorDeclarationWithParams(
1396+
context.DeclarationNames.getCXXConstructorName(canQualType), false, true);
1397+
auto *addressConstructorDecl = CreateConstructorDeclarationWithParams(
13981398
context, recordDecl, context.VoidTy, {context.UnsignedIntTy}, {"address"},
1399-
context.DeclarationNames.getCXXConstructorName(canQualType), false);
1399+
context.DeclarationNames.getCXXConstructorName(canQualType), false, true);
1400+
hlsl::CreateFunctionTemplateDecl(
1401+
context, recordDecl, copyConstructorDecl,
1402+
Builder.getTemplateDecl()->getTemplateParameters()->begin(), 2);
1403+
hlsl::CreateFunctionTemplateDecl(
1404+
context, recordDecl, addressConstructorDecl,
1405+
Builder.getTemplateDecl()->getTemplateParameters()->begin(), 2);
14001406

14011407
StringRef OpcodeGroup = GetHLOpcodeGroupName(HLOpcodeGroup::HLIntrinsic);
14021408
unsigned Opcode = static_cast<unsigned>(IntrinsicOp::MOP_GetBufferContents);
14031409
methodDecl->addAttr(
14041410
HLSLIntrinsicAttr::CreateImplicit(context, OpcodeGroup, "", Opcode));
14051411
methodDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));
1412+
copyConstructorDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));
1413+
addressConstructorDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));
14061414

14071415
return Builder.completeDefinition();
14081416
}

tools/clang/lib/Sema/SemaExprCXX.cpp

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,26 +1057,51 @@ Sema::BuildCXXTypeConstructExpr(TypeSourceInfo *TInfo,
10571057
Expr *Arg = Exprs[0];
10581058
#ifdef ENABLE_SPIRV_CODEGEN
10591059
if (hlsl::IsVKBufferPointerType(Ty) && Arg->getType()->isIntegerType()) {
1060-
for (auto *ctor : Ty->getAsCXXRecordDecl()->ctors()) {
1061-
if (auto *functionType = ctor->getType()->getAs<FunctionProtoType>()) {
1062-
if (functionType->getNumParams() != 1 ||
1063-
!functionType->getParamType(0)->isIntegerType())
1064-
continue;
1065-
1066-
CanQualType argType = Arg->getType()->getCanonicalTypeUnqualified();
1067-
if (!Arg->isRValue()) {
1068-
Arg = ImpCastExprToType(Arg, argType, CK_LValueToRValue).get();
1069-
}
1070-
if (argType != Context.UnsignedLongLongTy) {
1071-
Arg = ImpCastExprToType(Arg, Context.UnsignedLongLongTy,
1072-
CK_IntegralCast)
1073-
.get();
1074-
}
1075-
return CXXConstructExpr::Create(
1076-
Context, Ty, TyBeginLoc, ctor, false, {Arg}, false, false, false,
1077-
false, CXXConstructExpr::ConstructionKind::CK_Complete,
1078-
SourceRange(LParenLoc, RParenLoc));
1060+
typedef DeclContext::specific_decl_iterator<FunctionTemplateDecl> ft_iter;
1061+
auto *recordDecl = Ty->getAsCXXRecordDecl();
1062+
auto *specDecl = cast<ClassTemplateSpecializationDecl>(recordDecl);
1063+
auto *templatedDecl =
1064+
specDecl->getSpecializedTemplate()->getTemplatedDecl();
1065+
auto functionTemplateDecls =
1066+
llvm::iterator_range<ft_iter>(ft_iter(templatedDecl->decls_begin()),
1067+
ft_iter(templatedDecl->decls_end()));
1068+
for (auto *ftd : functionTemplateDecls) {
1069+
auto *fd = ftd->getTemplatedDecl();
1070+
if (fd->getNumParams() != 1 ||
1071+
!fd->getParamDecl(0)->getType()->isIntegerType())
1072+
continue;
1073+
1074+
void *insertPos;
1075+
auto templateArgs = ftd->getInjectedTemplateArgs();
1076+
auto *functionDecl = ftd->findSpecialization(templateArgs, insertPos);
1077+
if (!functionDecl) {
1078+
DeclarationNameInfo DInfo(ftd->getDeclName(),
1079+
recordDecl->getLocation());
1080+
auto *templateArgList = TemplateArgumentList::CreateCopy(
1081+
Context, templateArgs.data(), templateArgs.size());
1082+
functionDecl = CXXConstructorDecl::Create(
1083+
Context, recordDecl, Arg->getLocStart(), DInfo, Ty, TInfo, false,
1084+
false, false, false);
1085+
functionDecl->setFunctionTemplateSpecialization(ftd, templateArgList,
1086+
insertPos);
1087+
} else if (functionDecl->getDeclKind() != Decl::Kind::CXXConstructor) {
1088+
continue;
1089+
}
1090+
1091+
CanQualType argType = Arg->getType()->getCanonicalTypeUnqualified();
1092+
if (!Arg->isRValue()) {
1093+
Arg = ImpCastExprToType(Arg, argType, CK_LValueToRValue).get();
1094+
}
1095+
if (argType != Context.UnsignedLongLongTy) {
1096+
Arg = ImpCastExprToType(Arg, Context.UnsignedLongLongTy,
1097+
CK_IntegralCast)
1098+
.get();
10791099
}
1100+
return CXXConstructExpr::Create(
1101+
Context, Ty, TyBeginLoc, cast<CXXConstructorDecl>(functionDecl),
1102+
false, {Arg}, false, false, false, false,
1103+
CXXConstructExpr::ConstructionKind::CK_Complete,
1104+
SourceRange(LParenLoc, RParenLoc));
10801105
}
10811106
}
10821107
#endif
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: %dxc -spirv -Od -T cs_6_7 %s | FileCheck %s
2+
// RUN: %dxc -spirv -Od -T cs_6_7 -DALIGN_16 %s | FileCheck %s
3+
// RUN: %dxc -spirv -Od -T cs_6_7 -DNO_PC %s | FileCheck %s
4+
5+
// Was getting bogus type errors with the defined changes
6+
7+
#ifdef ALIGN_16
8+
typedef vk::BufferPointer<uint, 16> BufferType;
9+
#else
10+
typedef vk::BufferPointer<uint, 32> BufferType;
11+
#endif
12+
#ifndef NO_PC
13+
struct PushConstantStruct {
14+
BufferType push_buffer;
15+
};
16+
[[vk::push_constant]] PushConstantStruct push_constant;
17+
#endif
18+
19+
RWStructuredBuffer<uint> output;
20+
21+
// CHECK: [[INT:%[_0-9A-Za-z]*]] = OpTypeInt 32 1
22+
// CHECK: [[I0:%[_0-9A-Za-z]*]] = OpConstant [[INT]] 0
23+
// CHECK: [[UINT:%[_0-9A-Za-z]*]] = OpTypeInt 32 0
24+
// CHECK: [[U0:%[_0-9A-Za-z]*]] = OpConstant [[UINT]] 0
25+
// CHECK: [[PPUINT:%[_0-9A-Za-z]*]] = OpTypePointer PhysicalStorageBuffer [[UINT]]
26+
// CHECK: [[PFPPUINT:%[_0-9A-Za-z]*]] = OpTypePointer Function [[PPUINT]]
27+
// CHECK: [[PUUINT:%[_0-9A-Za-z]*]] = OpTypePointer Uniform [[UINT]]
28+
// CHECK: [[OUTPUT:%[_0-9A-Za-z]*]] = OpVariable %{{[_0-9A-Za-z]*}} Uniform
29+
30+
[numthreads(1, 1, 1)]
31+
void main() {
32+
uint64_t addr = 123;
33+
vk::BufferPointer<uint, 32> test = vk::BufferPointer<uint, 32>(addr);
34+
output[0] = test.Get();
35+
}
36+
37+
// CHECK: [[TEST:%[_0-9A-Za-z]*]] = OpVariable [[PFPPUINT]] Function
38+
// CHECK: [[X1:%[_0-9A-Za-z]*]] = OpConvertUToPtr [[PPUINT]]
39+
// CHECK: OpStore [[TEST]] [[X1]]
40+
// CHECK: [[X2:%[_0-9A-Za-z]*]] = OpLoad [[PPUINT]] [[TEST]] Aligned 32
41+
// CHECK: [[X3:%[_0-9A-Za-z]*]] = OpLoad [[UINT]] [[X2]] Aligned 4
42+
// CHECK: [[X4:%[_0-9A-Za-z]*]] = OpAccessChain [[PUUINT]] [[OUTPUT]] [[I0]] [[U0]]
43+
// CHECK: OpStore [[X4]] [[X3]]
44+
// CHECK: OpReturn
45+
// CHECK: OpFunctionEnd
46+

0 commit comments

Comments
 (0)