Skip to content

Commit 607518b

Browse files
committed
[HLSL] Update type for out arguments on template instantiation only for dependent params types
Non-dependent argument types have already been converted to a reference and the template instantiation should not change that. Fixes #163648
1 parent 3f3af56 commit 607518b

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

clang/lib/Sema/SemaTemplateInstantiateDecl.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -765,10 +765,18 @@ static bool isRelevantAttr(Sema &S, const Decl *D, const Attr *A) {
765765

766766
static void instantiateDependentHLSLParamModifierAttr(
767767
Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
768-
const HLSLParamModifierAttr *Attr, Decl *New) {
769-
ParmVarDecl *P = cast<ParmVarDecl>(New);
770-
P->addAttr(Attr->clone(S.getASTContext()));
771-
P->setType(S.HLSL().getInoutParameterType(P->getType()));
768+
const HLSLParamModifierAttr *Attr, const Decl *Old, Decl *New) {
769+
ParmVarDecl *NewParm = cast<ParmVarDecl>(New);
770+
NewParm->addAttr(Attr->clone(S.getASTContext()));
771+
772+
const Type *OldParmTy = cast<ParmVarDecl>(Old)->getType().getTypePtr();
773+
if (OldParmTy->isDependentType())
774+
NewParm->setType(S.HLSL().getInoutParameterType(NewParm->getType()));
775+
776+
assert(!Attr->isAnyOut() || (NewParm->getType().isRestrictQualified() &&
777+
NewParm->getType()->isReferenceType()) &&
778+
"out or inout parameter type must be a "
779+
"reference and restrict qualified");
772780
}
773781

774782
void Sema::InstantiateAttrsForDecl(
@@ -923,7 +931,7 @@ void Sema::InstantiateAttrs(const MultiLevelTemplateArgumentList &TemplateArgs,
923931

924932
if (const auto *ParamAttr = dyn_cast<HLSLParamModifierAttr>(TmplAttr)) {
925933
instantiateDependentHLSLParamModifierAttr(*this, TemplateArgs, ParamAttr,
926-
New);
934+
Tmpl, New);
927935
continue;
928936
}
929937

clang/test/SemaHLSL/Language/TemplateOutArg.hlsl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,44 @@ T buzz(int X, T Y) {
195195
return X + Y;
196196
}
197197

198+
// Case 4: Verify that the parameter modifier attributes are instantiated
199+
// for both templated and non-templated arguments, and that the non-templated
200+
// out argument type is not modified by the template instantiation.
201+
202+
// CHECK-LABEL: FunctionTemplateDecl {{.*}} fizz_two
203+
204+
// Check the pattern decl.
205+
// CHECK: FunctionDecl {{.*}} fizz_two 'void (inout T, out int)'
206+
// CHECK-NEXT: ParmVarDecl {{.*}} referenced V 'T'
207+
// CHECK-NEXT: HLSLParamModifierAttr {{.*}} inout
208+
// CHECK-NEXT: ParmVarDecl {{.*}} referenced I 'int &__restrict'
209+
// CHECK-NEXT: HLSLParamModifierAttr {{.*}} out
210+
211+
// Check the 3 instantiations (int, float, & double).
212+
213+
// CHECK-LABEL: FunctionDecl {{.*}} used fizz_two 'void (inout int, out int)' implicit_instantiation
214+
// CHECK: ParmVarDecl {{.*}} used V 'int &__restrict'
215+
// CHECK-NEXT: HLSLParamModifierAttr {{.*}} inout
216+
// CHECK: ParmVarDecl {{.*}} used I 'int &__restrict'
217+
// CHECK-NEXT: HLSLParamModifierAttr {{.*}} out
218+
219+
// CHECK-LABEL: FunctionDecl {{.*}} used fizz_two 'void (inout float, out int)' implicit_instantiation
220+
// CHECK: ParmVarDecl {{.*}} used V 'float &__restrict'
221+
// CHECK-NEXT: HLSLParamModifierAttr {{.*}} inout
222+
// CHECK: ParmVarDecl {{.*}} used I 'int &__restrict'
223+
// CHECK-NEXT: HLSLParamModifierAttr {{.*}} out
224+
225+
// CHECK-LABEL: FunctionDecl {{.*}} used fizz_two 'void (inout double, out int)' implicit_instantiation
226+
// CHECK: ParmVarDecl {{.*}} used V 'double &__restrict'
227+
// CHECK-NEXT: HLSLParamModifierAttr {{.*}} inout
228+
// CHECK: ParmVarDecl {{.*}} used I 'int &__restrict'
229+
// CHECK-NEXT: HLSLParamModifierAttr {{.*}} out
230+
template <typename T>
231+
void fizz_two(inout T V, out int I) {
232+
V += 2;
233+
I = V;
234+
}
235+
198236
export void caller() {
199237
int X = 2;
200238
float Y = 3.3;
@@ -211,4 +249,8 @@ export void caller() {
211249
X = buzz(X, X);
212250
Y = buzz(X, Y);
213251
Z = buzz(X, Z);
252+
253+
fizz_two(X, X);
254+
fizz_two(Y, X);
255+
fizz_two(Z, X);
214256
}

0 commit comments

Comments
 (0)