Skip to content

Commit 9c3cae8

Browse files
committed
Add special type parameter support for free function kernels
1 parent ff1c3b4 commit 9c3cae8

File tree

1 file changed

+43
-24
lines changed

1 file changed

+43
-24
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,12 +2084,8 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
20842084
}
20852085

20862086
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
2087-
if (!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
2088-
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type)
2089-
<< ParamTy;
2090-
IsInvalid = true;
2091-
}
2092-
return isValid();
2087+
IsInvalid |= checkSyclSpecialType(ParamTy, PD->getLocation());
2088+
return isValid();
20932089
}
20942090

20952091
bool handleArrayType(FieldDecl *FD, QualType FieldTy) final {
@@ -2240,9 +2236,7 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
22402236
}
22412237

22422238
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
2243-
if (!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory))
2244-
unsupportedFreeFunctionParamType(); // TODO
2245-
return true;
2239+
return checkType(PD->getLocation(), ParamTy);
22462240
}
22472241

22482242
bool handleSyclSpecialType(const CXXRecordDecl *, const CXXBaseSpecifier &BS,
@@ -3026,7 +3020,6 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
30263020
}
30273021

30283022
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
3029-
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
30303023
const auto *RecordDecl = ParamTy->getAsCXXRecordDecl();
30313024
assert(RecordDecl && "The type must be a RecordDecl");
30323025
CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName);
@@ -3043,8 +3036,6 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
30433036
Params.back()->addAttr(AddIRAttr->clone(SemaSYCLRef.getASTContext()));
30443037
}
30453038
LastParamIndex = ParamIndex;
3046-
} else // TODO
3047-
unsupportedFreeFunctionParamType();
30483039
return true;
30493040
}
30503041

@@ -4538,7 +4529,6 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
45384529
// TODO: Revisit this approach once https://github.com/intel/llvm/issues/16061
45394530
// is closed.
45404531
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
4541-
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
45424532
const auto *RecordDecl = ParamTy->getAsCXXRecordDecl();
45434533
AccessSpecifier DefaultConstructorAccess;
45444534
auto DefaultConstructor =
@@ -4549,34 +4539,32 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
45494539

45504540
QualType Ty = PD->getOriginalType();
45514541
ASTContext &Ctx = SemaSYCLRef.SemaRef.getASTContext();
4552-
VarDecl *WorkGroupMemoryClone = VarDecl::Create(
4542+
VarDecl *SpecialObjectClone = VarDecl::Create(
45534543
Ctx, DeclCreator.getKernelDecl(), FreeFunctionSrcLoc,
45544544
FreeFunctionSrcLoc, PD->getIdentifier(), PD->getType(),
45554545
Ctx.getTrivialTypeSourceInfo(Ty), SC_None);
45564546
InitializedEntity VarEntity =
4557-
InitializedEntity::InitializeVariable(WorkGroupMemoryClone);
4547+
InitializedEntity::InitializeVariable(SpecialObjectClone);
45584548
InitializationKind InitKind =
45594549
InitializationKind::CreateDefault(FreeFunctionSrcLoc);
45604550
InitializationSequence InitSeq(SemaSYCLRef.SemaRef, VarEntity, InitKind,
45614551
std::nullopt);
45624552
ExprResult Init = InitSeq.Perform(SemaSYCLRef.SemaRef, VarEntity,
45634553
InitKind, std::nullopt);
4564-
WorkGroupMemoryClone->setInit(
4554+
SpecialObjectClone->setInit(
45654555
SemaSYCLRef.SemaRef.MaybeCreateExprWithCleanups(Init.get()));
4566-
WorkGroupMemoryClone->setInitStyle(VarDecl::CallInit);
4556+
SpecialObjectClone->setInitStyle(VarDecl::CallInit);
45674557
DefaultConstructor->setAccess(DefaultConstructorAccess);
45684558

45694559
Stmt *DS = new (SemaSYCLRef.getASTContext())
4570-
DeclStmt(DeclGroupRef(WorkGroupMemoryClone), FreeFunctionSrcLoc,
4560+
DeclStmt(DeclGroupRef(SpecialObjectClone), FreeFunctionSrcLoc,
45714561
FreeFunctionSrcLoc);
45724562
BodyStmts.push_back(DS);
45734563
Expr *MemberBaseExpr = SemaSYCLRef.SemaRef.BuildDeclRefExpr(
4574-
WorkGroupMemoryClone, Ty, VK_PRValue, FreeFunctionSrcLoc);
4564+
SpecialObjectClone, Ty, VK_PRValue, FreeFunctionSrcLoc);
45754565
createSpecialMethodCall(RecordDecl, InitMethodName, MemberBaseExpr,
45764566
BodyStmts);
45774567
ArgExprs.push_back(MemberBaseExpr);
4578-
} else // TODO
4579-
unsupportedFreeFunctionParamType();
45804568
return true;
45814569
}
45824570

@@ -4862,10 +4850,41 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
48624850
}
48634851

48644852
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
4865-
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory))
4853+
const auto *ClassTy = ParamTy->getAsCXXRecordDecl();
4854+
assert(ClassTy && "Type must be a C++ record type");
4855+
if (isSyclAccessorType(ParamTy)) {
4856+
const auto *AccTy =
4857+
cast<ClassTemplateSpecializationDecl>(ParamTy->getAsRecordDecl());
4858+
assert(AccTy->getTemplateArgs().size() >= 2 &&
4859+
"Incorrect template args for Accessor Type");
4860+
int Dims = static_cast<int>(
4861+
AccTy->getTemplateArgs()[1].getAsIntegral().getExtValue());
4862+
int Info = getAccessTarget(ParamTy, AccTy) | (Dims << 11);
4863+
4864+
Header.addParamDesc(SYCLIntegrationHeader::kind_accessor, Info,
4865+
CurOffset);
4866+
} else if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::stream)) {
4867+
addParam(PD, ParamTy, SYCLIntegrationHeader::kind_stream);
4868+
} else if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
48664869
addParam(PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory);
4867-
else
4868-
unsupportedFreeFunctionParamType(); // TODO
4870+
} else if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::sampler) ||
4871+
SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::annotated_ptr) ||
4872+
SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::annotated_arg)) {
4873+
CXXMethodDecl *InitMethod = getMethodByName(ClassTy, InitMethodName);
4874+
assert(InitMethod && "type must have __init method");
4875+
const ParmVarDecl *InitArg = InitMethod->getParamDecl(0);
4876+
assert(InitArg && "Init method must have arguments");
4877+
QualType T = InitArg->getType();
4878+
SYCLIntegrationHeader::kernel_param_kind_t ParamKind =
4879+
SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::sampler)
4880+
? SYCLIntegrationHeader::kind_sampler
4881+
: (T->isPointerType() ? SYCLIntegrationHeader::kind_pointer
4882+
: SYCLIntegrationHeader::kind_std_layout);
4883+
addParam(PD, ParamTy, ParamKind);
4884+
} else {
4885+
llvm_unreachable(
4886+
"Unexpected SYCL special class when generating integration header");
4887+
}
48694888
return true;
48704889
}
48714890

0 commit comments

Comments
 (0)