@@ -1522,7 +1522,7 @@ class KernelObjVisitor {
15221522  void  visitParam (ParmVarDecl *Param, QualType ParamTy,
15231523                  HandlerTys &...Handlers) {
15241524    if  (isSyclSpecialType (ParamTy, SemaSYCLRef))
1525-       KP_FOR_EACH (handleOtherType , Param, ParamTy);
1525+       KP_FOR_EACH (handleSyclSpecialType , Param, ParamTy);
15261526    else  if  (ParamTy->isStructureOrClassType ()) {
15271527      if  (KP_FOR_EACH (handleStructType, Param, ParamTy)) {
15281528        CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl ();
@@ -2075,8 +2075,11 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
20752075  }
20762076
20772077  bool  handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final  {
2078-     Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type) << ParamTy;
2079-     IsInvalid = true ;
2078+     if  (!SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
2079+       Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type)
2080+           << ParamTy;
2081+       IsInvalid = true ;
2082+     }
20802083    return  isValid ();
20812084  }
20822085
@@ -2228,8 +2231,8 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
22282231  }
22292232
22302233  bool  handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final  {
2231-     //  TODO 
2232-     unsupportedFreeFunctionParamType ();
2234+     if  (! SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) 
2235+        unsupportedFreeFunctionParamType ();  //  TODO 
22332236    return  true ;
22342237  }
22352238
@@ -3013,9 +3016,26 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
30133016    return  handleSpecialType (FD, FieldTy);
30143017  }
30153018
3016-   bool  handleSyclSpecialType (ParmVarDecl *, QualType) final  {
3017-     //  TODO
3018-     unsupportedFreeFunctionParamType ();
3019+   bool  handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final  {
3020+     if  (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
3021+       const  auto  *RecordDecl = ParamTy->getAsCXXRecordDecl ();
3022+       assert (RecordDecl && " The type must be a RecordDecl"  );
3023+       CXXMethodDecl *InitMethod = getMethodByName (RecordDecl, InitMethodName);
3024+       assert (InitMethod && " The type must have the __init method"  );
3025+       //  Don't do -1 here because we count on this to be the first parameter
3026+       //  added (if any).
3027+       size_t  ParamIndex = Params.size ();
3028+       for  (const  ParmVarDecl *Param : InitMethod->parameters ()) {
3029+         QualType ParamTy = Param->getType ();
3030+         addParam (Param, ParamTy.getCanonicalType ());
3031+         //  Propagate add_ir_attributes_kernel_parameter attribute.
3032+         if  (const  auto  *AddIRAttr =
3033+                 Param->getAttr <SYCLAddIRAttributesKernelParameterAttr>())
3034+           Params.back ()->addAttr (AddIRAttr->clone (SemaSYCLRef.getASTContext ()));
3035+       }
3036+       LastParamIndex = ParamIndex;
3037+     } else  //  TODO
3038+       unsupportedFreeFunctionParamType ();
30193039    return  true ;
30203040  }
30213041
@@ -3291,9 +3311,7 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
32913311  }
32923312
32933313  bool  handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final  {
3294-     //  TODO
3295-     unsupportedFreeFunctionParamType ();
3296-     return  true ;
3314+     return  handleSpecialType (ParamTy);
32973315  }
32983316
32993317  bool  handleSyclSpecialType (const  CXXRecordDecl *, const  CXXBaseSpecifier &BS,
@@ -4442,6 +4460,45 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
44424460                                {});
44434461  }
44444462
4463+   MemberExpr *buildMemberExpr (Expr *Base, ValueDecl *Member) {
4464+     DeclAccessPair MemberDAP = DeclAccessPair::make (Member, AS_none);
4465+     MemberExpr *Result = SemaSYCLRef.SemaRef .BuildMemberExpr (
4466+         Base, /* IsArrow */   false , FreeFunctionSrcLoc, NestedNameSpecifierLoc (),
4467+         FreeFunctionSrcLoc, Member, MemberDAP,
4468+         /* HadMultipleCandidates*/   false ,
4469+         DeclarationNameInfo (Member->getDeclName (), FreeFunctionSrcLoc),
4470+         Member->getType (), VK_LValue, OK_Ordinary);
4471+     return  Result;
4472+   }
4473+ 
4474+   void  createSpecialMethodCall (const  CXXRecordDecl *RD, StringRef MethodName,
4475+                                Expr *MemberBaseExpr,
4476+                                SmallVectorImpl<Stmt *> &AddTo) {
4477+     CXXMethodDecl *Method = getMethodByName (RD, MethodName);
4478+     if  (!Method)
4479+       return ;
4480+     unsigned  NumParams = Method->getNumParams ();
4481+     llvm::SmallVector<Expr *, 4 > ParamDREs (NumParams);
4482+     llvm::ArrayRef<ParmVarDecl *> KernelParameters =
4483+         DeclCreator.getParamVarDeclsForCurrentField ();
4484+     for  (size_t  I = 0 ; I < NumParams; ++I) {
4485+       QualType ParamType = KernelParameters[I]->getOriginalType ();
4486+       ParamDREs[I] = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4487+           KernelParameters[I], ParamType, VK_LValue, FreeFunctionSrcLoc);
4488+     }
4489+     MemberExpr *MethodME = buildMemberExpr (MemberBaseExpr, Method);
4490+     QualType ResultTy = Method->getReturnType ();
4491+     ExprValueKind VK = Expr::getValueKindForType (ResultTy);
4492+     ResultTy = ResultTy.getNonLValueExprType (SemaSYCLRef.getASTContext ());
4493+     llvm::SmallVector<Expr *, 4 > ParamStmts;
4494+     const  auto  *Proto = cast<FunctionProtoType>(Method->getType ());
4495+     SemaSYCLRef.SemaRef .GatherArgumentsForCall (FreeFunctionSrcLoc, Method,
4496+                                                Proto, 0 , ParamDREs, ParamStmts);
4497+     AddTo.push_back (CXXMemberCallExpr::Create (
4498+         SemaSYCLRef.getASTContext (), MethodME, ParamStmts, ResultTy, VK,
4499+         FreeFunctionSrcLoc, FPOptionsOverride ()));
4500+   }
4501+ 
44454502public: 
44464503  static  constexpr  const  bool  VisitInsideSimpleContainers = false ;
44474504
@@ -4461,9 +4518,53 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
44614518    return  true ;
44624519  }
44634520
4464-   bool  handleSyclSpecialType (ParmVarDecl *, QualType) final  {
4465-     //  TODO
4466-     unsupportedFreeFunctionParamType ();
4521+   //  Default inits the type, then calls the init-method in the body.
4522+   //  A type may not have a public default constructor as per its spec so
4523+   //  typically if this is the case the default constructor will be private and
4524+   //  in such cases we must manually override the access specifier from private
4525+   //  to public just for the duration of this default initialization.
4526+   //  TODO: Revisit this approach once https://github.com/intel/llvm/issues/16061
4527+   //  is closed.
4528+   bool  handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final  {
4529+     if  (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
4530+       const  auto  *RecordDecl = ParamTy->getAsCXXRecordDecl ();
4531+       AccessSpecifier DefaultConstructorAccess;
4532+       auto  DefaultConstructor =
4533+           std::find_if (RecordDecl->ctor_begin (), RecordDecl->ctor_end (),
4534+                        [](auto  it) { return  it->isDefaultConstructor (); });
4535+       DefaultConstructorAccess = DefaultConstructor->getAccess ();
4536+       DefaultConstructor->setAccess (AS_public);
4537+ 
4538+       QualType Ty = PD->getOriginalType ();
4539+       ASTContext &Ctx = SemaSYCLRef.SemaRef .getASTContext ();
4540+       VarDecl *WorkGroupMemoryClone = VarDecl::Create (
4541+           Ctx, DeclCreator.getKernelDecl (), FreeFunctionSrcLoc,
4542+           FreeFunctionSrcLoc, PD->getIdentifier (), PD->getType (),
4543+           Ctx.getTrivialTypeSourceInfo (Ty), SC_None);
4544+       InitializedEntity VarEntity =
4545+           InitializedEntity::InitializeVariable (WorkGroupMemoryClone);
4546+       InitializationKind InitKind =
4547+           InitializationKind::CreateDefault (FreeFunctionSrcLoc);
4548+       InitializationSequence InitSeq (SemaSYCLRef.SemaRef , VarEntity, InitKind,
4549+                                      std::nullopt );
4550+       ExprResult Init = InitSeq.Perform (SemaSYCLRef.SemaRef , VarEntity,
4551+                                         InitKind, std::nullopt );
4552+       WorkGroupMemoryClone->setInit (
4553+           SemaSYCLRef.SemaRef .MaybeCreateExprWithCleanups (Init.get ()));
4554+       WorkGroupMemoryClone->setInitStyle (VarDecl::CallInit);
4555+       DefaultConstructor->setAccess (DefaultConstructorAccess);
4556+ 
4557+       Stmt *DS = new  (SemaSYCLRef.getASTContext ())
4558+           DeclStmt (DeclGroupRef (WorkGroupMemoryClone), FreeFunctionSrcLoc,
4559+                    FreeFunctionSrcLoc);
4560+       BodyStmts.push_back (DS);
4561+       Expr *MemberBaseExpr = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4562+           WorkGroupMemoryClone, Ty, VK_PRValue, FreeFunctionSrcLoc);
4563+       createSpecialMethodCall (RecordDecl, InitMethodName, MemberBaseExpr,
4564+                               BodyStmts);
4565+       ArgExprs.push_back (MemberBaseExpr);
4566+     } else  //  TODO
4567+       unsupportedFreeFunctionParamType ();
44674568    return  true ;
44684569  }
44694570
@@ -4748,9 +4849,11 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
47484849    return  true ;
47494850  }
47504851
4751-   bool  handleSyclSpecialType (ParmVarDecl *, QualType) final  {
4752-     //  TODO
4753-     unsupportedFreeFunctionParamType ();
4852+   bool  handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final  {
4853+     if  (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory))
4854+       addParam (PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory);
4855+     else 
4856+       unsupportedFreeFunctionParamType (); //  TODO
47544857    return  true ;
47554858  }
47564859
@@ -6227,7 +6330,6 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
62276330  O << " #include <sycl/detail/defines_elementary.hpp>\n "  ;
62286331  O << " #include <sycl/detail/kernel_desc.hpp>\n "  ;
62296332  O << " #include <sycl/ext/oneapi/experimental/free_function_traits.hpp>\n "  ;
6230- 
62316333  O << " \n "  ;
62326334
62336335  LangOptions LO;
@@ -6502,6 +6604,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
65026604
65036605    O << " \n "  ;
65046606    O << " // Forward declarations of kernel and its argument types:\n "  ;
6607+     Policy.SuppressDefaultTemplateArgs  = false ;
65056608    FwdDeclEmitter.Visit (K.SyclKernel ->getType ());
65066609    O << " \n "  ;
65076610
@@ -6579,6 +6682,8 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
65796682    }
65806683    O << " ;\n "  ;
65816684    O << " }\n "  ;
6685+     Policy.SuppressDefaultTemplateArgs  = true ;
6686+     Policy.EnforceDefaultTemplateArgs  = false ;
65826687
65836688    //  Generate is_kernel, is_single_task_kernel and nd_range_kernel functions.
65846689    O << " namespace sycl {\n "  ;
0 commit comments