@@ -2084,12 +2084,8 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
20842084 }
20852085
20862086 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
2087- if (!SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
2088- Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type)
2089- << ParamTy;
2090- IsInvalid = true ;
2091- }
2092- return isValid ();
2087+ IsInvalid |= checkSyclSpecialType (ParamTy, PD->getLocation ());
2088+ return isValid ();
20932089 }
20942090
20952091 bool handleArrayType (FieldDecl *FD, QualType FieldTy) final {
@@ -2240,9 +2236,7 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
22402236 }
22412237
22422238 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
2243- if (!SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory))
2244- unsupportedFreeFunctionParamType (); // TODO
2245- return true ;
2239+ return checkType (PD->getLocation (), ParamTy);
22462240 }
22472241
22482242 bool handleSyclSpecialType (const CXXRecordDecl *, const CXXBaseSpecifier &BS,
@@ -3026,7 +3020,6 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
30263020 }
30273021
30283022 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
3029- if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
30303023 const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
30313024 assert (RecordDecl && " The type must be a RecordDecl" );
30323025 CXXMethodDecl *InitMethod = getMethodByName (RecordDecl, InitMethodName);
@@ -3043,8 +3036,6 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
30433036 Params.back ()->addAttr (AddIRAttr->clone (SemaSYCLRef.getASTContext ()));
30443037 }
30453038 LastParamIndex = ParamIndex;
3046- } else // TODO
3047- unsupportedFreeFunctionParamType ();
30483039 return true ;
30493040 }
30503041
@@ -4538,7 +4529,6 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
45384529 // TODO: Revisit this approach once https://github.com/intel/llvm/issues/16061
45394530 // is closed.
45404531 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
4541- if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
45424532 const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
45434533 AccessSpecifier DefaultConstructorAccess;
45444534 auto DefaultConstructor =
@@ -4549,34 +4539,32 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
45494539
45504540 QualType Ty = PD->getOriginalType ();
45514541 ASTContext &Ctx = SemaSYCLRef.SemaRef .getASTContext ();
4552- VarDecl *WorkGroupMemoryClone = VarDecl::Create (
4542+ VarDecl *SpecialObjectClone = VarDecl::Create (
45534543 Ctx, DeclCreator.getKernelDecl (), FreeFunctionSrcLoc,
45544544 FreeFunctionSrcLoc, PD->getIdentifier (), PD->getType (),
45554545 Ctx.getTrivialTypeSourceInfo (Ty), SC_None);
45564546 InitializedEntity VarEntity =
4557- InitializedEntity::InitializeVariable (WorkGroupMemoryClone );
4547+ InitializedEntity::InitializeVariable (SpecialObjectClone );
45584548 InitializationKind InitKind =
45594549 InitializationKind::CreateDefault (FreeFunctionSrcLoc);
45604550 InitializationSequence InitSeq (SemaSYCLRef.SemaRef , VarEntity, InitKind,
45614551 std::nullopt );
45624552 ExprResult Init = InitSeq.Perform (SemaSYCLRef.SemaRef , VarEntity,
45634553 InitKind, std::nullopt );
4564- WorkGroupMemoryClone ->setInit (
4554+ SpecialObjectClone ->setInit (
45654555 SemaSYCLRef.SemaRef .MaybeCreateExprWithCleanups (Init.get ()));
4566- WorkGroupMemoryClone ->setInitStyle (VarDecl::CallInit);
4556+ SpecialObjectClone ->setInitStyle (VarDecl::CallInit);
45674557 DefaultConstructor->setAccess (DefaultConstructorAccess);
45684558
45694559 Stmt *DS = new (SemaSYCLRef.getASTContext ())
4570- DeclStmt (DeclGroupRef (WorkGroupMemoryClone ), FreeFunctionSrcLoc,
4560+ DeclStmt (DeclGroupRef (SpecialObjectClone ), FreeFunctionSrcLoc,
45714561 FreeFunctionSrcLoc);
45724562 BodyStmts.push_back (DS);
45734563 Expr *MemberBaseExpr = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4574- WorkGroupMemoryClone , Ty, VK_PRValue, FreeFunctionSrcLoc);
4564+ SpecialObjectClone , Ty, VK_PRValue, FreeFunctionSrcLoc);
45754565 createSpecialMethodCall (RecordDecl, InitMethodName, MemberBaseExpr,
45764566 BodyStmts);
45774567 ArgExprs.push_back (MemberBaseExpr);
4578- } else // TODO
4579- unsupportedFreeFunctionParamType ();
45804568 return true ;
45814569 }
45824570
@@ -4862,10 +4850,41 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
48624850 }
48634851
48644852 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
4865- if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory))
4853+ const auto *ClassTy = ParamTy->getAsCXXRecordDecl ();
4854+ assert (ClassTy && " Type must be a C++ record type" );
4855+ if (isSyclAccessorType (ParamTy)) {
4856+ const auto *AccTy =
4857+ cast<ClassTemplateSpecializationDecl>(ParamTy->getAsRecordDecl ());
4858+ assert (AccTy->getTemplateArgs ().size () >= 2 &&
4859+ " Incorrect template args for Accessor Type" );
4860+ int Dims = static_cast <int >(
4861+ AccTy->getTemplateArgs ()[1 ].getAsIntegral ().getExtValue ());
4862+ int Info = getAccessTarget (ParamTy, AccTy) | (Dims << 11 );
4863+
4864+ Header.addParamDesc (SYCLIntegrationHeader::kind_accessor, Info,
4865+ CurOffset);
4866+ } else if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::stream)) {
4867+ addParam (PD, ParamTy, SYCLIntegrationHeader::kind_stream);
4868+ } else if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
48664869 addParam (PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory);
4867- else
4868- unsupportedFreeFunctionParamType (); // TODO
4870+ } else if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::sampler) ||
4871+ SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::annotated_ptr) ||
4872+ SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::annotated_arg)) {
4873+ CXXMethodDecl *InitMethod = getMethodByName (ClassTy, InitMethodName);
4874+ assert (InitMethod && " type must have __init method" );
4875+ const ParmVarDecl *InitArg = InitMethod->getParamDecl (0 );
4876+ assert (InitArg && " Init method must have arguments" );
4877+ QualType T = InitArg->getType ();
4878+ SYCLIntegrationHeader::kernel_param_kind_t ParamKind =
4879+ SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::sampler)
4880+ ? SYCLIntegrationHeader::kind_sampler
4881+ : (T->isPointerType () ? SYCLIntegrationHeader::kind_pointer
4882+ : SYCLIntegrationHeader::kind_std_layout);
4883+ addParam (PD, ParamTy, ParamKind);
4884+ } else {
4885+ llvm_unreachable (
4886+ " Unexpected SYCL special class when generating integration header" );
4887+ }
48694888 return true ;
48704889 }
48714890
0 commit comments