-
Notifications
You must be signed in to change notification settings - Fork 796
[clang][SYCL] Allow structs as free function kernel arguments #15334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
80491c3
3fbb50c
3e5332e
c8bf832
95cfe15
6d8a28f
7b17d82
efcfc47
4e293da
34524f4
dd52ed5
f9a1641
ef3ca91
431a405
0b88367
ce98062
9885f5a
2eb3d61
9622e3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1263,6 +1263,12 @@ class KernelObjVisitor { | |||||||||||||||||||
| std::initializer_list<int>{(result = result && tn(BD, BDTy), 0)...}; | ||||||||||||||||||||
| return result; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| template <typename... Tn> | ||||||||||||||||||||
| bool handleField(ParmVarDecl *PD, QualType PDTy, Tn &&...tn) { | ||||||||||||||||||||
| bool result = true; | ||||||||||||||||||||
| std::initializer_list<int>{(result = result && tn(PD, PDTy), 0)...}; | ||||||||||||||||||||
| return result; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // This definition using std::bind is necessary because of a gcc 7.x bug. | ||||||||||||||||||||
| #define KF_FOR_EACH(FUNC, Item, Qt) \ | ||||||||||||||||||||
|
|
@@ -1443,9 +1449,12 @@ class KernelObjVisitor { | |||||||||||||||||||
| HandlerTys &...Handlers) { | ||||||||||||||||||||
| if (isSyclSpecialType(ParamTy, SemaSYCLRef)) | ||||||||||||||||||||
| KP_FOR_EACH(handleOtherType, Param, ParamTy); | ||||||||||||||||||||
| else if (ParamTy->isStructureOrClassType()) | ||||||||||||||||||||
| KP_FOR_EACH(handleOtherType, Param, ParamTy); | ||||||||||||||||||||
| else if (ParamTy->isUnionType()) | ||||||||||||||||||||
| else if (ParamTy->isStructureOrClassType()) { | ||||||||||||||||||||
| if (KF_FOR_EACH(handleStructType, Param, ParamTy)) { | ||||||||||||||||||||
| CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl(); | ||||||||||||||||||||
| visitRecord(RD, Param, RD, ParamTy, Handlers...); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } else if (ParamTy->isUnionType()) | ||||||||||||||||||||
| KP_FOR_EACH(handleOtherType, Param, ParamTy); | ||||||||||||||||||||
| else if (ParamTy->isReferenceType()) | ||||||||||||||||||||
| KP_FOR_EACH(handleOtherType, Param, ParamTy); | ||||||||||||||||||||
|
|
@@ -1957,8 +1966,25 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { | |||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| bool handleStructType(ParmVarDecl *PD, QualType ParamTy) final { | ||||||||||||||||||||
| Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy; | ||||||||||||||||||||
| IsInvalid = true; | ||||||||||||||||||||
| CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl(); | ||||||||||||||||||||
| // For free functions all struct/class kernel arguments are forward declared | ||||||||||||||||||||
| // in integration header, that adds additional restrictions for kernel | ||||||||||||||||||||
| // arguments. | ||||||||||||||||||||
| // Lambdas are not forward declarable. So, diagnose them properly. | ||||||||||||||||||||
| if (RD->isLambda()) { | ||||||||||||||||||||
| Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) | ||||||||||||||||||||
| << ParamTy; | ||||||||||||||||||||
| IsInvalid = true; | ||||||||||||||||||||
| return isValid(); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // Check that the type is defined at namespace scope. | ||||||||||||||||||||
| const DeclContext *DeclCtx = RD->getDeclContext(); | ||||||||||||||||||||
| if (!DeclCtx->isTranslationUnit() && !isa<NamespaceDecl>(DeclCtx)) { | ||||||||||||||||||||
| Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) | ||||||||||||||||||||
| << ParamTy; | ||||||||||||||||||||
| IsInvalid = true; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
||||||||||||||||||||
| return isValid(); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -2037,14 +2063,16 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { | |||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { | ||||||||||||||||||||
| // TODO | ||||||||||||||||||||
| unsupportedFreeFunctionParamType(); | ||||||||||||||||||||
| // TODO manipulate struct depth once special types are supported for free | ||||||||||||||||||||
| // function kernels. | ||||||||||||||||||||
| // ++StructFieldDepth; | ||||||||||||||||||||
|
Comment on lines
+2137
to
+2139
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we at least diagnose cases that involve SYCL special types now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC this is already diagnosed by calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually maybe not. It is not obvious to me whether we will hit that code since we aren't decomposing yet. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, we didn't diagnose because we don't decompose yet. I added diagnosing. |
||||||||||||||||||||
| return true; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { | ||||||||||||||||||||
| // TODO | ||||||||||||||||||||
| unsupportedFreeFunctionParamType(); | ||||||||||||||||||||
| // TODO manipulate struct depth once special types are supported for free | ||||||||||||||||||||
| // function kernels. | ||||||||||||||||||||
| // --StructFieldDepth; | ||||||||||||||||||||
| return true; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -2162,8 +2190,7 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler { | |||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| bool handlePointerType(ParmVarDecl *, QualType) final { | ||||||||||||||||||||
| // TODO | ||||||||||||||||||||
| unsupportedFreeFunctionParamType(); | ||||||||||||||||||||
| PointerStack.back() = targetRequiresNewType(SemaSYCLRef.getASTContext()); | ||||||||||||||||||||
| return true; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -2194,8 +2221,10 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler { | |||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| bool enterStruct(const CXXRecordDecl *, ParmVarDecl *, QualType) final { | ||||||||||||||||||||
| // TODO | ||||||||||||||||||||
| unsupportedFreeFunctionParamType(); | ||||||||||||||||||||
| // TODO handle decomposition once special type arguments are supported | ||||||||||||||||||||
| // for free function kernels. | ||||||||||||||||||||
| // CollectionStack.push_back(false); | ||||||||||||||||||||
| PointerStack.push_back(false); | ||||||||||||||||||||
| return true; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -2221,10 +2250,24 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler { | |||||||||||||||||||
| return true; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| bool leaveStruct(const CXXRecordDecl *RD, ParmVarDecl *PD, | ||||||||||||||||||||
| bool leaveStruct(const CXXRecordDecl *, ParmVarDecl *, | ||||||||||||||||||||
| QualType ParamTy) final { | ||||||||||||||||||||
| // TODO | ||||||||||||||||||||
| unsupportedFreeFunctionParamType(); | ||||||||||||||||||||
| CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl(); | ||||||||||||||||||||
| assert(RD && "should not be null."); | ||||||||||||||||||||
| // TODO handle decomposition once special type arguments are supported | ||||||||||||||||||||
| // for free function kernels. | ||||||||||||||||||||
| // if (CollectionStack.pop_back_val()) { | ||||||||||||||||||||
| // if (!RD->hasAttr<SYCLRequiresDecompositionAttr>()) | ||||||||||||||||||||
| // RD->addAttr(SYCLRequiresDecompositionAttr::CreateImplicit( | ||||||||||||||||||||
| // SemaSYCLRef.getASTContext())); | ||||||||||||||||||||
| // CollectionStack.back() = true; | ||||||||||||||||||||
| // PointerStack.pop_back(); | ||||||||||||||||||||
| if (PointerStack.pop_back_val()) { | ||||||||||||||||||||
| PointerStack.back() = true; | ||||||||||||||||||||
| if (!RD->hasAttr<SYCLGenerateNewTypeAttr>()) | ||||||||||||||||||||
| RD->addAttr(SYCLGenerateNewTypeAttr::CreateImplicit( | ||||||||||||||||||||
| SemaSYCLRef.getASTContext())); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| return true; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -2974,8 +3017,15 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { | |||||||||||||||||||
|
|
||||||||||||||||||||
| bool handleNonDecompStruct(const CXXRecordDecl *RD, ParmVarDecl *PD, | ||||||||||||||||||||
| QualType ParamTy) final { | ||||||||||||||||||||
| // TODO | ||||||||||||||||||||
| unsupportedFreeFunctionParamType(); | ||||||||||||||||||||
| // This is a field which should not be decomposed. | ||||||||||||||||||||
| CXXRecordDecl *FieldRecordDecl = ParamTy->getAsCXXRecordDecl(); | ||||||||||||||||||||
|
||||||||||||||||||||
| assert(FieldRecordDecl && "Type must be a C++ record type"); | ||||||||||||||||||||
| // Check if we need to generate a new type for this record, | ||||||||||||||||||||
| // i.e. this record contains pointers. | ||||||||||||||||||||
| if (FieldRecordDecl->hasAttr<SYCLGenerateNewTypeAttr>()) | ||||||||||||||||||||
| addParam(PD, GenerateNewRecordType(FieldRecordDecl)); | ||||||||||||||||||||
| else | ||||||||||||||||||||
| addParam(PD, ParamTy); | ||||||||||||||||||||
| return true; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -3203,8 +3253,7 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler { | |||||||||||||||||||
|
|
||||||||||||||||||||
| bool handleNonDecompStruct(const CXXRecordDecl *, ParmVarDecl *, | ||||||||||||||||||||
| QualType ParamTy) final { | ||||||||||||||||||||
| // TODO | ||||||||||||||||||||
| unsupportedFreeFunctionParamType(); | ||||||||||||||||||||
| addParam(ParamTy); | ||||||||||||||||||||
| return true; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -4194,7 +4243,7 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { | |||||||||||||||||||
|
|
||||||||||||||||||||
| // Creates a DeclRefExpr to the ParmVar that represents the current pointer | ||||||||||||||||||||
| // parameter. | ||||||||||||||||||||
| Expr *createPointerParamReferenceExpr(QualType PointerTy, bool Wrapped) { | ||||||||||||||||||||
| Expr *createPointerParamReferenceExpr(QualType PointerTy) { | ||||||||||||||||||||
| ParmVarDecl *FreeFunctionParameter = | ||||||||||||||||||||
| DeclCreator.getParamVarDeclsForCurrentField()[0]; | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -4212,6 +4261,50 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler { | |||||||||||||||||||
| return DRE; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Expr *createGetAddressOf(Expr *E) { | ||||||||||||||||||||
| return UnaryOperator::Create( | ||||||||||||||||||||
| SemaSYCLRef.getASTContext(), E, UO_AddrOf, | ||||||||||||||||||||
| SemaSYCLRef.getASTContext().getPointerType(E->getType()), VK_PRValue, | ||||||||||||||||||||
| OK_Ordinary, SourceLocation(), false, | ||||||||||||||||||||
| SemaSYCLRef.SemaRef.CurFPFeatureOverrides()); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Expr *createDerefOp(Expr *E) { | ||||||||||||||||||||
| return UnaryOperator::Create(SemaSYCLRef.getASTContext(), E, UO_Deref, | ||||||||||||||||||||
| E->getType()->getPointeeType(), VK_LValue, | ||||||||||||||||||||
| OK_Ordinary, SourceLocation(), false, | ||||||||||||||||||||
| SemaSYCLRef.SemaRef.CurFPFeatureOverrides()); | ||||||||||||||||||||
|
Comment on lines
+4355
to
+4358
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency with surrounding functions. Clang-format might have other ideas though.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clang-format doesn't agree. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. clang-format makes bad choices sometimes :) |
||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Expr *createReinterpretCastExpr(Expr *E, QualType To) { | ||||||||||||||||||||
| return CXXReinterpretCastExpr::Create( | ||||||||||||||||||||
| SemaSYCLRef.getASTContext(), To, VK_PRValue, CK_BitCast, E, | ||||||||||||||||||||
| /*Path=*/nullptr, | ||||||||||||||||||||
| SemaSYCLRef.getASTContext().getTrivialTypeSourceInfo(To), | ||||||||||||||||||||
| SourceLocation(), SourceLocation(), SourceRange()); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Expr *createStructTemporary(ParmVarDecl *OrigFunctionParameter) { | ||||||||||||||||||||
|
||||||||||||||||||||
| Expr *createStructTemporary(ParmVarDecl *OrigFunctionParameter) { | |
| Expr *createCopyInitExpr(ParmVarDecl *OrigFunctionParameter) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing
RDfor both the 1st and 3rd arguments seems surprising here. The situation doesn't seem quite analogous tovisitField()above. I'm having a difficult time figuring out exactly whatvisitRecord()is actually intending to do; the owner/wrapper distinction seems weird to me.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, a lot of code around the visitors is intended to have some "owner" because the original use case is lambda/functor whose fields need to be visited. However it doesn't seem to affect things anyhow and I suspect not having it is fine. Also this comment
llvm/clang/lib/Sema/SemaSYCL.cpp
Line 1301 in 729d6f6
suggests so.
I transformed this argument to
nullptrto avoid confusion.