-
Notifications
You must be signed in to change notification settings - Fork 795
[SYCL] Add support for work group memory free function kernel parameter #15861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 118 commits
652caa8
76daf77
21e082b
025cbc4
b94f7c9
0d6d694
448071f
9f2973a
852315f
4234022
ae5eb7e
8ce0280
4ee31a5
7b1b90b
cf7476e
50c0954
44811b8
ad1046f
d343a2e
3f1bc30
103e233
bfa5830
2031478
76f0acc
e0ad435
3513251
2cec997
4c8b196
a0b70e2
ed8f125
9460876
ac7130a
e2889b3
d6c78b9
8f7a07b
ae59899
8cff603
3ceead1
3228aeb
0e95ee5
52f13f0
71d1013
d48bc42
f6515bc
3e4c73c
c84229e
d2fddd8
0f677c2
6ef823e
6dc262a
4de6d50
5653f04
40eb63e
f6a0df7
026501c
2ce21b3
a9b2875
3821df4
d73b0b1
c1087ad
31481b8
396169f
dc37b2c
236139f
2beda8e
dbafe31
1b968df
e6b66c3
84ef6a8
f24af09
7dfa80b
3957cb5
91820d8
34bc23d
3acf835
604c640
5510208
d9418f9
e90a3b7
3b9a55a
b9ed6f4
af08c19
b2a97a2
77a6de1
3cb0ba4
1783f75
5a6085f
6affbc3
ed3c60f
fd89473
24f87b0
3f8ced8
084eb7d
06a3c37
d96f2e1
a2b44a2
494870a
13dd2b6
4be6df0
42edb22
6bc951e
5a91bfa
6f356ea
ca4b228
98186a5
377945b
c00c1f7
5cb2d21
f4b9b1a
c3426d7
7a6516c
872f671
805b00f
c3b494c
bff19f4
9bab966
69388f0
740f389
0fa7af4
883549a
0997c60
535890a
dcff700
4b2a148
3578639
54ab379
336c90e
5bd6e3b
ec91887
f08478e
3d326a3
dab03c5
ca4f2f3
299ecde
d35ebb5
e3899f7
7b3e5e3
71ee15f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1522,7 +1522,7 @@ class KernelObjVisitor { | |
| void visitParam(ParmVarDecl *Param, QualType ParamTy, | ||
| HandlerTys &...Handlers) { | ||
| if (isSyclSpecialType(ParamTy, SemaSYCLRef)) | ||
| KP_FOR_EACH(handleOtherType, Param, ParamTy); | ||
| KP_FOR_EACH(handleSyclSpecialType, Param, ParamTy); | ||
| else if (ParamTy->isStructureOrClassType()) { | ||
| if (KP_FOR_EACH(handleStructType, Param, ParamTy)) { | ||
| CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl(); | ||
|
|
@@ -2070,8 +2070,11 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { | |
| } | ||
|
|
||
| bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final { | ||
| Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy; | ||
| IsInvalid = true; | ||
| if (!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) { | ||
| Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) | ||
| << ParamTy; | ||
| IsInvalid = true; | ||
| } | ||
| return isValid(); | ||
| } | ||
|
|
||
|
|
@@ -2223,8 +2226,8 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler { | |
| } | ||
|
|
||
| bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final { | ||
| // TODO | ||
| unsupportedFreeFunctionParamType(); | ||
| if (!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) | ||
| unsupportedFreeFunctionParamType(); // TODO | ||
| return true; | ||
| } | ||
|
|
||
|
|
@@ -3008,9 +3011,26 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { | |
| return handleSpecialType(FD, FieldTy); | ||
| } | ||
|
|
||
| bool handleSyclSpecialType(ParmVarDecl *, QualType) final { | ||
| // TODO | ||
| unsupportedFreeFunctionParamType(); | ||
| bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final { | ||
| if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) { | ||
| const auto *RecordDecl = ParamTy->getAsCXXRecordDecl(); | ||
| assert(RecordDecl && "The type must be a RecordDecl"); | ||
| CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName); | ||
| assert(InitMethod && "The type must have the __init method"); | ||
| // Don't do -1 here because we count on this to be the first parameter | ||
| // added (if any). | ||
| size_t ParamIndex = Params.size(); | ||
| for (const ParmVarDecl *Param : InitMethod->parameters()) { | ||
| QualType ParamTy = Param->getType(); | ||
| addParam(Param, ParamTy.getCanonicalType()); | ||
| // Propagate add_ir_attributes_kernel_parameter attribute. | ||
| if (const auto *AddIRAttr = | ||
| Param->getAttr<SYCLAddIRAttributesKernelParameterAttr>()) | ||
| Params.back()->addAttr(AddIRAttr->clone(SemaSYCLRef.getASTContext())); | ||
| } | ||
| LastParamIndex = ParamIndex; | ||
| } else // TODO | ||
| unsupportedFreeFunctionParamType(); | ||
| return true; | ||
| } | ||
|
|
||
|
|
@@ -3286,9 +3306,7 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler { | |
| } | ||
|
|
||
| bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final { | ||
| // TODO | ||
| unsupportedFreeFunctionParamType(); | ||
| return true; | ||
| return handleSpecialType(ParamTy); | ||
| } | ||
|
|
||
| bool handleSyclSpecialType(const CXXRecordDecl *, const CXXBaseSpecifier &BS, | ||
|
|
@@ -4416,6 +4434,45 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { | |
| {}); | ||
| } | ||
|
|
||
| MemberExpr *buildMemberExpr(Expr *Base, ValueDecl *Member) { | ||
| DeclAccessPair MemberDAP = DeclAccessPair::make(Member, AS_none); | ||
| MemberExpr *Result = SemaSYCLRef.SemaRef.BuildMemberExpr( | ||
| Base, /*IsArrow */ false, FreeFunctionSrcLoc, NestedNameSpecifierLoc(), | ||
| FreeFunctionSrcLoc, Member, MemberDAP, | ||
| /*HadMultipleCandidates*/ false, | ||
| DeclarationNameInfo(Member->getDeclName(), FreeFunctionSrcLoc), | ||
| Member->getType(), VK_LValue, OK_Ordinary); | ||
| return Result; | ||
| } | ||
|
|
||
| void createSpecialMethodCall(const CXXRecordDecl *RD, StringRef MethodName, | ||
| Expr *MemberBaseExpr, | ||
| SmallVectorImpl<Stmt *> &AddTo) { | ||
| CXXMethodDecl *Method = getMethodByName(RD, MethodName); | ||
| if (!Method) | ||
| return; | ||
| unsigned NumParams = Method->getNumParams(); | ||
| llvm::SmallVector<Expr *, 4> ParamDREs(NumParams); | ||
| llvm::ArrayRef<ParmVarDecl *> KernelParameters = | ||
| DeclCreator.getParamVarDeclsForCurrentField(); | ||
| for (size_t I = 0; I < NumParams; ++I) { | ||
| QualType ParamType = KernelParameters[I]->getOriginalType(); | ||
| ParamDREs[I] = SemaSYCLRef.SemaRef.BuildDeclRefExpr( | ||
| KernelParameters[I], ParamType, VK_LValue, FreeFunctionSrcLoc); | ||
| } | ||
| MemberExpr *MethodME = buildMemberExpr(MemberBaseExpr, Method); | ||
| QualType ResultTy = Method->getReturnType(); | ||
| ExprValueKind VK = Expr::getValueKindForType(ResultTy); | ||
| ResultTy = ResultTy.getNonLValueExprType(SemaSYCLRef.getASTContext()); | ||
| llvm::SmallVector<Expr *, 4> ParamStmts; | ||
| const auto *Proto = cast<FunctionProtoType>(Method->getType()); | ||
| SemaSYCLRef.SemaRef.GatherArgumentsForCall(FreeFunctionSrcLoc, Method, | ||
| Proto, 0, ParamDREs, ParamStmts); | ||
| AddTo.push_back(CXXMemberCallExpr::Create( | ||
| SemaSYCLRef.getASTContext(), MethodME, ParamStmts, ResultTy, VK, | ||
| FreeFunctionSrcLoc, FPOptionsOverride())); | ||
| } | ||
|
|
||
| public: | ||
| static constexpr const bool VisitInsideSimpleContainers = false; | ||
|
|
||
|
|
@@ -4435,9 +4492,37 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { | |
| return true; | ||
| } | ||
|
|
||
| bool handleSyclSpecialType(ParmVarDecl *, QualType) final { | ||
| // TODO | ||
| unsupportedFreeFunctionParamType(); | ||
| bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final { | ||
| if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) { | ||
| const auto *RecordDecl = ParamTy->getAsCXXRecordDecl(); | ||
| QualType Ty = PD->getOriginalType(); | ||
| ASTContext &Ctx = SemaSYCLRef.SemaRef.getASTContext(); | ||
| VarDecl *WorkGroupMemoryClone = VarDecl::Create( | ||
| Ctx, DeclCreator.getKernelDecl(), FreeFunctionSrcLoc, | ||
| FreeFunctionSrcLoc, PD->getIdentifier(), PD->getType(), | ||
| Ctx.getTrivialTypeSourceInfo(Ty), SC_None); | ||
| InitializedEntity VarEntity = | ||
| InitializedEntity::InitializeVariable(WorkGroupMemoryClone); | ||
| InitializationKind InitKind = | ||
| InitializationKind::CreateDefault(FreeFunctionSrcLoc); | ||
| InitializationSequence InitSeq(SemaSYCLRef.SemaRef, VarEntity, InitKind, | ||
| std::nullopt); | ||
| ExprResult Init = InitSeq.Perform(SemaSYCLRef.SemaRef, VarEntity, | ||
| InitKind, std::nullopt); | ||
| WorkGroupMemoryClone->setInit( | ||
| SemaSYCLRef.SemaRef.MaybeCreateExprWithCleanups(Init.get())); | ||
| WorkGroupMemoryClone->setInitStyle(VarDecl::CallInit); | ||
| Stmt *DS = new (SemaSYCLRef.getASTContext()) | ||
| DeclStmt(DeclGroupRef(WorkGroupMemoryClone), FreeFunctionSrcLoc, | ||
| FreeFunctionSrcLoc); | ||
| BodyStmts.push_back(DS); | ||
| Expr *MemberBaseExpr = SemaSYCLRef.SemaRef.BuildDeclRefExpr( | ||
| WorkGroupMemoryClone, Ty, VK_PRValue, FreeFunctionSrcLoc); | ||
| createSpecialMethodCall(RecordDecl, InitMethodName, MemberBaseExpr, | ||
| BodyStmts); | ||
| ArgExprs.push_back(MemberBaseExpr); | ||
| } else // TODO | ||
| unsupportedFreeFunctionParamType(); | ||
| return true; | ||
| } | ||
|
|
||
|
|
@@ -4717,9 +4802,11 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { | |
| return true; | ||
| } | ||
|
|
||
| bool handleSyclSpecialType(ParmVarDecl *, QualType) final { | ||
| // TODO | ||
| unsupportedFreeFunctionParamType(); | ||
| bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final { | ||
| if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) | ||
| addParam(PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory); | ||
| else | ||
| unsupportedFreeFunctionParamType(); // TODO | ||
| return true; | ||
| } | ||
|
|
||
|
|
@@ -6196,7 +6283,16 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { | |
| O << "#include <sycl/detail/defines_elementary.hpp>\n"; | ||
| O << "#include <sycl/detail/kernel_desc.hpp>\n"; | ||
| O << "#include <sycl/ext/oneapi/experimental/free_function_traits.hpp>\n"; | ||
|
|
||
| // When using work group memory parameters in free kernel functions, the | ||
|
||
| // integration header emits incorrect forward declarations for the work group | ||
| // memory type because it ignores default arguments. This means the user | ||
| // cannot use work group memory types with parameters omitted such as | ||
| // work_group_memory<int> where the hidden second parameter has a default | ||
| // value. To circumvent this, we include the correct forward declaration | ||
| // ourselves. | ||
| O << "#include <tuple>\n"; | ||
| O << "#include " | ||
| "<sycl/ext/oneapi/experimental/work_group_memory_forward_decl.hpp>\n"; | ||
| O << "\n"; | ||
|
|
||
| LangOptions LO; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| #pragma once | ||
| // Dummy header file for the purpose of SemaSYCL testing. | ||
| // It shadows the file | ||
| // sycl/include/sycl/ext/oneapi/experimental/work_group_memory_forward_decl.hpp | ||
| namespace sycl { | ||
| inline namespace _V1 { | ||
| namespace ext { | ||
| namespace oneapi { | ||
| namespace experimental { | ||
| template <typename DataT, typename PropertiesT = int> | ||
| class work_group_memory; | ||
| } // namespace experimental | ||
| } // namespace oneapi | ||
| } // namespace ext | ||
| } // namespace _V1 | ||
| } // namespace sycl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a new issue but there is a significant overlap between code for kernel body and free function body creator. I wonder if this can be refactored (not in this PR since it is an orthogonal issue) so that we don't duplicate so much code. @Fznamznon can you weigh in here since you are implementing free function functionality now.