77// ===----------------------------------------------------------------------===//
88// This implements Semantic Analysis for SYCL constructs.
99// ===----------------------------------------------------------------------===//
10-
10+ # include < iostream >
1111#include " clang/Sema/SemaSYCL.h"
1212#include " TreeTransform.h"
1313#include " clang/AST/AST.h"
@@ -1521,8 +1521,9 @@ class KernelObjVisitor {
15211521 template <typename ... HandlerTys>
15221522 void visitParam (ParmVarDecl *Param, QualType ParamTy,
15231523 HandlerTys &...Handlers) {
1524- if (isSyclSpecialType (ParamTy, SemaSYCLRef))
1525- KP_FOR_EACH (handleOtherType, Param, ParamTy);
1524+ if (isSyclSpecialType (ParamTy, SemaSYCLRef)){
1525+ KP_FOR_EACH (handleSyclSpecialType, Param, ParamTy);
1526+ }
15261527 else if (ParamTy->isStructureOrClassType ()) {
15271528 if (KP_FOR_EACH (handleStructType, Param, ParamTy)) {
15281529 CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl ();
@@ -2070,8 +2071,28 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
20702071 }
20712072
20722073 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
2073- Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type) << ParamTy;
2074- IsInvalid = true ;
2074+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
2075+ CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl ();
2076+ // For free functions all struct/class kernel arguments are forward declared
2077+ // in integration header, that adds additional restrictions for kernel
2078+ // arguments.
2079+ NotForwardDeclarableReason NFDR =
2080+ isForwardDeclarable (RD, SemaSYCLRef, /* DiagForFreeFunction=*/ true );
2081+ if (NFDR != NotForwardDeclarableReason::None) {
2082+ Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type)
2083+ << ParamTy;
2084+ Diag.Report (PD->getLocation (),
2085+ diag::note_free_function_kernel_param_type_not_fwd_declarable)
2086+ << ParamTy;
2087+ IsInvalid = true ;
2088+ } else
2089+ IsInvalid = false ;
2090+ }
2091+ else {
2092+ Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type)
2093+ << ParamTy;
2094+ IsInvalid = true ;
2095+ }
20752096 return isValid ();
20762097 }
20772098
@@ -2224,6 +2245,7 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
22242245
22252246 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
22262247 // TODO
2248+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) return true ;
22272249 unsupportedFreeFunctionParamType ();
22282250 return true ;
22292251 }
@@ -2262,7 +2284,7 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler {
22622284 return true ;
22632285 }
22642286
2265- bool handleSyclSpecialType (ParmVarDecl *, QualType) final {
2287+ bool handleSyclSpecialType (ParmVarDecl *, QualType Ty ) final {
22662288 // TODO We don't support special types in free function kernel parameters,
22672289 // but track them to diagnose the case properly.
22682290 CollectionStack.back () = true ;
@@ -3008,10 +3030,29 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
30083030 return handleSpecialType (FD, FieldTy);
30093031 }
30103032
3011- bool handleSyclSpecialType (ParmVarDecl *, QualType) final {
3033+ bool handleSyclSpecialType (ParmVarDecl *PD , QualType ParamTy ) final {
30123034 // TODO
3013- unsupportedFreeFunctionParamType ();
3014- return true ;
3035+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
3036+ const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
3037+ assert (RecordDecl && " The type must be a RecordDecl" );
3038+ CXXMethodDecl *InitMethod = getMethodByName (RecordDecl, InitMethodName);
3039+ assert (InitMethod && " The type must have the __init method" );
3040+
3041+ // Don't do -1 here because we count on this to be the first parameter
3042+ // added (if any).
3043+ size_t ParamIndex = Params.size ();
3044+ for (const ParmVarDecl *Param : InitMethod->parameters ()) {
3045+ QualType ParamTy = Param->getType ();
3046+ addParam (Param, ParamTy.getCanonicalType ());
3047+ // Propagate add_ir_attributes_kernel_parameter attribute.
3048+ if (const auto *AddIRAttr =
3049+ Param->getAttr <SYCLAddIRAttributesKernelParameterAttr>())
3050+ Params.back ()->addAttr (AddIRAttr->clone (SemaSYCLRef.getASTContext ()));
3051+ }
3052+ LastParamIndex = ParamIndex;
3053+ } else
3054+ unsupportedFreeFunctionParamType ();
3055+ return true ;
30153056 }
30163057
30173058 RecordDecl *wrapField (FieldDecl *Field, QualType FieldTy) {
@@ -3286,9 +3327,7 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
32863327 }
32873328
32883329 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
3289- // TODO
3290- unsupportedFreeFunctionParamType ();
3291- return true ;
3330+ return handleSpecialType (ParamTy);
32923331 }
32933332
32943333 bool handleSyclSpecialType (const CXXRecordDecl *, const CXXBaseSpecifier &BS,
@@ -3601,6 +3640,8 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
36013640
36023641 BodyStmts.insert (BodyStmts.end (), FinalizeStmts.begin (),
36033642 FinalizeStmts.end ());
3643+ // CompoundStmt::Create(SemaSYCLRef.getASTContext(), BodyStmts,
3644+ // FPOptionsOverride(), {}, {})->dumpPretty(SemaSYCLRef.getASTContext());
36043645
36053646 return CompoundStmt::Create (SemaSYCLRef.getASTContext (), BodyStmts,
36063647 FPOptionsOverride (), {}, {});
@@ -4118,6 +4159,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
41184159 annotateHierarchicalParallelismAPICalls ();
41194160 CompoundStmt *KernelBody = createKernelBody ();
41204161 DeclCreator.setBody (KernelBody);
4162+ DeclCreator.getKernelDecl ()->dump ();
41214163 }
41224164
41234165 bool handleSyclSpecialType (FieldDecl *FD, QualType Ty) final {
@@ -4308,15 +4350,11 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
43084350 SourceLocation FreeFunctionSrcLoc; // Free function source location.
43094351 llvm::SmallVector<Expr *, 8 > ArgExprs;
43104352
4311- // Creates a DeclRefExpr to the ParmVar that represents the current free
4312- // function parameter.
4313- Expr *createParamReferenceExpr () {
4314- ParmVarDecl *FreeFunctionParameter =
4315- DeclCreator.getParamVarDeclsForCurrentField ()[0 ];
4316-
4317- QualType FreeFunctionParamType = FreeFunctionParameter->getOriginalType ();
4353+ // Creates a DeclRefExpr to the ParmVar given be PD
4354+ Expr *createParamReferenceExpr (ParmVarDecl *PD = nullptr ) {
4355+ ParmVarDecl *FreeFunctionParameter = PD ? PD : DeclCreator.getParamVarDeclsForCurrentField ()[0 ];
43184356 Expr *DRE = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4319- FreeFunctionParameter, FreeFunctionParamType , VK_LValue,
4357+ FreeFunctionParameter, FreeFunctionParameter-> getType () , VK_LValue,
43204358 FreeFunctionSrcLoc);
43214359 DRE = SemaSYCLRef.SemaRef .DefaultLvalueConversion (DRE).get ();
43224360 return DRE;
@@ -4412,8 +4450,52 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
44124450 auto CallExpr = CallExpr::Create (Context, Fn, ArgExprs, ResultTy, VK,
44134451 FreeFunctionSrcLoc, FPOptionsOverride ());
44144452 BodyStmts.push_back (CallExpr);
4415- return CompoundStmt::Create (Context, BodyStmts, FPOptionsOverride (), {},
4416- {});
4453+ // CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), {},
4454+ // {})->dumpPretty(SemaSYCLRef.getASTContext());
4455+ return CompoundStmt::Create (Context, BodyStmts, FPOptionsOverride (), {},
4456+ {});
4457+ }
4458+
4459+ MemberExpr *buildMemberExpr (Expr *Base, ValueDecl *Member) {
4460+ DeclAccessPair MemberDAP = DeclAccessPair::make (Member, AS_none);
4461+ MemberExpr *Result = SemaSYCLRef.SemaRef .BuildMemberExpr (
4462+ Base, /* IsArrow */ false , FreeFunctionSrcLoc, NestedNameSpecifierLoc (),
4463+ FreeFunctionSrcLoc, Member, MemberDAP,
4464+ /* HadMultipleCandidates*/ false ,
4465+ DeclarationNameInfo (Member->getDeclName (), FreeFunctionSrcLoc),
4466+ Member->getType (), VK_LValue, OK_Ordinary);
4467+ return Result;
4468+ }
4469+
4470+ void createSpecialMethodCall (const CXXRecordDecl *RD, StringRef MethodName,
4471+ Expr *MemberBaseExpr,
4472+ SmallVectorImpl<Stmt *> &AddTo) {
4473+ CXXMethodDecl *Method = getMethodByName (RD, MethodName);
4474+ if (!Method)
4475+ return ;
4476+
4477+ unsigned NumParams = Method->getNumParams ();
4478+ llvm::SmallVector<Expr *, 4 > ParamDREs (NumParams);
4479+ llvm::ArrayRef<ParmVarDecl *> KernelParameters =
4480+ DeclCreator.getParamVarDeclsForCurrentField ();
4481+ for (size_t I = 0 ; I < NumParams; ++I) {
4482+ QualType ParamType = KernelParameters[I]->getOriginalType ();
4483+ ParamDREs[I] = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4484+ KernelParameters[I], ParamType, VK_LValue, FreeFunctionSrcLoc);
4485+ }
4486+
4487+ MemberExpr *MethodME = buildMemberExpr (MemberBaseExpr, Method);
4488+
4489+ QualType ResultTy = Method->getReturnType ();
4490+ ExprValueKind VK = Expr::getValueKindForType (ResultTy);
4491+ ResultTy = ResultTy.getNonLValueExprType (SemaSYCLRef.getASTContext ());
4492+ llvm::SmallVector<Expr *, 4 > ParamStmts;
4493+ const auto *Proto = cast<FunctionProtoType>(Method->getType ());
4494+ SemaSYCLRef.SemaRef .GatherArgumentsForCall (FreeFunctionSrcLoc, Method,
4495+ Proto, 0 , ParamDREs, ParamStmts);
4496+ AddTo.push_back (CXXMemberCallExpr::Create (
4497+ SemaSYCLRef.getASTContext (), MethodME, ParamStmts, ResultTy, VK,
4498+ FreeFunctionSrcLoc, FPOptionsOverride ()));
44174499 }
44184500
44194501public:
@@ -4427,6 +4509,7 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
44274509 ~FreeFunctionKernelBodyCreator () {
44284510 CompoundStmt *KernelBody = createFreeFunctionKernelBody ();
44294511 DeclCreator.setBody (KernelBody);
4512+ DeclCreator.getKernelDecl ()->dump ();
44304513 }
44314514
44324515 bool handleSyclSpecialType (FieldDecl *FD, QualType Ty) final {
@@ -4435,9 +4518,41 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
44354518 return true ;
44364519 }
44374520
4438- bool handleSyclSpecialType (ParmVarDecl *, QualType) final {
4521+ bool handleSyclSpecialType (ParmVarDecl *PD , QualType ParamTy ) final {
44394522 // TODO
4440- unsupportedFreeFunctionParamType ();
4523+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
4524+ const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
4525+ QualType Ty = PD->getOriginalType ();
4526+ ASTContext &Ctx = SemaSYCLRef.SemaRef .getASTContext ();
4527+ VarDecl *WorkGroupMemoryClone = VarDecl::Create (
4528+ Ctx, DeclCreator.getKernelDecl (), FreeFunctionSrcLoc, FreeFunctionSrcLoc, PD->getIdentifier (),
4529+ PD->getType (),
4530+ Ctx.getTrivialTypeSourceInfo (Ty), SC_None);
4531+ InitializedEntity VarEntity =
4532+ InitializedEntity::InitializeVariable (WorkGroupMemoryClone);
4533+ InitializationKind InitKind =
4534+ InitializationKind::CreateDefault (FreeFunctionSrcLoc);
4535+ InitializationSequence InitSeq (SemaSYCLRef.SemaRef , VarEntity, InitKind,
4536+ std::nullopt );
4537+ ExprResult Init = InitSeq.Perform (SemaSYCLRef.SemaRef , VarEntity,
4538+ InitKind, std::nullopt );
4539+ WorkGroupMemoryClone->setInit (
4540+ SemaSYCLRef.SemaRef .MaybeCreateExprWithCleanups (Init.get ()));
4541+ WorkGroupMemoryClone->setInitStyle (VarDecl::CallInit);
4542+ Stmt *DS = new (SemaSYCLRef.getASTContext ())
4543+ DeclStmt (DeclGroupRef (WorkGroupMemoryClone), FreeFunctionSrcLoc,
4544+ FreeFunctionSrcLoc);
4545+ BodyStmts.push_back (DS);
4546+ Expr *MemberBaseExpr = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4547+ WorkGroupMemoryClone, Ty, VK_LValue, FreeFunctionSrcLoc);
4548+ // createSpecialMethodCall(RecordDecl, InitMethodName, MemberBaseExpr,
4549+ // BodyStmts);
4550+ Expr *RvalueMemberBaseExpr = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4551+ WorkGroupMemoryClone, Ty, VK_PRValue, FreeFunctionSrcLoc);
4552+
4553+ ArgExprs.push_back (RvalueMemberBaseExpr);
4554+ } else
4555+ unsupportedFreeFunctionParamType ();
44414556 return true ;
44424557 }
44434558
@@ -4717,9 +4832,11 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
47174832 return true ;
47184833 }
47194834
4720- bool handleSyclSpecialType (ParmVarDecl *, QualType) final {
4835+ bool handleSyclSpecialType (ParmVarDecl *PD , QualType ParamTy ) final {
47214836 // TODO
4722- unsupportedFreeFunctionParamType ();
4837+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory))
4838+ addParam (PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory);
4839+ else unsupportedFreeFunctionParamType ();
47234840 return true ;
47244841 }
47254842
@@ -4907,7 +5024,6 @@ class SYCLKernelNameTypeVisitor
49075024 void Visit (QualType T) {
49085025 if (T.isNull ())
49095026 return ;
4910-
49115027 const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
49125028 // If KernelNameType has template args visit each template arg via
49135029 // ConstTemplateArgumentVisitor
@@ -5566,11 +5682,9 @@ void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
55665682 KernelObjVisitor Visitor{*this };
55675683
55685684 DiagnosingSYCLKernel = true ;
5569-
55705685 // Check parameters of free function.
55715686 Visitor.VisitFunctionParameters (FD, DecompMarker, FieldChecker,
55725687 UnionChecker);
5573-
55745688 DiagnosingSYCLKernel = false ;
55755689
55765690 // Ignore the free function if any of the checkers fail validation.
@@ -5911,7 +6025,7 @@ class SYCLFwdDeclEmitter
59116025 void Visit (QualType T) {
59126026 if (T.isNull ())
59136027 return ;
5914- InnerTypeVisitor::Visit (T.getTypePtr ());
6028+ InnerTypeVisitor::Visit (T.getTypePtr ());
59156029 }
59166030
59176031 void VisitReferenceType (const ReferenceType *RT) {
@@ -5935,7 +6049,7 @@ class SYCLFwdDeclEmitter
59356049 }
59366050
59376051 void VisitTagType (const TagType *T) {
5938- TagDecl *TD = T->getDecl ();
6052+ TagDecl *TD = T->getDecl ();
59396053 if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(TD)) {
59406054 // - first, recurse into template parameters and emit needed forward
59416055 // declarations
@@ -6015,13 +6129,13 @@ class SYCLFwdDeclEmitter
60156129 VisitTemplateArgs (TA.getPackAsArray ());
60166130 }
60176131
6018- void VisitFunctionProtoType (const FunctionProtoType *T) {
6019- for (const auto Ty : T->getParamTypes ())
6132+ void VisitFunctionProtoType (const FunctionProtoType *T) {
6133+ for (const auto Ty : T->getParamTypes ())
60206134 Visit (Ty.getCanonicalType ());
60216135 // So far this visitor method is only used for free function kernels whose
60226136 // return type is void anyway, so it is not visited. Otherwise, add if
60236137 // required.
6024- }
6138+ }
60256139};
60266140
60276141class SYCLKernelNameTypePrinter
@@ -6196,6 +6310,16 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
61966310 O << " #include <sycl/detail/defines_elementary.hpp>\n " ;
61976311 O << " #include <sycl/detail/kernel_desc.hpp>\n " ;
61986312 O << " #include <sycl/ext/oneapi/experimental/free_function_traits.hpp>\n " ;
6313+ O << " #include <sycl/ext/oneapi/properties/properties.hpp>\n " ;
6314+ O << " namespace sycl { inline namespace _V1 { namespace ext { namespace "
6315+ " oneapi { namespace experimental {\n "
6316+ " template<typename DataT, typename PropertiesT = "
6317+ " properties<std::tuple<>>> class work_group_memory;\n "
6318+ " }\n "
6319+ " }\n "
6320+ " }\n "
6321+ " }\n "
6322+ " }\n " ;
61996323
62006324 O << " \n " ;
62016325
@@ -6465,8 +6589,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
64656589 O << " \n " ;
64666590 O << " // Forward declarations of kernel and its argument types:\n " ;
64676591 FwdDeclEmitter.Visit (K.SyclKernel ->getType ());
6468- O << " \n " ;
6469-
6592+ O << " \n " ;
64706593 if (K.SyclKernel ->getLanguageLinkage () == CLanguageLinkage)
64716594 O << " extern \" C\" " ;
64726595 std::string ParmList;
0 commit comments